diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index 85ab480989..03f45d99e7 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -980,11 +980,14 @@ bool FunctionEmitter::EmitPipelineInput(std::string var_name, index_prefix.push_back(0); for (int i = 0; i < static_cast(members.size()); ++i) { index_prefix.back() = i; + auto* location = parser_impl_.GetMemberLocation(*struct_type, i); + auto* saved_location = SetLocation(decos, location); if (!EmitPipelineInput(var_name, var_type, decos, index_prefix, members[i], forced_param_type, params, statements)) { return false; } + SetLocation(decos, saved_location); } return success(); } @@ -1036,6 +1039,12 @@ bool FunctionEmitter::EmitPipelineInput(std::string var_name, statements->push_back(builder_.Assign(store_dest, param_value)); // Increment the location attribute, in case more parameters will follow. + IncrementLocation(decos); + + return success(); +} + +void FunctionEmitter::IncrementLocation(ast::DecorationList* decos) { for (auto*& deco : *decos) { if (auto* loc_deco = deco->As()) { // Replace this location decoration with a new one with one higher index. @@ -1044,8 +1053,27 @@ bool FunctionEmitter::EmitPipelineInput(std::string var_name, deco = builder_.Location(loc_deco->source(), loc_deco->value() + 1); } } +} - return success(); +ast::Decoration* FunctionEmitter::SetLocation( + ast::DecorationList* decos, + ast::Decoration* replacement) { + if (!replacement) { + return nullptr; + } + for (auto*& deco : *decos) { + if (deco->Is()) { + // Replace this location decoration with a new one with one higher index. + // The old one doesn't leak because it's kept in the builder's AST node + // list. + ast::Decoration* result = deco; + deco = replacement; + return result; + } + } + // The list didn't have a location. Add it. + decos->push_back(replacement); + return nullptr; } bool FunctionEmitter::EmitPipelineOutput(std::string var_name, @@ -1095,11 +1123,14 @@ bool FunctionEmitter::EmitPipelineOutput(std::string var_name, index_prefix.push_back(0); for (int i = 0; i < static_cast(members.size()); ++i) { index_prefix.back() = i; + auto* location = parser_impl_.GetMemberLocation(*struct_type, i); + auto* saved_location = SetLocation(decos, location); if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, members[i], forced_member_type, return_members, return_exprs)) { return false; } + SetLocation(decos, saved_location); } return success(); } @@ -1150,14 +1181,7 @@ bool FunctionEmitter::EmitPipelineOutput(std::string var_name, return_exprs->push_back(load_source); // Increment the location attribute, in case more parameters will follow. - for (auto*& deco : *decos) { - if (auto* loc_deco = deco->As()) { - // Replace this location decoration with a new one with one higher index. - // The old one doesn't leak because it's kept in the builder's AST node - // list. - deco = builder_.Location(loc_deco->source(), loc_deco->value() + 1); - } - } + IncrementLocation(decos); return success(); } diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h index 3a287a4b6c..06f04505ec 100644 --- a/src/reader/spirv/function.h +++ b/src/reader/spirv/function.h @@ -466,6 +466,23 @@ class FunctionEmitter { ast::StructMemberList* return_members, ast::ExpressionList* return_exprs); + /// Updates the decoration list, replacing an existing Location decoration + /// with another having one higher location value. Does nothing if no + /// location decoration exists. + /// Assumes the list contains at most one Location decoration. + /// @param decos the decoration list to modify + void IncrementLocation(ast::DecorationList* decos); + + /// Updates the decoration list, placing a non-null location decoration into + /// the list, replacing an existing one if it exists. Does nothing if the + /// replacement is nullptr. + /// Assumes the list contains at most one Location decoration. + /// @param decos the decoration list to modify + /// @param replacement the location decoration to place into the list + /// @returns the location decoration that was replaced, if one was replaced. + ast::Decoration* SetLocation(ast::DecorationList* decos, + ast::Decoration* replacement); + /// Create an ast::BlockStatement representing the body of the function. /// This creates the statement stack, which is non-empty for the lifetime /// of the function. diff --git a/src/reader/spirv/function_var_test.cc b/src/reader/spirv/function_var_test.cc index cf2e9aa072..9b07da7d32 100644 --- a/src/reader/spirv/function_var_test.cc +++ b/src/reader/spirv/function_var_test.cc @@ -2784,7 +2784,6 @@ TEST_F(SpvParserFunctionVarTest, DISABLED_EmitStatement_Hoist_UsedAsPtrArg) { OpReturn OpFunctionEnd )"; - std::cout << assembly << std::endl; auto p = parser(test::Assemble(assembly)); ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error() << assembly; auto fe = p->function_emitter(100); diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc index db4383a4e8..c5830c5438 100644 --- a/src/reader/spirv/parser_impl.cc +++ b/src/reader/spirv/parser_impl.cc @@ -1098,6 +1098,8 @@ const Type* ParserImpl::ConvertType( // apply the ReadOnly access control to the containing struct if all // the members are non-writable. is_non_writable = true; + } else if (decoration[0] == SpvDecorationLocation) { + // Location decorations are handled when emitting the entry point. } else { auto* ast_member_decoration = ConvertMemberDecoration(type_id, member_index, decoration); @@ -1134,7 +1136,7 @@ const Type* ParserImpl::ConvertType( // Now make the struct. auto sym = builder_.Symbols().Register(name); ast::DecorationList ast_struct_decorations; - if (is_block_decorated) { + if (is_block_decorated && struct_types_for_buffers_.count(type_id)) { ast_struct_decorations.emplace_back( create(Source{})); } @@ -1208,6 +1210,37 @@ bool ParserImpl::RegisterTypes() { if (!success_) { return false; } + + // First record the structure types that should have a `block` decoration + // in WGSL. In particular, exclude user-defined pipeline IO in a + // block-decorated struct. + for (const auto& type_or_value : module_->types_values()) { + if (type_or_value.opcode() != SpvOpVariable) { + continue; + } + const auto& var = type_or_value; + const auto spirv_storage_class = + SpvStorageClass(var.GetSingleWordInOperand(0)); + if ((spirv_storage_class != SpvStorageClassStorageBuffer) && + (spirv_storage_class != SpvStorageClassUniform)) { + continue; + } + const auto* ptr_type = def_use_mgr_->GetDef(var.type_id()); + if (ptr_type->opcode() != SpvOpTypePointer) { + return Fail() << "OpVariable type expected to be a pointer: " + << var.PrettyPrint(); + } + const auto* store_type = + def_use_mgr_->GetDef(ptr_type->GetSingleWordInOperand(1)); + if (store_type->opcode() == SpvOpTypeStruct) { + struct_types_for_buffers_.insert(store_type->result_id()); + } else { + Fail() << "WGSL does not support arrays of buffers: " + << var.PrettyPrint(); + } + } + + // Now convert each type. for (auto& type_or_const : module_->types_values()) { const auto* type = type_mgr_->GetType(type_or_const.result_id()); if (type == nullptr) { @@ -2630,6 +2663,22 @@ std::string ParserImpl::GetMemberName(const Struct& struct_type, return namer_.GetMemberName(where->second, member_index); } +ast::Decoration* ParserImpl::GetMemberLocation(const Struct& struct_type, + int member_index) { + auto where = struct_id_for_symbol_.find(struct_type.name); + if (where == struct_id_for_symbol_.end()) { + Fail() << "no structure type registered for symbol"; + return nullptr; + } + const auto type_id = where->second; + for (auto& deco : GetDecorationsForMember(type_id, member_index)) { + if ((deco.size() == 2) && (deco[0] == SpvDecorationLocation)) { + return create(Source{}, deco[1]); + } + } + return nullptr; +} + WorkgroupSizeInfo::WorkgroupSizeInfo() = default; WorkgroupSizeInfo::~WorkgroupSizeInfo() = default; diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h index 3321fd95c8..b041605f45 100644 --- a/src/reader/spirv/parser_impl.h +++ b/src/reader/spirv/parser_impl.h @@ -379,6 +379,13 @@ class ParserImpl : Reader { /// @returns the field name std::string GetMemberName(const Struct& struct_type, int member_index); + /// Returns the location decoration, if any on a struct member. + /// @param struct_type the parser's structure type. + /// @param member_index the member index + /// @returns a newly created location node, or nullptr + ast::Decoration* GetMemberLocation(const Struct& struct_type, + int member_index); + /// Creates an AST Variable node for a SPIR-V ID, including any attached /// decorations, unless it's an ignorable builtin variable. /// @param id the SPIR-V result ID @@ -765,6 +772,10 @@ class ParserImpl : Reader { // "NonSemanticInfo." import is ignored. std::unordered_set ignored_imports_; + // The SPIR-V IDs of structure types that are the store type for buffer + // variables, either UBO or SSBO. + std::unordered_set struct_types_for_buffers_; + // Bookkeeping for the gl_Position builtin. // In Vulkan SPIR-V, it's the 0 member of the gl_PerVertex structure. // But in WGSL we make a module-scope variable: diff --git a/src/reader/spirv/parser_impl_module_var_test.cc b/src/reader/spirv/parser_impl_module_var_test.cc index 78e1ea94d2..beeef8587c 100644 --- a/src/reader/spirv/parser_impl_module_var_test.cc +++ b/src/reader/spirv/parser_impl_module_var_test.cc @@ -6528,7 +6528,7 @@ TEST_F(SpvModuleScopeVarParserTest, Input_FlattenMatrix) { EXPECT_EQ(got, expected) << got; } -TEST_F(SpvModuleScopeVarParserTest, Input_FlattenStruct) { +TEST_F(SpvModuleScopeVarParserTest, Input_FlattenStruct_LocOnVariable) { const std::string assembly = R"( OpCapability Shader OpMemoryModel Logical Simple @@ -6993,7 +6993,7 @@ TEST_F(SpvModuleScopeVarParserTest, Output_FlattenMatrix) { EXPECT_EQ(got, expected) << got; } -TEST_F(SpvModuleScopeVarParserTest, Output_FlattenStruct) { +TEST_F(SpvModuleScopeVarParserTest, Output_FlattenStruct_LocOnVariable) { const std::string assembly = R"( OpCapability Shader OpMemoryModel Logical Simple @@ -7092,6 +7092,192 @@ TEST_F(SpvModuleScopeVarParserTest, Output_FlattenStruct) { EXPECT_EQ(got, expected) << got; } +TEST_F(SpvModuleScopeVarParserTest, FlattenStruct_LocOnMembers) { + // Block-decorated struct may have its members decorated with Location. + const std::string assembly = R"( + OpCapability Shader + OpMemoryModel Logical Simple + OpEntryPoint Vertex %main "main" %1 %2 %3 + + OpName %strct "Communicators" + OpMemberName %strct 0 "alice" + OpMemberName %strct 1 "bob" + + OpMemberDecorate %strct 0 Location 9 + OpMemberDecorate %strct 1 Location 11 + OpDecorate %strct Block + OpDecorate %2 BuiltIn Position + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %strct = OpTypeStruct %float %v4float + + %11 = OpTypePointer Input %strct + %13 = OpTypePointer Output %strct + + %1 = OpVariable %11 Input + %3 = OpVariable %13 Output + + %12 = OpTypePointer Output %v4float + %2 = OpVariable %12 Output + + %main = OpFunction %void None %voidfn + %entry = OpLabel + OpReturn + OpFunctionEnd +)"; + auto p = parser(test::Assemble(assembly)); + + ASSERT_TRUE(p->Parse()) << p->error() << assembly; + EXPECT_TRUE(p->error().empty()); + + const auto got = p->program().to_str(); + const std::string expected = R"(Module{ + Struct Communicators { + StructMember{alice: __f32} + StructMember{bob: __vec_4__f32} + } + Struct main_out { + StructMember{[[ BuiltinDecoration{position} + ]] x_2_1: __vec_4__f32} + StructMember{[[ LocationDecoration{9} + ]] x_3_1: __f32} + StructMember{[[ LocationDecoration{11} + ]] x_3_2: __vec_4__f32} + } + Variable{ + x_1 + private + undefined + __type_name_Communicators + } + Variable{ + x_3 + private + undefined + __type_name_Communicators + } + Variable{ + x_2 + private + undefined + __vec_4__f32 + } + Function main_1 -> __void + () + { + Return{} + } + Function main -> __type_name_main_out + StageDecoration{vertex} + ( + VariableConst{ + Decorations{ + LocationDecoration{9} + } + x_1_param + none + undefined + __f32 + } + VariableConst{ + Decorations{ + LocationDecoration{11} + } + x_1_param_1 + none + undefined + __vec_4__f32 + } + ) + { + Assignment{ + MemberAccessor[not set]{ + Identifier[not set]{x_1} + Identifier[not set]{alice} + } + Identifier[not set]{x_1_param} + } + Assignment{ + MemberAccessor[not set]{ + Identifier[not set]{x_1} + Identifier[not set]{bob} + } + Identifier[not set]{x_1_param_1} + } + Call[not set]{ + Identifier[not set]{main_1} + ( + ) + } + Return{ + { + TypeConstructor[not set]{ + __type_name_main_out + Identifier[not set]{x_2} + MemberAccessor[not set]{ + Identifier[not set]{x_3} + Identifier[not set]{alice} + } + MemberAccessor[not set]{ + Identifier[not set]{x_3} + Identifier[not set]{bob} + } + } + } + } + } +} +)"; + EXPECT_EQ(got, expected) << got; +} + +TEST_F(SpvModuleScopeVarParserTest, FlattenStruct_LocOnStruct) { + const std::string assembly = R"( + OpCapability Shader + OpMemoryModel Logical Simple + OpEntryPoint Vertex %main "main" %1 %2 %3 + + OpName %strct "Communicators" + OpMemberName %strct 0 "alice" + OpMemberName %strct 1 "bob" + + OpDecorate %strct Location 9 + OpDecorate %strct Block + OpDecorate %2 BuiltIn Position + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %strct = OpTypeStruct %float %v4float + + %11 = OpTypePointer Input %strct + %13 = OpTypePointer Output %strct + + %1 = OpVariable %11 Input + %3 = OpVariable %13 Output + + %12 = OpTypePointer Output %v4float + %2 = OpVariable %12 Output + + %main = OpFunction %void None %voidfn + %entry = OpLabel + OpReturn + OpFunctionEnd +)"; + auto p = parser(test::Assemble(assembly)); + + // The validator rejects this because Location decorations + // can only go on OpVariable or members of a structure type. + ASSERT_FALSE(p->Parse()) << p->error() << assembly; + EXPECT_THAT(p->error(), + HasSubstr("Location decoration can only be applied to a variable " + "or member of a structure type")); +} + } // namespace } // namespace spirv } // namespace reader