diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index db5c51d972..341e68775f 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -855,36 +855,213 @@ bool FunctionEmitter::Emit() { return true; } + // The function declaration, corresponding to how it's written in SPIR-V, + // and without regard to whether it's an entry point. FunctionDeclaration decl; if (!ParseFunctionDeclaration(&decl)) { return false; } + bool make_body_function = true; + if (ep_info_) { + if (ep_info_->inner_name.empty()) { + // This is an entry point, and we don't want to emit it as a wrapper + // around its own body. Emit it as one function. + decl.name = ep_info_->name; + decl.decorations.emplace_back( + create(Source{}, ep_info_->stage)); + } else if (ep_info_->owns_inner_implementation) { + // This is an entry point, and we want to emit it as a wrapper around + // an implementation function. + decl.name = ep_info_->inner_name; + } else { + // This is a second entry point that shares an inner implementation + // function. + make_body_function = false; + } + } + + if (make_body_function) { + auto* body = MakeFunctionBody(); + if (!body) { + return false; + } + + builder_.AST().AddFunction(create( + decl.source, builder_.Symbols().Register(decl.name), + std::move(decl.params), decl.return_type->Build(builder_), body, + std::move(decl.decorations), ast::DecorationList{})); + } + + if (ep_info_ && !ep_info_->inner_name.empty()) { + return EmitEntryPointAsWrapper(); + } + + return success(); +} + +ast::BlockStatement* FunctionEmitter::MakeFunctionBody() { + TINT_ASSERT(statements_stack_.size() == 1); + if (!EmitBody()) { - return false; + return nullptr; } // Set the body of the AST function node. if (statements_stack_.size() != 1) { - return Fail() << "internal error: statement-list stack should have 1 " - "element but has " - << statements_stack_.size(); + Fail() << "internal error: statement-list stack should have 1 " + "element but has " + << statements_stack_.size(); + return nullptr; } statements_stack_[0].Finalize(&builder_); - auto& statements = statements_stack_[0].GetStatements(); auto* body = create(Source{}, statements); - builder_.AST().AddFunction(create( - decl.source, builder_.Symbols().Register(decl.name), - std::move(decl.params), decl.return_type->Build(builder_), body, - std::move(decl.decorations), ast::DecorationList{})); // Maintain the invariant by repopulating the one and only element. statements_stack_.clear(); PushNewStatementBlock(constructs_[0].get(), 0, nullptr); - return success(); + return body; +} + +bool FunctionEmitter::EmitEntryPointAsWrapper() { + Source source; + + // The statements in the body. + ast::StatementList stmts; + + FunctionDeclaration decl; + decl.source = source; + decl.name = ep_info_->name; + ast::Type* return_type = nullptr; // Populated below. + + // Pipeline inputs become parameters to the wrapper function, and + // their values are saved into the corresponding private variables that + // have already been created. + for (uint32_t var_id : ep_info_->inputs) { + const auto* var = def_use_mgr_->GetDef(var_id); + TINT_ASSERT(var != nullptr); + TINT_ASSERT(var->opcode() == SpvOpVariable); + auto* store_type = GetVariableStoreType(*var); + auto* forced_store_type = store_type; + ast::DecorationList param_decos; + if (!parser_impl_.ConvertDecorationsForVariable(var_id, &forced_store_type, + ¶m_decos)) { + // This occurs, and is not an error, for the PointSize builtin. + if (!success()) { + // But exit early if an error was logged. + return false; + } + continue; + } + + const auto var_name = namer_.GetName(var_id); + const auto var_sym = builder_.Symbols().Register(var_name); + const auto param_name = namer_.MakeDerivedName(var_name + "_param"); + const auto param_sym = builder_.Symbols().Register(param_name); + auto* param = create( + source, param_sym, ast::StorageClass::kNone, + forced_store_type->Build(builder_), true /* is const */, + nullptr /* no constructor */, param_decos); + decl.params.push_back(param); + + // Add a body statement to copy the parameter to the corresponding private + // variable. + ast::Expression* param_value = + create(source, param_sym); + if (forced_store_type != store_type) { + // Insert a bitcast if needed. + const auto cast_name = namer_.MakeDerivedName(param_name + "_cast"); + const auto cast_sym = builder_.Symbols().Register(cast_name); + + param_value = create( + source, forced_store_type->Build(builder_), param_value); + } + + stmts.push_back(create( + source, create(source, var_sym), + param_value)); + } + + // Call the inner function. It has no parameters. + stmts.push_back(create( + source, + create( + source, + create( + source, builder_.Symbols().Register(ep_info_->inner_name)), + ast::ExpressionList{}))); + + if (ep_info_->outputs.empty()) { + return_type = ty_.Void()->Build(builder_); + } else { + // Pipeline outputs are converted to a structure that is written + // to just before returning. + + const auto return_struct_name = + namer_.MakeDerivedName(ep_info_->name + "_out"); + const auto return_struct_sym = + builder_.Symbols().Register(return_struct_name); + + // Define the structure. + ast::ExpressionList return_exprs; + std::vector return_members; + for (uint32_t var_id : ep_info_->outputs) { + const auto* var = def_use_mgr_->GetDef(var_id); + TINT_ASSERT(var != nullptr); + TINT_ASSERT(var->opcode() == SpvOpVariable); + const auto* store_type = GetVariableStoreType(*var); + const auto* forced_store_type = store_type; + ast::DecorationList out_decos; + if (!parser_impl_.ConvertDecorationsForVariable( + var_id, &forced_store_type, &out_decos)) { + // This occurs, and is not an error, for the PointSize builtin. + continue; + } + + // TODO(dneto): flatten structs and arrays to vectors or scalars. + // The Per-vertex structure is already flattened. + + // The member name is the same as the variable name, which is already + // unique across all module-scope declarations. + const auto var_name = namer_.GetName(var_id); + const auto var_sym = builder_.Symbols().Register(var_name); + + // Form the member type. + // Reuse the var name for the member name. They can't clash. + ast::StructMember* return_member = create( + Source{}, var_sym, forced_store_type->Build(builder_), + std::move(out_decos)); + return_members.push_back(return_member); + + // Save the expression. + return_exprs.push_back( + create(source, var_sym)); + } + + // Create and register the result type. + return_type = create( + Source{}, return_struct_sym, return_members, ast::DecorationList{}); + parser_impl_.AddConstructedType(return_struct_sym, return_type->As()); + + // Add the return-value statement. + stmts.push_back(create( + source, create( + source, return_type, std::move(return_exprs)))); + } + + auto* body = create(source, stmts); + ast::DecorationList fn_decos; + fn_decos.emplace_back(create(source, ep_info_->stage)); + + builder_.AST().AddFunction( + create(source, builder_.Symbols().Register(ep_info_->name), + std::move(decl.params), return_type, body, + std::move(fn_decos), ast::DecorationList{})); + + return true; } bool FunctionEmitter::ParseFunctionDeclaration(FunctionDeclaration* decl) { @@ -892,12 +1069,7 @@ bool FunctionEmitter::ParseFunctionDeclaration(FunctionDeclaration* decl) { return false; } - std::string name; - if (ep_info_ == nullptr) { - name = namer_.Name(function_.result_id()); - } else { - name = ep_info_->name; - } + const std::string name = namer_.Name(function_.result_id()); // Surprisingly, the "type id" on an OpFunction is the result type of the // function, not the type of the function. This is the one exceptional case @@ -932,15 +1104,10 @@ bool FunctionEmitter::ParseFunctionDeclaration(FunctionDeclaration* decl) { if (failed()) { return false; } - ast::DecorationList decos; - if (ep_info_ != nullptr) { - decos.emplace_back(create(Source{}, ep_info_->stage)); - } - decl->name = name; decl->params = std::move(ast_params); decl->return_type = ret_ty; - decl->decorations = std::move(decos); + decl->decorations.clear(); return success(); } diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h index 58203d99ca..edff9079ef 100644 --- a/src/reader/spirv/function.h +++ b/src/reader/spirv/function.h @@ -423,6 +423,16 @@ class FunctionEmitter { /// @returns the parser implementation ParserImpl* parser() { return &parser_impl_; } + /// Emits the entry point as a wrapper around its implementation function. + /// @returns false if emission failed. + bool EmitEntryPointAsWrapper(); + + /// Create an ast::BlockStatement representing the body of the function. + /// This creates the statement stack, which is non-empty for the lifetime + /// of the function. + /// @returns the body of the function, or null on error + ast::BlockStatement* MakeFunctionBody(); + /// Emits the function body, populating the bottom entry of the statements /// stack. /// @returns false if emission failed. diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc index c8d3c64f1c..068d07764d 100644 --- a/src/reader/spirv/parser_impl.cc +++ b/src/reader/spirv/parser_impl.cc @@ -1365,8 +1365,12 @@ ast::Variable* ParserImpl::MakeVariable(uint32_t id, sc = ast::StorageClass::kNone; } - if (!ConvertDecorationsForVariable(id, &type, &decorations)) { - return nullptr; + // In almost all cases, copy the decorations from SPIR-V to the variable. + // But avoid doing so when converting pipeline IO to private variables. + if (sc != ast::StorageClass::kPrivate) { + if (!ConvertDecorationsForVariable(id, &type, &decorations)) { + return nullptr; + } } std::string name = namer_.Name(id); diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h index 834c1a5345..09807a18ff 100644 --- a/src/reader/spirv/parser_impl.h +++ b/src/reader/spirv/parser_impl.h @@ -176,6 +176,11 @@ class ParserImpl : Reader { const spvtools::opt::analysis::Type* type, const Type* ast_type); + /// Adds `type` as a constructed type if it hasn't been added yet. + /// @param name the type's unique name + /// @param type the type to add + void AddConstructedType(Symbol name, ast::NamedType* type); + /// @returns the fail stream object FailStream& fail_stream() { return fail_stream_; } /// @returns the namer object @@ -635,11 +640,6 @@ class ParserImpl : Reader { bool ParseArrayDecorations(const spvtools::opt::analysis::Type* spv_type, uint32_t* array_stride); - /// Adds `type` as a constructed type if it hasn't been added yet. - /// @param name the type's unique name - /// @param type the type to add - void AddConstructedType(Symbol name, ast::NamedType* type); - /// Creates a new `ast::Node` owned by the ProgramBuilder. /// @param args the arguments to pass to the type constructor /// @returns the node pointer diff --git a/src/reader/spirv/parser_impl_module_var_test.cc b/src/reader/spirv/parser_impl_module_var_test.cc index c8afb2ca80..97d2aeda13 100644 --- a/src/reader/spirv/parser_impl_module_var_test.cc +++ b/src/reader/spirv/parser_impl_module_var_test.cc @@ -52,6 +52,13 @@ std::string MainBody() { )"; } +std::string CommonCapabilities() { + return R"( + OpCapability Shader + OpMemoryModel Logical Simple +)"; +} + std::string CommonTypes() { return R"( %void = OpTypeVoid @@ -3837,6 +3844,118 @@ TEST_F(SpvModuleScopeVarParserTest, OutputVarsConvertedToPrivate) { EXPECT_THAT(got, HasSubstr(expected)) << got; } +TEST_F(SpvModuleScopeVarParserTest, EntryPointWrapping_IOLocations) { + const auto assembly = CommonCapabilities() + R"( + OpEntryPoint Vertex %main "main" %1 %2 %3 %4 + OpDecorate %1 Location 0 + OpDecorate %2 Location 0 + OpDecorate %3 Location 30 + OpDecorate %4 Location 40 +)" + CommonTypes() + + R"( + %ptr_in_uint = OpTypePointer Input %uint + %ptr_out_uint = OpTypePointer Output %uint + %1 = OpVariable %ptr_in_uint Input + %2 = OpVariable %ptr_out_uint Output + %3 = OpVariable %ptr_in_uint Input + %4 = OpVariable %ptr_out_uint Output + + %main = OpFunction %void None %voidfn + %entry = OpLabel + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + + // TODO(crbug.com/tint/508): Remove this when everything is converted + // to HLSL style pipeline IO. + p->SetHLSLStylePipelineIO(); + + ASSERT_TRUE(p->BuildAndParseInternalModule()); + EXPECT_TRUE(p->error().empty()); + const auto got = p->program().to_str(); + const std::string expected = + R"( + Struct main_out { + StructMember{[[ LocationDecoration{0} + ]] x_2: __u32} + StructMember{[[ LocationDecoration{40} + ]] x_4: __u32} + } + Variable{ + x_1 + private + __u32 + } + Variable{ + x_2 + private + __u32 + } + Variable{ + x_3 + private + __u32 + } + Variable{ + x_4 + private + __u32 + } + Function main_1 -> __void + () + { + Return{} + } + Function main -> __struct_main_out + StageDecoration{vertex} + ( + VariableConst{ + Decorations{ + LocationDecoration{0} + } + x_1_param + none + __u32 + } + VariableConst{ + Decorations{ + LocationDecoration{30} + } + x_3_param + none + __u32 + } + ) + { + Assignment{ + Identifier[not set]{x_1} + Identifier[not set]{x_1_param} + } + Assignment{ + Identifier[not set]{x_3} + Identifier[not set]{x_3_param} + } + Call[not set]{ + Identifier[not set]{main_1} + ( + ) + } + Return{ + { + TypeConstructor[not set]{ + __struct_main_out + Identifier[not set]{x_2} + Identifier[not set]{x_4} + } + } + } + } +} +)"; + EXPECT_THAT(got, HasSubstr(expected)) << got; +} + // TODO(dneto): Test passing pointer to SampleMask as function parameter, // both input case and output case.