diff --git a/src/ast/function.cc b/src/ast/function.cc index 3b1e839720..f166031822 100644 --- a/src/ast/function.cc +++ b/src/ast/function.cc @@ -42,6 +42,15 @@ Function::Function(Function&&) = default; Function::~Function() = default; +void Function::add_referenced_module_variable(Variable* var) { + for (const auto* v : referenced_module_vars_) { + if (v->name() == var->name()) { + return; + } + } + referenced_module_vars_.push_back(var); +} + bool Function::IsValid() const { for (const auto& param : params_) { if (param == nullptr || !param->IsValid()) diff --git a/src/ast/function.h b/src/ast/function.h index c6ee88c931..fa268497d5 100644 --- a/src/ast/function.h +++ b/src/ast/function.h @@ -68,6 +68,15 @@ class Function : public Node { /// @returns the function params const VariableList& params() const { return params_; } + /// Adds the given variable to the list of referenced module variables if it + /// is not already included. + /// @param var the module variable to add + void add_referenced_module_variable(Variable* var); + /// @returns the referenced module variables + const std::vector& referenced_module_variables() const { + return referenced_module_vars_; + } + /// Sets the return type of the function /// @param type the return type void set_return_type(type::Type* type) { return_type_ = type; } @@ -98,6 +107,7 @@ class Function : public Node { VariableList params_; type::Type* return_type_ = nullptr; StatementList body_; + std::vector referenced_module_vars_; }; /// A list of unique functions diff --git a/src/ast/function_test.cc b/src/ast/function_test.cc index 300244c09f..5b0c394031 100644 --- a/src/ast/function_test.cc +++ b/src/ast/function_test.cc @@ -57,6 +57,26 @@ TEST_F(FunctionTest, Creation_WithSource) { EXPECT_EQ(src.column, 2u); } +TEST_F(FunctionTest, AddDuplicateReferencedVariables) { + type::VoidType void_type; + type::I32Type i32; + + Variable v("var", StorageClass::kInput, &i32); + Function f("func", VariableList{}, &void_type); + + f.add_referenced_module_variable(&v); + ASSERT_EQ(f.referenced_module_variables().size(), 1u); + EXPECT_EQ(f.referenced_module_variables()[0], &v); + + f.add_referenced_module_variable(&v); + ASSERT_EQ(f.referenced_module_variables().size(), 1u); + + Variable v2("var2", StorageClass::kOutput, &i32); + f.add_referenced_module_variable(&v2); + ASSERT_EQ(f.referenced_module_variables().size(), 2u); + EXPECT_EQ(f.referenced_module_variables()[1], &v2); +} + TEST_F(FunctionTest, IsValid) { type::VoidType void_type; type::I32Type i32; diff --git a/src/ast/variable.h b/src/ast/variable.h index b84ef82080..f94036375c 100644 --- a/src/ast/variable.h +++ b/src/ast/variable.h @@ -105,7 +105,7 @@ class Variable : public Node { /// @param name the name to set void set_name(const std::string& name) { name_ = name; } /// @returns the variable name - const std::string& name() { return name_; } + const std::string& name() const { return name_; } /// Sets the value type if a const or formal parameter, or the /// store type if a var. diff --git a/src/type_determiner.cc b/src/type_determiner.cc index 23de278722..fa6a03ce6f 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -161,6 +161,19 @@ void TypeDeterminer::set_error(const Source& src, const std::string& msg) { error_ += msg; } +void TypeDeterminer::set_referenced_from_function_if_needed( + ast::Variable* var) { + if (current_function_ == nullptr) { + return; + } + if (var->storage_class() == ast::StorageClass::kNone || + var->storage_class() == ast::StorageClass::kFunction) { + return; + } + + current_function_->add_referenced_module_variable(var); +} + bool TypeDeterminer::Determine() { for (const auto& var : mod_->global_variables()) { variable_stack_.set_global(var->name(), var.get()); @@ -190,6 +203,8 @@ bool TypeDeterminer::DetermineFunctions(const ast::FunctionList& funcs) { bool TypeDeterminer::DetermineFunction(ast::Function* func) { name_to_function_[func->name()] = func; + current_function_ = func; + variable_stack_.push_scope(); for (const auto& param : func->params()) { variable_stack_.set(param->name(), param.get()); @@ -200,6 +215,8 @@ bool TypeDeterminer::DetermineFunction(ast::Function* func) { } variable_stack_.pop_scope(); + current_function_ = nullptr; + return true; } @@ -567,6 +584,8 @@ bool TypeDeterminer::DetermineIdentifier(ast::IdentifierExpression* expr) { ctx_.type_mgr().Get(std::make_unique( var->type(), var->storage_class()))); } + + set_referenced_from_function_if_needed(var); return true; } diff --git a/src/type_determiner.h b/src/type_determiner.h index 7f62a57c51..80f239762c 100644 --- a/src/type_determiner.h +++ b/src/type_determiner.h @@ -104,6 +104,7 @@ class TypeDeterminer { private: void set_error(const Source& src, const std::string& msg); + void set_referenced_from_function_if_needed(ast::Variable* var); bool DetermineArrayAccessor(ast::ArrayAccessorExpression* expr); bool DetermineAs(ast::AsExpression* expr); @@ -121,6 +122,7 @@ class TypeDeterminer { std::string error_; ScopeStack variable_stack_; std::unordered_map name_to_function_; + ast::Function* current_function_ = nullptr; }; } // namespace tint diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index 15fd077d57..319bfdfb23 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -743,6 +743,93 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_Function) { EXPECT_TRUE(ident.result_type()->IsF32()); } +TEST_F(TypeDeterminerTest, Function_RegisterInputOutputVariables) { + ast::type::F32Type f32; + + auto in_var = std::make_unique( + "in_var", ast::StorageClass::kInput, &f32); + auto out_var = std::make_unique( + "out_var", ast::StorageClass::kOutput, &f32); + auto sb_var = std::make_unique( + "sb_var", ast::StorageClass::kStorageBuffer, &f32); + auto wg_var = std::make_unique( + "wg_var", ast::StorageClass::kWorkgroup, &f32); + auto priv_var = std::make_unique( + "priv_var", ast::StorageClass::kPrivate, &f32); + + auto in_ptr = in_var.get(); + auto out_ptr = out_var.get(); + auto sb_ptr = sb_var.get(); + auto wg_ptr = wg_var.get(); + auto priv_ptr = priv_var.get(); + + mod()->AddGlobalVariable(std::move(in_var)); + mod()->AddGlobalVariable(std::move(out_var)); + mod()->AddGlobalVariable(std::move(sb_var)); + mod()->AddGlobalVariable(std::move(wg_var)); + mod()->AddGlobalVariable(std::move(priv_var)); + + ast::VariableList params; + auto func = + std::make_unique("my_func", std::move(params), &f32); + auto func_ptr = func.get(); + + ast::StatementList body; + body.push_back(std::make_unique( + std::make_unique("out_var"), + std::make_unique("in_var"))); + body.push_back(std::make_unique( + std::make_unique("wg_var"), + std::make_unique("wg_var"))); + body.push_back(std::make_unique( + std::make_unique("sb_var"), + std::make_unique("sb_var"))); + body.push_back(std::make_unique( + std::make_unique("priv_var"), + std::make_unique("priv_var"))); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + + // Register the function + EXPECT_TRUE(td()->Determine()); + + const auto& vars = func_ptr->referenced_module_variables(); + ASSERT_EQ(vars.size(), 5); + EXPECT_EQ(vars[0], out_ptr); + EXPECT_EQ(vars[1], in_ptr); + EXPECT_EQ(vars[2], wg_ptr); + EXPECT_EQ(vars[3], sb_ptr); + EXPECT_EQ(vars[4], priv_ptr); +} + +TEST_F(TypeDeterminerTest, Function_NotRegisterFunctionVariable) { + ast::type::F32Type f32; + + auto var = std::make_unique( + "in_var", ast::StorageClass::kFunction, &f32); + + ast::VariableList params; + auto func = + std::make_unique("my_func", std::move(params), &f32); + auto func_ptr = func.get(); + + ast::StatementList body; + body.push_back(std::make_unique(std::move(var))); + body.push_back(std::make_unique( + std::make_unique("var"), + std::make_unique( + std::make_unique(&f32, 1.f)))); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + + // Register the function + EXPECT_TRUE(td()->Determine()); + + EXPECT_EQ(func_ptr->referenced_module_variables().size(), 0); +} + TEST_F(TypeDeterminerTest, Expr_MemberAccessor_Struct) { ast::type::I32Type i32; ast::type::F32Type f32; diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 0a81e8eb03..93f3133ebb 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -164,6 +164,8 @@ bool Builder::Build() { } } + // Note, the entry points must be generated after the functions as they need + // to be able to lookup the function information based on the name. for (const auto& ep : mod_->entry_points()) { if (!GenerateEntryPoint(ep.get())) { return false; @@ -296,10 +298,16 @@ bool Builder::GenerateEntryPoint(ast::EntryPoint* ep) { OperandList operands = {Operand::Int(stage), Operand::Int(id), Operand::String(name)}; - // TODO(dsinclair): This could be made smarter by only listing the - // input/output variables which are used by the entry point instead of just - // listing all module scoped variables of type input/output. - for (const auto& var : mod_->global_variables()) { + + auto* func = func_name_to_func_[ep->function_name()]; + if (func == nullptr) { + error_ = "processing an entry point when the function has not been seen."; + return false; + } + + for (const auto* var : func->referenced_module_variables()) { + // For SPIR-V 1.3 we only output Input/output variables. If we update to + // SPIR-V 1.4 or later this should be all variables. if (var->storage_class() != ast::StorageClass::kInput && var->storage_class() != ast::StorageClass::kOutput) { continue; @@ -425,6 +433,7 @@ bool Builder::GenerateFunction(ast::Function* func) { scope_stack_.pop_scope(); func_name_to_id_[func->name()] = func_id; + func_name_to_func_[func->name()] = func; return true; } diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h index 4213462e98..91505e82a0 100644 --- a/src/writer/spirv/builder.h +++ b/src/writer/spirv/builder.h @@ -84,36 +84,6 @@ class Builder { return id; } - /// Sets the id for a given function name - /// @param name the name to set - /// @param id the id to set - void set_func_name_to_id(const std::string& name, uint32_t id) { - func_name_to_id_[name] = id; - } - - /// Retrives the id for the given function name - /// @param name the function name to search for - /// @returns the id for the given name or 0 on failure - uint32_t id_for_func_name(const std::string& name) { - if (func_name_to_id_.count(name) == 0) { - return 0; - } - return func_name_to_id_[name]; - } - - /// Retrieves the id for an entry point function, or 0 if not found. - /// Emits an error if not found. - /// @param ep the entry point - /// @returns 0 on error - uint32_t id_for_entry_point(ast::EntryPoint* ep) { - auto id = id_for_func_name(ep->function_name()); - if (id == 0) { - error_ = "unable to find ID for function: " + ep->function_name(); - return 0; - } - return id; - } - /// Iterates over all the instructions in the correct order and calls the /// given callback /// @param cb the callback to execute @@ -402,6 +372,29 @@ class Builder { /// automatically. Operand result_op(); + /// Retrives the id for the given function name + /// @param name the function name to search for + /// @returns the id for the given name or 0 on failure + uint32_t id_for_func_name(const std::string& name) { + if (func_name_to_id_.count(name) == 0) { + return 0; + } + return func_name_to_id_[name]; + } + + /// Retrieves the id for an entry point function, or 0 if not found. + /// Emits an error if not found. + /// @param ep the entry point + /// @returns 0 on error + uint32_t id_for_entry_point(ast::EntryPoint* ep) { + auto id = id_for_func_name(ep->function_name()); + if (id == 0) { + error_ = "unable to find ID for function: " + ep->function_name(); + return 0; + } + return id; + } + ast::Module* mod_; std::string error_; uint32_t next_id_ = 1; @@ -415,6 +408,7 @@ class Builder { std::unordered_map import_name_to_id_; std::unordered_map func_name_to_id_; + std::unordered_map func_name_to_func_; std::unordered_map type_name_to_id_; std::unordered_map const_to_id_; ScopeStack scope_stack_; diff --git a/src/writer/spirv/builder_entry_point_test.cc b/src/writer/spirv/builder_entry_point_test.cc index 2f1d8e032c..34dc0f81c9 100644 --- a/src/writer/spirv/builder_entry_point_test.cc +++ b/src/writer/spirv/builder_entry_point_test.cc @@ -17,10 +17,16 @@ #include "gtest/gtest.h" #include "spirv/unified1/spirv.h" #include "spirv/unified1/spirv.hpp11" +#include "src/ast/assignment_statement.h" #include "src/ast/entry_point.h" +#include "src/ast/function.h" +#include "src/ast/identifier_expression.h" #include "src/ast/pipeline_stage.h" #include "src/ast/type/f32_type.h" +#include "src/ast/type/void_type.h" #include "src/ast/variable.h" +#include "src/context.h" +#include "src/type_determiner.h" #include "src/writer/spirv/builder.h" #include "src/writer/spirv/spv_dump.h" @@ -32,24 +38,30 @@ namespace { using BuilderTest = testing::Test; TEST_F(BuilderTest, EntryPoint) { + ast::type::VoidType void_type; + + ast::Function func("frag_main", {}, &void_type); ast::EntryPoint ep(ast::PipelineStage::kFragment, "main", "frag_main"); ast::Module mod; Builder b(&mod); - b.set_func_name_to_id("frag_main", 2); - ASSERT_TRUE(b.GenerateEntryPoint(&ep)); + ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); + ASSERT_TRUE(b.GenerateEntryPoint(&ep)) << b.error(); - EXPECT_EQ(DumpInstructions(b.preamble()), R"(OpEntryPoint Fragment %2 "main" + EXPECT_EQ(DumpInstructions(b.preamble()), R"(OpEntryPoint Fragment %3 "main" )"); } TEST_F(BuilderTest, EntryPoint_WithoutName) { + ast::type::VoidType void_type; + + ast::Function func("compute_main", {}, &void_type); ast::EntryPoint ep(ast::PipelineStage::kCompute, "", "compute_main"); ast::Module mod; Builder b(&mod); - b.set_func_name_to_id("compute_main", 3); - ASSERT_TRUE(b.GenerateEntryPoint(&ep)); + ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); + ASSERT_TRUE(b.GenerateEntryPoint(&ep)) << b.error(); EXPECT_EQ(DumpInstructions(b.preamble()), R"(OpEntryPoint GLCompute %3 "compute_main" @@ -77,12 +89,15 @@ using EntryPointStageTest = testing::TestWithParam; TEST_P(EntryPointStageTest, Emit) { auto params = GetParam(); + ast::type::VoidType void_type; + + ast::Function func("main", {}, &void_type); ast::EntryPoint ep(params.stage, "", "main"); ast::Module mod; Builder b(&mod); - b.set_func_name_to_id("main", 3); - ASSERT_TRUE(b.GenerateEntryPoint(&ep)); + ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); + ASSERT_TRUE(b.GenerateEntryPoint(&ep)) << b.error(); auto preamble = b.preamble(); ASSERT_EQ(preamble.size(), 1u); @@ -101,8 +116,12 @@ INSTANTIATE_TEST_SUITE_P( EntryPointStageData{ast::PipelineStage::kCompute, SpvExecutionModelGLCompute})); -TEST_F(BuilderTest, EntryPoint_WithInterfaceIds) { +TEST_F(BuilderTest, EntryPoint_WithUnusedInterfaceIds) { ast::type::F32Type f32; + ast::type::VoidType void_type; + + ast::Function func("main", {}, &void_type); + auto v_in = std::make_unique("my_in", ast::StorageClass::kInput, &f32); auto v_out = std::make_unique( @@ -121,11 +140,12 @@ TEST_F(BuilderTest, EntryPoint_WithInterfaceIds) { mod.AddGlobalVariable(std::move(v_out)); mod.AddGlobalVariable(std::move(v_wg)); - b.set_func_name_to_id("main", 3); - ASSERT_TRUE(b.GenerateEntryPoint(&ep)); + ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); + ASSERT_TRUE(b.GenerateEntryPoint(&ep)) << b.error(); EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %1 "my_in" OpName %4 "my_out" OpName %7 "my_wg" +OpName %11 "main" )"); EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32 %2 = OpTypePointer Input %3 @@ -135,35 +155,111 @@ OpName %7 "my_wg" %4 = OpVariable %5 Output %6 %8 = OpTypePointer Workgroup %3 %7 = OpVariable %8 Workgroup +%10 = OpTypeVoid +%9 = OpTypeFunction %10 )"); EXPECT_EQ(DumpInstructions(b.preamble()), - R"(OpEntryPoint Vertex %3 "main" %1 %4 + R"(OpEntryPoint Vertex %11 "main" +)"); +} + +TEST_F(BuilderTest, EntryPoint_WithUsedInterfaceIds) { + ast::type::F32Type f32; + ast::type::VoidType void_type; + + ast::Function func("main", {}, &void_type); + ast::StatementList body; + body.push_back(std::make_unique( + std::make_unique("my_out"), + std::make_unique("my_in"))); + body.push_back(std::make_unique( + std::make_unique("my_wg"), + std::make_unique("my_wg"))); + // Add duplicate usages so we show they don't get output multiple times. + body.push_back(std::make_unique( + std::make_unique("my_out"), + std::make_unique("my_in"))); + func.set_body(std::move(body)); + + auto v_in = + std::make_unique("my_in", ast::StorageClass::kInput, &f32); + auto v_out = std::make_unique( + "my_out", ast::StorageClass::kOutput, &f32); + auto v_wg = std::make_unique( + "my_wg", ast::StorageClass::kWorkgroup, &f32); + ast::EntryPoint ep(ast::PipelineStage::kVertex, "", "main"); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(v_in.get()); + td.RegisterVariableForTesting(v_out.get()); + td.RegisterVariableForTesting(v_wg.get()); + + ASSERT_TRUE(td.DetermineFunction(&func)) << td.error(); + + Builder b(&mod); + + EXPECT_TRUE(b.GenerateGlobalVariable(v_in.get())) << b.error(); + EXPECT_TRUE(b.GenerateGlobalVariable(v_out.get())) << b.error(); + EXPECT_TRUE(b.GenerateGlobalVariable(v_wg.get())) << b.error(); + + mod.AddGlobalVariable(std::move(v_in)); + mod.AddGlobalVariable(std::move(v_out)); + mod.AddGlobalVariable(std::move(v_wg)); + + ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); + ASSERT_TRUE(b.GenerateEntryPoint(&ep)) << b.error(); + EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %1 "my_in" +OpName %4 "my_out" +OpName %7 "my_wg" +OpName %11 "main" +)"); + EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32 +%2 = OpTypePointer Input %3 +%1 = OpVariable %2 Input +%5 = OpTypePointer Output %3 +%6 = OpConstantNull %3 +%4 = OpVariable %5 Output %6 +%8 = OpTypePointer Workgroup %3 +%7 = OpVariable %8 Workgroup +%10 = OpTypeVoid +%9 = OpTypeFunction %10 +)"); + EXPECT_EQ(DumpInstructions(b.preamble()), + R"(OpEntryPoint Vertex %11 "main" %4 %1 )"); } TEST_F(BuilderTest, ExecutionModel_Fragment_OriginUpperLeft) { + ast::type::VoidType void_type; + + ast::Function func("frag_main", {}, &void_type); ast::EntryPoint ep(ast::PipelineStage::kFragment, "main", "frag_main"); ast::Module mod; Builder b(&mod); - b.set_func_name_to_id("frag_main", 2); + ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateExecutionModes(&ep)); EXPECT_EQ(DumpInstructions(b.preamble()), - R"(OpExecutionMode %2 OriginUpperLeft + R"(OpExecutionMode %3 OriginUpperLeft )"); } TEST_F(BuilderTest, ExecutionModel_Compute_LocalSize) { + ast::type::VoidType void_type; + + ast::Function func("main", {}, &void_type); ast::EntryPoint ep(ast::PipelineStage::kCompute, "main", "main"); ast::Module mod; Builder b(&mod); - b.set_func_name_to_id("main", 2); + ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateExecutionModes(&ep)); EXPECT_EQ(DumpInstructions(b.preamble()), - R"(OpExecutionMode %2 LocalSize 1 1 1 + R"(OpExecutionMode %3 LocalSize 1 1 1 )"); }