From 17287fcf1a15b94eb8f2519bb044d9b5cd10629b Mon Sep 17 00:00:00 2001 From: David Neto Date: Thu, 17 Jun 2021 22:40:43 +0000 Subject: [PATCH] spirv-reader: Set workgroup size, but not specializable The WorkgroupSize builtin decoration applies to a composite constant. Because WGSL does not yet support specializable constants for this, use the *default* values for that SPIR-V spec constant. Update end-to-end test expectations. Fixed: tint:503 Change-Id: I012b316d13544ab9282e3276b58906327adab133 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/41960 Auto-Submit: David Neto Kokoro: Kokoro Commit-Queue: David Neto Reviewed-by: Alan Baker --- src/reader/spirv/entry_point_info.cc | 6 +- src/reader/spirv/entry_point_info.h | 18 +- src/reader/spirv/function.cc | 16 ++ src/reader/spirv/function_composite_test.cc | 121 ++++++++++- src/reader/spirv/parser_impl.cc | 184 ++++++++++++++++- src/reader/spirv/parser_impl.h | 54 ++++- .../spirv/parser_impl_function_decl_test.cc | 188 ++++++++++++++++++ .../spirv/parser_impl_module_var_test.cc | 33 ++- test/access/let/matrix.spvasm.expected.wgsl | 2 +- test/access/let/vector.spvasm.expected.wgsl | 2 +- test/access/var/matrix.spvasm.expected.wgsl | 2 +- test/access/var/vector.spvasm.expected.wgsl | 2 +- test/bug/tint/413.spvasm.expected.wgsl | 2 +- .../access/matrix.spvasm.expected.wgsl | 2 +- .../access/vector.spvasm.expected.wgsl | 2 +- .../copy/ptr_copy.spvasm.expected.wgsl | 2 +- .../load/global/i32.spvasm.expected.wgsl | 2 +- .../global/struct_field.spvasm.expected.wgsl | 2 +- .../load/local/i32.spvasm.expected.wgsl | 2 +- .../local/struct_field.spvasm.expected.wgsl | 2 +- .../load/param/ptr.spvasm.expected.wgsl | 2 +- .../store/global/i32.spvasm.expected.wgsl | 2 +- .../global/struct_field.spvasm.expected.wgsl | 2 +- .../store/local/i32.spvasm.expected.wgsl | 2 +- .../local/struct_field.spvasm.expected.wgsl | 2 +- .../store/param/ptr.spvasm.expected.wgsl | 2 +- 26 files changed, 617 insertions(+), 39 deletions(-) diff --git a/src/reader/spirv/entry_point_info.cc b/src/reader/spirv/entry_point_info.cc index 44242a5b67..63e55a2e57 100644 --- a/src/reader/spirv/entry_point_info.cc +++ b/src/reader/spirv/entry_point_info.cc @@ -25,13 +25,15 @@ EntryPointInfo::EntryPointInfo(std::string the_name, bool the_owns_inner_implementation, std::string the_inner_name, std::vector&& the_inputs, - std::vector&& the_outputs) + std::vector&& the_outputs, + GridSize the_wg_size) : name(the_name), stage(the_stage), owns_inner_implementation(the_owns_inner_implementation), inner_name(std::move(the_inner_name)), inputs(std::move(the_inputs)), - outputs(std::move(the_outputs)) {} + outputs(std::move(the_outputs)), + workgroup_size(the_wg_size) {} EntryPointInfo::EntryPointInfo(const EntryPointInfo&) = default; diff --git a/src/reader/spirv/entry_point_info.h b/src/reader/spirv/entry_point_info.h index 95fb760993..946d62f757 100644 --- a/src/reader/spirv/entry_point_info.h +++ b/src/reader/spirv/entry_point_info.h @@ -24,6 +24,13 @@ namespace tint { namespace reader { namespace spirv { +/// The size of an integer-coordinate grid, in the x, y, and z dimensions. +struct GridSize { + uint32_t x = 0; + uint32_t y = 0; + uint32_t z = 0; +}; + /// Entry point information for a function struct EntryPointInfo { /// Constructor. @@ -35,12 +42,14 @@ struct EntryPointInfo { /// entry point /// @param the_inputs list of IDs for Input variables used by the shader /// @param the_outputs list of IDs for Output variables used by the shader + /// @param the_wg_size the workgroup_size, for a compute shader EntryPointInfo(std::string the_name, ast::PipelineStage the_stage, bool the_owns_inner_implementation, std::string the_inner_name, std::vector&& the_inputs, - std::vector&& the_outputs); + std::vector&& the_outputs, + GridSize the_wg_size); /// Copy constructor /// @param other the other entry point info to be built from EntryPointInfo(const EntryPointInfo& other); @@ -55,6 +64,7 @@ struct EntryPointInfo { std::string name; /// The entry point stage ast::PipelineStage stage = ast::PipelineStage::kNone; + /// True when this entry point is responsible for generating the /// inner implementation function. False when this is the second entry /// point encountered for the same function in SPIR-V. It's unusual, but @@ -67,6 +77,12 @@ struct EntryPointInfo { std::vector inputs; /// IDs of pipeline output variables, sorted and without duplicates. std::vector outputs; + + /// If this is a compute shader, this is the workgroup size in the x, y, + /// and z dimensions set via LocalSize, or via the composite value + /// decorated as the WorkgroupSize BuiltIn. The WorkgroupSize builtin + /// takes priority. + GridSize workgroup_size; }; } // namespace spirv diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index d2ada0b22b..42ae77c137 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -1131,6 +1131,19 @@ bool FunctionEmitter::EmitEntryPointAsWrapper() { ast::DecorationList fn_decos; fn_decos.emplace_back(create(source, ep_info_->stage)); + if (ep_info_->stage == ast::PipelineStage::kCompute) { + auto& size = ep_info_->workgroup_size; + if (size.x != 0 && size.y != 0 && size.z != 0) { + ast::Expression* x = builder_.Expr(static_cast(size.x)); + ast::Expression* y = + size.y ? builder_.Expr(static_cast(size.y)) : nullptr; + ast::Expression* z = + size.z ? builder_.Expr(static_cast(size.z)) : nullptr; + fn_decos.emplace_back( + create(Source{}, x, y, z)); + } + } + builder_.AST().AddFunction( create(source, builder_.Symbols().Register(ep_info_->name), std::move(decl.params), return_type, body, @@ -2261,6 +2274,9 @@ bool FunctionEmitter::EmitFunctionVariables() { constructor = parser_impl_.MakeConstantExpression(inst.GetSingleWordInOperand(1)) .expr; + if (!constructor) { + return false; + } } auto* var = parser_impl_.MakeVariable( inst.result_id(), ast::StorageClass::kNone, var_store_type, false, diff --git a/src/reader/spirv/function_composite_test.cc b/src/reader/spirv/function_composite_test.cc index aca7eba228..7621aa50a2 100644 --- a/src/reader/spirv/function_composite_test.cc +++ b/src/reader/spirv/function_composite_test.cc @@ -1462,7 +1462,7 @@ TEST_F(SpvParserTest_VectorExtractDynamic, UnsignedIndex) { using SpvParserTest_VectorInsertDynamic = SpvParserTest; -TEST_F(SpvParserTest_VectorExtractDynamic, Sample) { +TEST_F(SpvParserTest_VectorInsertDynamic, Sample) { const auto assembly = Preamble() + R"( %100 = OpFunction %void None %voidfn %entry = OpLabel @@ -1512,6 +1512,125 @@ VariableDeclStatement{ << assembly; } +TEST_F(SpvParserTest, DISABLED_WorkgroupSize_Overridable) { + // TODO(dneto): Support specializable workgroup size. crbug.com/tint/504 + const auto* assembly = R"( + OpCapability Shader + OpMemoryModel Logical Simple + OpEntryPoint GLCompute %100 "main" + OpDecorate %1 BuiltIn WorkgroupSize + OpDecorate %uint_2 SpecId 0 + OpDecorate %uint_4 SpecId 1 + OpDecorate %uint_8 SpecId 2 + + %uint = OpTypeInt 32 0 + %uint_2 = OpSpecConstant %uint 2 + %uint_4 = OpSpecConstant %uint 4 + %uint_8 = OpSpecConstant %uint 8 + %v3uint = OpTypeVector %uint 3 + %1 = OpSpecConstantComposite %v3uint %uint_2 %uint_4 %uint_8 + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %10 = OpCopyObject %v3uint %1 + %11 = OpCopyObject %uint %uint_2 + %12 = OpCopyObject %uint %uint_4 + %13 = OpCopyObject %uint %uint_8 + OpReturn + OpFunctionEnd +)"; + + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + auto fe = p->function_emitter(100); + EXPECT_TRUE(fe.Emit()) << p->error(); + const auto got = p->program().to_str(); + EXPECT_THAT(got, HasSubstr(R"( + VariableConst{ + Decorations{ + OverrideDecoration{0} + } + x_2 + none + __u32 + { + ScalarConstructor[not set]{2} + } + } + VariableConst{ + Decorations{ + OverrideDecoration{1} + } + x_3 + none + __u32 + { + ScalarConstructor[not set]{4} + } + } + VariableConst{ + Decorations{ + OverrideDecoration{2} + } + x_4 + none + __u32 + { + ScalarConstructor[not set]{8} + } + } +)")) << got; + EXPECT_THAT(got, HasSubstr(R"( + VariableDeclStatement{ + VariableConst{ + x_10 + none + __vec_3__u32 + { + TypeConstructor[not set]{ + __vec_3__u32 + ScalarConstructor[not set]{2} + ScalarConstructor[not set]{4} + ScalarConstructor[not set]{8} + } + } + } + } + VariableDeclStatement{ + VariableConst{ + x_11 + none + __u32 + { + Identifier[not set]{x_2} + } + } + } + VariableDeclStatement{ + VariableConst{ + x_12 + none + __u32 + { + Identifier[not set]{x_3} + } + } + } + VariableDeclStatement{ + VariableConst{ + x_13 + none + __u32 + { + Identifier[not set]{x_4} + } + } + })")) + << got << assembly; +} + } // namespace } // namespace spirv } // namespace reader diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc index 46ad25ac90..cb0887ad9a 100644 --- a/src/reader/spirv/parser_impl.cc +++ b/src/reader/spirv/parser_impl.cc @@ -580,6 +580,9 @@ bool ParserImpl::ParseInternalModuleExceptFunctions() { if (!RegisterUserAndStructMemberNames()) { return false; } + if (!RegisterWorkgroupSizeBuiltin()) { + return false; + } if (!RegisterEntryPoints()) { return false; } @@ -720,13 +723,102 @@ bool ParserImpl::IsValidIdentifier(const std::string& str) { return true; } +bool ParserImpl::RegisterWorkgroupSizeBuiltin() { + WorkgroupSizeInfo& info = workgroup_size_builtin_; + for (const spvtools::opt::Instruction& inst : module_->annotations()) { + if (inst.opcode() != SpvOpDecorate) { + continue; + } + if (inst.GetSingleWordInOperand(1) != SpvDecorationBuiltIn) { + continue; + } + if (inst.GetSingleWordInOperand(2) != SpvBuiltInWorkgroupSize) { + continue; + } + info.id = inst.GetSingleWordInOperand(0); + } + if (info.id == 0) { + return true; + } + // Gather the values. + const spvtools::opt::Instruction* composite_def = + def_use_mgr_->GetDef(info.id); + if (!composite_def) { + return Fail() << "Invalid WorkgroupSize builtin value"; + } + // SPIR-V validation checks that the result is a 3-element vector of 32-bit + // integer scalars (signed or unsigned). Rely on validation to check the + // type. In theory the instruction could be OpConstantNull and still + // pass validation, but that would be non-sensical. Be a little more + // stringent here and check for specific opcodes. WGSL does not support + // const-expr yet, so avoid supporting OpSpecConstantOp here. + // TODO(dneto): See https://github.com/gpuweb/gpuweb/issues/1272 for WGSL + // const_expr proposals. + if ((composite_def->opcode() != SpvOpSpecConstantComposite && + composite_def->opcode() != SpvOpConstantComposite)) { + return Fail() << "Invalid WorkgroupSize builtin. Expected 3-element " + "OpSpecConstantComposite or OpConstantComposite: " + << composite_def->PrettyPrint(); + } + info.type_id = composite_def->type_id(); + // Extract the component type from the vector type. + info.component_type_id = + def_use_mgr_->GetDef(info.type_id)->GetSingleWordInOperand(0); + + /// Sets the ID and value of the index'th member of the composite constant. + /// Returns false and emits a diagnostic on error. + auto set_param = [this, composite_def](uint32_t* id_ptr, uint32_t* value_ptr, + int index) -> bool { + const auto id = composite_def->GetSingleWordInOperand(index); + const auto* def = def_use_mgr_->GetDef(id); + if (!def || + (def->opcode() != SpvOpSpecConstant && + def->opcode() != SpvOpConstant) || + (def->NumInOperands() != 1)) { + return Fail() << "invalid component " << index << " of workgroupsize " + << (def ? def->PrettyPrint() + : std::string("no definition")); + } + *id_ptr = id; + // Use the default value of a spec constant. + *value_ptr = def->GetSingleWordInOperand(0); + return true; + }; + + return set_param(&info.x_id, &info.x_value, 0) && + set_param(&info.y_id, &info.y_value, 1) && + set_param(&info.z_id, &info.z_value, 2); +} + bool ParserImpl::RegisterEntryPoints() { + // Mapping from entry point ID to GridSize computed from LocalSize + // decorations. + std::unordered_map local_size; + for (const spvtools::opt::Instruction& inst : module_->execution_modes()) { + auto mode = static_cast(inst.GetSingleWordInOperand(1)); + if (mode == SpvExecutionModeLocalSize) { + if (inst.NumInOperands() != 5) { + // This won't even get past SPIR-V binary parsing. + return Fail() << "invalid LocalSize execution mode: " + << inst.PrettyPrint(); + } + uint32_t function_id = inst.GetSingleWordInOperand(0); + local_size[function_id] = GridSize{inst.GetSingleWordInOperand(2), + inst.GetSingleWordInOperand(3), + inst.GetSingleWordInOperand(4)}; + } + } + for (const spvtools::opt::Instruction& entry_point : module_->entry_points()) { const auto stage = SpvExecutionModel(entry_point.GetSingleWordInOperand(0)); const uint32_t function_id = entry_point.GetSingleWordInOperand(1); const std::string ep_name = entry_point.GetOperand(2).AsString(); + if (!IsValidIdentifier(ep_name)) { + return Fail() << "entry point name is not a valid WGSL identifier: " + << ep_name; + } bool owns_inner_implementation = false; std::string inner_implementation_name; @@ -769,11 +861,30 @@ bool ParserImpl::RegisterEntryPoints() { std::vector sorted_outputs(outputs.begin(), outputs.end()); std::sort(sorted_inputs.begin(), sorted_inputs.end()); + const auto ast_stage = enum_converter_.ToPipelineStage(stage); + GridSize wgsize; + if (ast_stage == ast::PipelineStage::kCompute) { + if (workgroup_size_builtin_.id) { + // Store the default values. + // WGSL allows specializing these, but this code doesn't support that + // yet. https://github.com/gpuweb/gpuweb/issues/1442 + wgsize = GridSize{workgroup_size_builtin_.x_value, + workgroup_size_builtin_.y_value, + workgroup_size_builtin_.z_value}; + } else { + // Use the LocalSize execution mode. This is the second choice. + auto where = local_size.find(function_id); + if (where != local_size.end()) { + wgsize = where->second; + } + } + } function_to_ep_info_[function_id].emplace_back( - ep_name, enum_converter_.ToPipelineStage(stage), - owns_inner_implementation, inner_implementation_name, - std::move(sorted_inputs), std::move(sorted_outputs)); + ep_name, ast_stage, owns_inner_implementation, + inner_implementation_name, std::move(sorted_inputs), + std::move(sorted_outputs), wgsize); } + // The enum conversion could have failed, so return the existing status value. return success_; } @@ -1521,10 +1632,59 @@ bool ParserImpl::ConvertDecorationsForVariable(uint32_t id, return success(); } +bool ParserImpl::CanMakeConstantExpression(uint32_t id) { + if ((id == workgroup_size_builtin_.id) || + (id == workgroup_size_builtin_.x_id) || + (id == workgroup_size_builtin_.y_id) || + (id == workgroup_size_builtin_.z_id)) { + return true; + } + const auto* inst = def_use_mgr_->GetDef(id); + if (!inst) { + return false; + } + if (inst->opcode() == SpvOpUndef) { + return true; + } + return nullptr != constant_mgr_->FindDeclaredConstant(id); +} + TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) { if (!success_) { return {}; } + + // Handle the special cases for workgroup sizing. + if (id == workgroup_size_builtin_.id) { + auto x = MakeConstantExpression(workgroup_size_builtin_.x_id); + auto y = MakeConstantExpression(workgroup_size_builtin_.y_id); + auto z = MakeConstantExpression(workgroup_size_builtin_.z_id); + auto* ast_type = ty_.Vector(x.type, 3); + return {ast_type, create( + Source{}, ast_type->Build(builder_), + ast::ExpressionList{x.expr, y.expr, z.expr})}; + } else if (id == workgroup_size_builtin_.x_id) { + return MakeConstantExpressionForSpirvConstant( + Source{}, ConvertType(workgroup_size_builtin_.component_type_id), + constant_mgr_->GetConstant( + type_mgr_->GetType(workgroup_size_builtin_.component_type_id), + {workgroup_size_builtin_.x_value})); + } else if (id == workgroup_size_builtin_.y_id) { + return MakeConstantExpressionForSpirvConstant( + Source{}, ConvertType(workgroup_size_builtin_.component_type_id), + constant_mgr_->GetConstant( + type_mgr_->GetType(workgroup_size_builtin_.component_type_id), + {workgroup_size_builtin_.y_value})); + } else if (id == workgroup_size_builtin_.z_id) { + return MakeConstantExpressionForSpirvConstant( + Source{}, ConvertType(workgroup_size_builtin_.component_type_id), + constant_mgr_->GetConstant( + type_mgr_->GetType(workgroup_size_builtin_.component_type_id), + {workgroup_size_builtin_.z_value})); + } + + // Handle the general case where a constant is already registered + // with the SPIR-V optimizer's analysis framework. const auto* inst = def_use_mgr_->GetDef(id); if (inst == nullptr) { Fail() << "ID " << id << " is not a registered instruction"; @@ -1548,6 +1708,14 @@ TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) { } auto source = GetSourceForInst(inst); + return MakeConstantExpressionForSpirvConstant(source, original_ast_type, + spirv_const); +} + +TypedExpression ParserImpl::MakeConstantExpressionForSpirvConstant( + Source source, + const Type* original_ast_type, + const spvtools::opt::analysis::Constant* spirv_const) { auto* ast_type = original_ast_type->UnwrapAlias(); // TODO(dneto): Note: NullConstant for int, uint, float map to a regular 0. @@ -1603,12 +1771,10 @@ TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) { Source{}, original_ast_type->Build(builder_), std::move(ast_components))}; } - auto* spirv_null_const = spirv_const->AsNullConstant(); - if (spirv_null_const != nullptr) { + if (spirv_const->AsNullConstant()) { return {original_ast_type, MakeNullValue(original_ast_type)}; } - Fail() << "Unhandled constant type " << inst->type_id() << " for value ID " - << id; + Fail() << "Unhandled constant type "; return {}; } @@ -2464,6 +2630,10 @@ const spvtools::opt::Instruction* ParserImpl::GetInstructionForTest( return def_use_mgr_ ? def_use_mgr_->GetDef(id) : nullptr; } +WorkgroupSizeInfo::WorkgroupSizeInfo() = default; + +WorkgroupSizeInfo::~WorkgroupSizeInfo() = default; + } // namespace spirv } // namespace reader } // namespace tint diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h index 60e96cfee7..453bb6651e 100644 --- a/src/reader/spirv/parser_impl.h +++ b/src/reader/spirv/parser_impl.h @@ -85,6 +85,29 @@ struct TypedExpression { ast::Expression* expr = nullptr; }; +/// Info about the WorkgroupSize builtin. +struct WorkgroupSizeInfo { + /// Constructor + WorkgroupSizeInfo(); + /// Destructor + ~WorkgroupSizeInfo(); + /// The SPIR-V ID of the WorkgroupSize builtin, if any. + uint32_t id = 0u; + /// The SPIR-V type ID of the WorkgroupSize builtin, if any. + uint32_t type_id = 0u; + /// The SPIR-V type IDs of the x, y, and z components. + uint32_t component_type_id = 0u; + /// The SPIR-V IDs of the X, Y, and Z components of the workgroup size + /// builtin. + uint32_t x_id = 0u; + uint32_t y_id = 0u; + uint32_t z_id = 0u; + /// The effective workgroup size, if this is a compute shader. + uint32_t x_value = 0u; + uint32_t y_value = 0u; + uint32_t z_value = 0u; +}; + /// Parser implementation for SPIR-V. class ParserImpl : Reader { public: @@ -306,6 +329,14 @@ class ParserImpl : Reader { /// @returns true if parser is still successful. bool RegisterUserAndStructMemberNames(); + /// Register the WorkgroupSize builtin and its associated constant value. + /// @returns true if parser is still successful. + bool RegisterWorkgroupSizeBuiltin(); + + const WorkgroupSizeInfo& workgroup_size_builtin() { + return workgroup_size_builtin_; + } + /// Register entry point information. /// This is a no-op if the parser has already failed. /// @returns true if parser is still successful. @@ -366,11 +397,27 @@ class ParserImpl : Reader { ast::Expression* constructor, ast::DecorationList decorations); - /// Creates an AST expression node for a SPIR-V constant. + /// Returns true if a constant expression can be generated. + /// @param id the SPIR-V ID of the value + /// @returns true if a constant expression can be generated + bool CanMakeConstantExpression(uint32_t id); + + /// Creates an AST expression node for a SPIR-V ID. This is valid to call + /// when `CanMakeConstantExpression` returns true. /// @param id the SPIR-V ID of the constant /// @returns a new expression TypedExpression MakeConstantExpression(uint32_t id); + /// Creates an AST expression node for a SPIR-V constant. + /// @param source the source location + /// @param ast_type the AST type for the value + /// @param spirv_const the internal representation of the SPIR-V constant. + /// @returns a new expression + TypedExpression MakeConstantExpressionForSpirvConstant( + Source source, + const Type* ast_type, + const spvtools::opt::analysis::Constant* spirv_const); + /// Creates an AST expression node for the null value for the given type. /// @param type the AST type /// @returns a new expression @@ -778,6 +825,11 @@ class ParserImpl : Reader { /// This is temporary while this module is converted to use the new style /// of pipeline IO. bool hlsl_style_pipeline_io_ = false; + + /// Info about the WorkgroupSize builtin. If it's not present, then the 'id' + /// field will be 0. Sadly, in SPIR-V right now, there's only one workgroup + /// size object in the module. + WorkgroupSizeInfo workgroup_size_builtin_; }; } // namespace spirv diff --git a/src/reader/spirv/parser_impl_function_decl_test.cc b/src/reader/spirv/parser_impl_function_decl_test.cc index e573c8d5b4..0216660ac1 100644 --- a/src/reader/spirv/parser_impl_function_decl_test.cc +++ b/src/reader/spirv/parser_impl_function_decl_test.cc @@ -151,6 +151,11 @@ TEST_F(SpvParserTest, EmitFunctions_Function_EntryPoint_GLCompute) { Function )" + program.Symbols().Get("main").to_str() + R"( -> __void StageDecoration{compute} + WorkgroupDecoration{ + ScalarConstructor[not set]{1} + ScalarConstructor[not set]{1} + ScalarConstructor[not set]{1} + } () {)")); } @@ -182,6 +187,189 @@ OpEntryPoint Vertex %main "second_shader" {)")); } +TEST_F(SpvParserTest, + EmitFunctions_Function_EntryPoint_GLCompute_LocalSize_Only) { + std::string input = Caps() + Names({"main"}) + + R"(OpEntryPoint GLCompute %main "comp_main" +OpExecutionMode %main LocalSize 2 4 8 +)" + CommonTypes() + R"( +%main = OpFunction %void None %voidfn +%entry = OpLabel +OpReturn +OpFunctionEnd)"; + + auto p = parser(test::Assemble(input)); + ASSERT_TRUE(p->BuildAndParseInternalModule()); + ASSERT_TRUE(p->error().empty()) << p->error(); + Program program = p->program(); + const auto program_ast = program.to_str(false); + EXPECT_THAT(program_ast, HasSubstr(R"( + Function )" + program.Symbols().Get("comp_main").to_str() + + R"( -> __void + StageDecoration{compute} + WorkgroupDecoration{ + ScalarConstructor[not set]{2} + ScalarConstructor[not set]{4} + ScalarConstructor[not set]{8} + } + () + {)")) + << program_ast; +} + +TEST_F(SpvParserTest, + EmitFunctions_Function_EntryPoint_WorkgroupSizeBuiltin_Constant_Only) { + std::string input = Caps() + R"(OpEntryPoint GLCompute %main "comp_main" +OpDecorate %wgsize BuiltIn WorkgroupSize +)" + CommonTypes() + R"( +%uvec3 = OpTypeVector %uint 3 +%uint_3 = OpConstant %uint 3 +%uint_5 = OpConstant %uint 5 +%uint_7 = OpConstant %uint 7 +%wgsize = OpConstantComposite %uvec3 %uint_3 %uint_5 %uint_7 +%main = OpFunction %void None %voidfn +%entry = OpLabel +OpReturn +OpFunctionEnd)"; + + auto p = parser(test::Assemble(input)); + ASSERT_TRUE(p->BuildAndParseInternalModule()); + ASSERT_TRUE(p->error().empty()) << p->error(); + Program program = p->program(); + const auto program_ast = program.to_str(false); + EXPECT_THAT(program_ast, HasSubstr(R"( + Function )" + program.Symbols().Get("comp_main").to_str() + + R"( -> __void + StageDecoration{compute} + WorkgroupDecoration{ + ScalarConstructor[not set]{3} + ScalarConstructor[not set]{5} + ScalarConstructor[not set]{7} + } + () + {)")) + << program_ast; +} + +TEST_F( + SpvParserTest, + EmitFunctions_Function_EntryPoint_WorkgroupSizeBuiltin_SpecConstant_Only) { + std::string input = Caps() + + R"(OpEntryPoint GLCompute %main "comp_main" +OpDecorate %wgsize BuiltIn WorkgroupSize +OpDecorate %uint_3 SpecId 0 +OpDecorate %uint_5 SpecId 1 +OpDecorate %uint_7 SpecId 2 +)" + CommonTypes() + R"( +%uvec3 = OpTypeVector %uint 3 +%uint_3 = OpSpecConstant %uint 3 +%uint_5 = OpSpecConstant %uint 5 +%uint_7 = OpSpecConstant %uint 7 +%wgsize = OpSpecConstantComposite %uvec3 %uint_3 %uint_5 %uint_7 +%main = OpFunction %void None %voidfn +%entry = OpLabel +OpReturn +OpFunctionEnd)"; + + auto p = parser(test::Assemble(input)); + ASSERT_TRUE(p->BuildAndParseInternalModule()); + ASSERT_TRUE(p->error().empty()) << p->error(); + Program program = p->program(); + const auto program_ast = program.to_str(false); + EXPECT_THAT(program_ast, HasSubstr(R"( + Function )" + program.Symbols().Get("comp_main").to_str() + + R"( -> __void + StageDecoration{compute} + WorkgroupDecoration{ + ScalarConstructor[not set]{3} + ScalarConstructor[not set]{5} + ScalarConstructor[not set]{7} + } + () + {)")) + << program_ast; +} + +TEST_F( + SpvParserTest, + EmitFunctions_Function_EntryPoint_WorkgroupSize_MixedConstantSpecConstant) { + std::string input = Caps() + + R"(OpEntryPoint GLCompute %main "comp_main" +OpDecorate %wgsize BuiltIn WorkgroupSize +OpDecorate %uint_3 SpecId 0 +OpDecorate %uint_7 SpecId 2 +)" + CommonTypes() + R"( +%uvec3 = OpTypeVector %uint 3 +%uint_3 = OpSpecConstant %uint 3 +%uint_5 = OpConstant %uint 5 +%uint_7 = OpSpecConstant %uint 7 +%wgsize = OpSpecConstantComposite %uvec3 %uint_3 %uint_5 %uint_7 +%main = OpFunction %void None %voidfn +%entry = OpLabel +OpReturn +OpFunctionEnd)"; + + auto p = parser(test::Assemble(input)); + ASSERT_TRUE(p->BuildAndParseInternalModule()); + ASSERT_TRUE(p->error().empty()) << p->error(); + Program program = p->program(); + const auto program_ast = program.to_str(false); + EXPECT_THAT(program_ast, HasSubstr(R"( + Function )" + program.Symbols().Get("comp_main").to_str() + + R"( -> __void + StageDecoration{compute} + WorkgroupDecoration{ + ScalarConstructor[not set]{3} + ScalarConstructor[not set]{5} + ScalarConstructor[not set]{7} + } + () + {)")) + << program_ast; +} + +TEST_F( + SpvParserTest, + // I had to shorten the name to pass the linter. + EmitFunctions_Function_EntryPoint_LocalSize_And_WGSBuiltin_SpecConstant) { + // WorkgroupSize builtin wins. + std::string input = Caps() + + R"(OpEntryPoint GLCompute %main "comp_main" +OpExecutionMode %main LocalSize 2 4 8 +OpDecorate %wgsize BuiltIn WorkgroupSize +OpDecorate %uint_3 SpecId 0 +OpDecorate %uint_5 SpecId 1 +OpDecorate %uint_7 SpecId 2 +)" + CommonTypes() + R"( +%uvec3 = OpTypeVector %uint 3 +%uint_3 = OpSpecConstant %uint 3 +%uint_5 = OpSpecConstant %uint 5 +%uint_7 = OpSpecConstant %uint 7 +%wgsize = OpSpecConstantComposite %uvec3 %uint_3 %uint_5 %uint_7 +%main = OpFunction %void None %voidfn +%entry = OpLabel +OpReturn +OpFunctionEnd)"; + + auto p = parser(test::Assemble(input)); + ASSERT_TRUE(p->BuildAndParseInternalModule()); + ASSERT_TRUE(p->error().empty()) << p->error(); + Program program = p->program(); + const auto program_ast = program.to_str(false); + EXPECT_THAT(program_ast, HasSubstr(R"( + Function )" + program.Symbols().Get("comp_main").to_str() + + R"( -> __void + StageDecoration{compute} + WorkgroupDecoration{ + ScalarConstructor[not set]{3} + ScalarConstructor[not set]{5} + ScalarConstructor[not set]{7} + } + () + {)")) + << program_ast; +} + TEST_F(SpvParserTest, EmitFunctions_VoidFunctionWithoutParams) { auto p = parser(test::Assemble(Preamble() + Names({"another_function"}) + CommonTypes() + R"( diff --git a/src/reader/spirv/parser_impl_module_var_test.cc b/src/reader/spirv/parser_impl_module_var_test.cc index 1196b232fd..4bc5193355 100644 --- a/src/reader/spirv/parser_impl_module_var_test.cc +++ b/src/reader/spirv/parser_impl_module_var_test.cc @@ -4903,6 +4903,11 @@ TEST_P(SpvModuleScopeVarParserTest_ComputeBuiltin, Load_Direct) { } Function main -> __void StageDecoration{compute} + WorkgroupDecoration{ + ScalarConstructor[not set]{1} + ScalarConstructor[not set]{1} + ScalarConstructor[not set]{1} + } ( VariableConst{ Decorations{ @@ -4919,10 +4924,10 @@ TEST_P(SpvModuleScopeVarParserTest_ComputeBuiltin, Load_Direct) { Assignment{ Identifier[not set]{x_1})" + (wgsl_type == unsigned_wgsl_type ? - R"( + R"( Identifier[not set]{x_1_param})" - : - R"( + : + R"( Bitcast[not set]<)" + signed_wgsl_type + R"(>{ Identifier[not set]{x_1_param} })") + R"( @@ -5001,6 +5006,11 @@ TEST_P(SpvModuleScopeVarParserTest_ComputeBuiltin, Load_CopyObject) { } Function main -> __void StageDecoration{compute} + WorkgroupDecoration{ + ScalarConstructor[not set]{1} + ScalarConstructor[not set]{1} + ScalarConstructor[not set]{1} + } ( VariableConst{ Decorations{ @@ -5017,10 +5027,10 @@ TEST_P(SpvModuleScopeVarParserTest_ComputeBuiltin, Load_CopyObject) { Assignment{ Identifier[not set]{x_1})" + (wgsl_type == unsigned_wgsl_type ? - R"( + R"( Identifier[not set]{x_1_param})" - : - R"( + : + R"( Bitcast[not set]<)" + signed_wgsl_type + R"(>{ Identifier[not set]{x_1_param} })") + R"( @@ -5081,6 +5091,11 @@ TEST_P(SpvModuleScopeVarParserTest_ComputeBuiltin, Load_AccessChain) { } Function main -> __void StageDecoration{compute} + WorkgroupDecoration{ + ScalarConstructor[not set]{1} + ScalarConstructor[not set]{1} + ScalarConstructor[not set]{1} + } ( VariableConst{ Decorations{ @@ -5097,10 +5112,10 @@ TEST_P(SpvModuleScopeVarParserTest_ComputeBuiltin, Load_AccessChain) { Assignment{ Identifier[not set]{x_1})" + (wgsl_type == unsigned_wgsl_type ? - R"( + R"( Identifier[not set]{x_1_param})" - : - R"( + : + R"( Bitcast[not set]<)" + signed_wgsl_type + R"(>{ Identifier[not set]{x_1_param} })") + R"( diff --git a/test/access/let/matrix.spvasm.expected.wgsl b/test/access/let/matrix.spvasm.expected.wgsl index 156324f5b7..e779e9fcea 100644 --- a/test/access/let/matrix.spvasm.expected.wgsl +++ b/test/access/let/matrix.spvasm.expected.wgsl @@ -3,7 +3,7 @@ fn main_1() { return; } -[[stage(compute)]] +[[stage(compute), workgroup_size(1, 1, 1)]] fn main() { main_1(); } diff --git a/test/access/let/vector.spvasm.expected.wgsl b/test/access/let/vector.spvasm.expected.wgsl index 3c3957fafc..b341f302cd 100644 --- a/test/access/let/vector.spvasm.expected.wgsl +++ b/test/access/let/vector.spvasm.expected.wgsl @@ -5,7 +5,7 @@ fn main_1() { return; } -[[stage(compute)]] +[[stage(compute), workgroup_size(1, 1, 1)]] fn main() { main_1(); } diff --git a/test/access/var/matrix.spvasm.expected.wgsl b/test/access/var/matrix.spvasm.expected.wgsl index 851efcce1a..080a67b636 100644 --- a/test/access/var/matrix.spvasm.expected.wgsl +++ b/test/access/var/matrix.spvasm.expected.wgsl @@ -5,7 +5,7 @@ fn main_1() { return; } -[[stage(compute)]] +[[stage(compute), workgroup_size(1, 1, 1)]] fn main() { main_1(); } diff --git a/test/access/var/vector.spvasm.expected.wgsl b/test/access/var/vector.spvasm.expected.wgsl index ced274e1dc..c904348691 100644 --- a/test/access/var/vector.spvasm.expected.wgsl +++ b/test/access/var/vector.spvasm.expected.wgsl @@ -8,7 +8,7 @@ fn main_1() { return; } -[[stage(compute)]] +[[stage(compute), workgroup_size(1, 1, 1)]] fn main() { main_1(); } diff --git a/test/bug/tint/413.spvasm.expected.wgsl b/test/bug/tint/413.spvasm.expected.wgsl index 5a581e82d2..8cf34a878c 100644 --- a/test/bug/tint/413.spvasm.expected.wgsl +++ b/test/bug/tint/413.spvasm.expected.wgsl @@ -14,7 +14,7 @@ fn main_1() { return; } -[[stage(compute)]] +[[stage(compute), workgroup_size(1, 1, 1)]] fn main() { main_1(); } diff --git a/test/ptr_ref/access/matrix.spvasm.expected.wgsl b/test/ptr_ref/access/matrix.spvasm.expected.wgsl index 9318800c12..57e94a58b7 100644 --- a/test/ptr_ref/access/matrix.spvasm.expected.wgsl +++ b/test/ptr_ref/access/matrix.spvasm.expected.wgsl @@ -5,7 +5,7 @@ fn main_1() { return; } -[[stage(compute)]] +[[stage(compute), workgroup_size(1, 1, 1)]] fn main() { main_1(); } diff --git a/test/ptr_ref/access/vector.spvasm.expected.wgsl b/test/ptr_ref/access/vector.spvasm.expected.wgsl index 1c01974801..69dee0458c 100644 --- a/test/ptr_ref/access/vector.spvasm.expected.wgsl +++ b/test/ptr_ref/access/vector.spvasm.expected.wgsl @@ -5,7 +5,7 @@ fn main_1() { return; } -[[stage(compute)]] +[[stage(compute), workgroup_size(1, 1, 1)]] fn main() { main_1(); } diff --git a/test/ptr_ref/copy/ptr_copy.spvasm.expected.wgsl b/test/ptr_ref/copy/ptr_copy.spvasm.expected.wgsl index 360caf4c3a..512d5c8dd1 100644 --- a/test/ptr_ref/copy/ptr_copy.spvasm.expected.wgsl +++ b/test/ptr_ref/copy/ptr_copy.spvasm.expected.wgsl @@ -5,7 +5,7 @@ fn main_1() { return; } -[[stage(compute)]] +[[stage(compute), workgroup_size(1, 1, 1)]] fn main() { main_1(); } diff --git a/test/ptr_ref/load/global/i32.spvasm.expected.wgsl b/test/ptr_ref/load/global/i32.spvasm.expected.wgsl index f9147e488d..2f193e8be8 100644 --- a/test/ptr_ref/load/global/i32.spvasm.expected.wgsl +++ b/test/ptr_ref/load/global/i32.spvasm.expected.wgsl @@ -6,7 +6,7 @@ fn main_1() { return; } -[[stage(compute)]] +[[stage(compute), workgroup_size(1, 1, 1)]] fn main() { main_1(); } diff --git a/test/ptr_ref/load/global/struct_field.spvasm.expected.wgsl b/test/ptr_ref/load/global/struct_field.spvasm.expected.wgsl index c423f7795e..8771732fd5 100644 --- a/test/ptr_ref/load/global/struct_field.spvasm.expected.wgsl +++ b/test/ptr_ref/load/global/struct_field.spvasm.expected.wgsl @@ -11,7 +11,7 @@ fn main_1() { return; } -[[stage(compute)]] +[[stage(compute), workgroup_size(1, 1, 1)]] fn main() { main_1(); } diff --git a/test/ptr_ref/load/local/i32.spvasm.expected.wgsl b/test/ptr_ref/load/local/i32.spvasm.expected.wgsl index 8a47e29786..89d3a08c19 100644 --- a/test/ptr_ref/load/local/i32.spvasm.expected.wgsl +++ b/test/ptr_ref/load/local/i32.spvasm.expected.wgsl @@ -6,7 +6,7 @@ fn main_1() { return; } -[[stage(compute)]] +[[stage(compute), workgroup_size(1, 1, 1)]] fn main() { main_1(); } diff --git a/test/ptr_ref/load/local/struct_field.spvasm.expected.wgsl b/test/ptr_ref/load/local/struct_field.spvasm.expected.wgsl index 8ddaccf949..aa20b42594 100644 --- a/test/ptr_ref/load/local/struct_field.spvasm.expected.wgsl +++ b/test/ptr_ref/load/local/struct_field.spvasm.expected.wgsl @@ -10,7 +10,7 @@ fn main_1() { return; } -[[stage(compute)]] +[[stage(compute), workgroup_size(1, 1, 1)]] fn main() { main_1(); } diff --git a/test/ptr_ref/load/param/ptr.spvasm.expected.wgsl b/test/ptr_ref/load/param/ptr.spvasm.expected.wgsl index 97ca0e7588..05c17cfe63 100644 --- a/test/ptr_ref/load/param/ptr.spvasm.expected.wgsl +++ b/test/ptr_ref/load/param/ptr.spvasm.expected.wgsl @@ -11,7 +11,7 @@ fn main_1() { return; } -[[stage(compute)]] +[[stage(compute), workgroup_size(1, 1, 1)]] fn main() { main_1(); } diff --git a/test/ptr_ref/store/global/i32.spvasm.expected.wgsl b/test/ptr_ref/store/global/i32.spvasm.expected.wgsl index ffd907acef..b48aa90008 100644 --- a/test/ptr_ref/store/global/i32.spvasm.expected.wgsl +++ b/test/ptr_ref/store/global/i32.spvasm.expected.wgsl @@ -6,7 +6,7 @@ fn main_1() { return; } -[[stage(compute)]] +[[stage(compute), workgroup_size(1, 1, 1)]] fn main() { main_1(); } diff --git a/test/ptr_ref/store/global/struct_field.spvasm.expected.wgsl b/test/ptr_ref/store/global/struct_field.spvasm.expected.wgsl index 7362392bcc..c217a5f895 100644 --- a/test/ptr_ref/store/global/struct_field.spvasm.expected.wgsl +++ b/test/ptr_ref/store/global/struct_field.spvasm.expected.wgsl @@ -9,7 +9,7 @@ fn main_1() { return; } -[[stage(compute)]] +[[stage(compute), workgroup_size(1, 1, 1)]] fn main() { main_1(); } diff --git a/test/ptr_ref/store/local/i32.spvasm.expected.wgsl b/test/ptr_ref/store/local/i32.spvasm.expected.wgsl index f1695d1f8b..64f374429f 100644 --- a/test/ptr_ref/store/local/i32.spvasm.expected.wgsl +++ b/test/ptr_ref/store/local/i32.spvasm.expected.wgsl @@ -6,7 +6,7 @@ fn main_1() { return; } -[[stage(compute)]] +[[stage(compute), workgroup_size(1, 1, 1)]] fn main() { main_1(); } diff --git a/test/ptr_ref/store/local/struct_field.spvasm.expected.wgsl b/test/ptr_ref/store/local/struct_field.spvasm.expected.wgsl index d9c0e1fa2f..d628330c98 100644 --- a/test/ptr_ref/store/local/struct_field.spvasm.expected.wgsl +++ b/test/ptr_ref/store/local/struct_field.spvasm.expected.wgsl @@ -8,7 +8,7 @@ fn main_1() { return; } -[[stage(compute)]] +[[stage(compute), workgroup_size(1, 1, 1)]] fn main() { main_1(); } diff --git a/test/ptr_ref/store/param/ptr.spvasm.expected.wgsl b/test/ptr_ref/store/param/ptr.spvasm.expected.wgsl index d56ff4cbfe..30fdea1035 100644 --- a/test/ptr_ref/store/param/ptr.spvasm.expected.wgsl +++ b/test/ptr_ref/store/param/ptr.spvasm.expected.wgsl @@ -10,7 +10,7 @@ fn main_1() { return; } -[[stage(compute)]] +[[stage(compute), workgroup_size(1, 1, 1)]] fn main() { main_1(); }