diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index eee67439f6..1bfb0ad5e2 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -266,8 +266,26 @@ bool Builder::GenerateEntryPoint(ast::EntryPoint* ep) { return false; } - push_preamble(spv::Op::OpEntryPoint, - {Operand::Int(stage), Operand::Int(id), Operand::String(name)}); + std::vector operands = {Operand::Int(stage), Operand::Int(id), + Operand::String(name)}; + // TODO(dsinclair): This could be made smarter by only listing the + // input/output variables which are used by the entry point instead of just + // listing all module scoped variables of type input/output. + for (const auto& var : mod_->global_variables()) { + if (var->storage_class() != ast::StorageClass::kInput && + var->storage_class() != ast::StorageClass::kOutput) { + continue; + } + + uint32_t var_id; + if (!scope_stack_.get(var->name(), &var_id)) { + error_ = "unable to find ID for global variable: " + var->name(); + return false; + } + + operands.push_back(Operand::Int(var_id)); + } + push_preamble(spv::Op::OpEntryPoint, operands); return true; } diff --git a/src/writer/spirv/builder_entry_point_test.cc b/src/writer/spirv/builder_entry_point_test.cc index 788787beee..2c4a71bafc 100644 --- a/src/writer/spirv/builder_entry_point_test.cc +++ b/src/writer/spirv/builder_entry_point_test.cc @@ -19,6 +19,8 @@ #include "spirv/unified1/spirv.hpp11" #include "src/ast/entry_point.h" #include "src/ast/pipeline_stage.h" +#include "src/ast/type/f32_type.h" +#include "src/ast/variable.h" #include "src/writer/spirv/builder.h" #include "src/writer/spirv/spv_dump.h" @@ -103,8 +105,44 @@ INSTANTIATE_TEST_SUITE_P( EntryPointStageData{ast::PipelineStage::kCompute, SpvExecutionModelGLCompute})); -// TODO(http://crbug.com/tint/28) -TEST_F(BuilderTest, DISABLED_EntryPoint_WithInterfaceIds) {} +TEST_F(BuilderTest, EntryPoint_WithInterfaceIds) { + ast::type::F32Type f32; + auto v_in = + std::make_unique("my_in", ast::StorageClass::kInput, &f32); + auto v_out = std::make_unique( + "my_out", ast::StorageClass::kOutput, &f32); + auto v_wg = std::make_unique( + "my_wg", ast::StorageClass::kWorkgroup, &f32); + ast::EntryPoint ep(ast::PipelineStage::kVertex, "", "main"); + + ast::Module mod; + Builder b(&mod); + EXPECT_TRUE(b.GenerateGlobalVariable(v_in.get())) << b.error(); + EXPECT_TRUE(b.GenerateGlobalVariable(v_out.get())) << b.error(); + EXPECT_TRUE(b.GenerateGlobalVariable(v_wg.get())) << b.error(); + + mod.AddGlobalVariable(std::move(v_in)); + mod.AddGlobalVariable(std::move(v_out)); + mod.AddGlobalVariable(std::move(v_wg)); + + b.set_func_name_to_id("main", 3); + ASSERT_TRUE(b.GenerateEntryPoint(&ep)); + EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %1 "my_in" +OpName %4 "my_out" +OpName %6 "my_wg" +)"); + EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32 +%2 = OpTypePointer Input %3 +%1 = OpVariable %2 Input +%5 = OpTypePointer Output %3 +%4 = OpVariable %5 Output +%7 = OpTypePointer Workgroup %3 +%6 = OpVariable %7 Workgroup +)"); + EXPECT_EQ(DumpInstructions(b.preamble()), + R"(OpEntryPoint Vertex %3 "main" %1 %4 +)"); +} TEST_F(BuilderTest, ExecutionModel_Fragment_OriginUpperLeft) { ast::EntryPoint ep(ast::PipelineStage::kFragment, "main", "frag_main");