diff --git a/src/tint/writer/spirv/generator_impl_function_test.cc b/src/tint/writer/spirv/generator_impl_function_test.cc index e4cb71bdd9..bb7008ce3b 100644 --- a/src/tint/writer/spirv/generator_impl_function_test.cc +++ b/src/tint/writer/spirv/generator_impl_function_test.cc @@ -46,5 +46,109 @@ TEST_F(SpvGeneratorImplTest, Function_DeduplicateType) { )"); } +TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Compute) { + auto* func = CreateFunction(); + func->name = ir.symbols.Register("main"); + func->return_type = ir.types.Get(); + func->pipeline_stage = ir::Function::PipelineStage::kCompute; + func->workgroup_size = {32, 4, 1}; + + generator_.EmitFunction(func); + EXPECT_EQ(DumpModule(generator_.Module()), R"(OpEntryPoint GLCompute %1 "main" +OpExecutionMode %1 LocalSize 32 4 1 +OpName %1 "main" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%1 = OpFunction %2 None %3 +%4 = OpLabel +OpReturn +OpFunctionEnd +)"); +} + +TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Fragment) { + auto* func = CreateFunction(); + func->name = ir.symbols.Register("main"); + func->return_type = ir.types.Get(); + func->pipeline_stage = ir::Function::PipelineStage::kFragment; + + generator_.EmitFunction(func); + EXPECT_EQ(DumpModule(generator_.Module()), R"(OpEntryPoint Fragment %1 "main" +OpExecutionMode %1 OriginUpperLeft +OpName %1 "main" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%1 = OpFunction %2 None %3 +%4 = OpLabel +OpReturn +OpFunctionEnd +)"); +} + +TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Vertex) { + auto* func = CreateFunction(); + func->name = ir.symbols.Register("main"); + func->return_type = ir.types.Get(); + func->pipeline_stage = ir::Function::PipelineStage::kVertex; + + generator_.EmitFunction(func); + EXPECT_EQ(DumpModule(generator_.Module()), R"(OpEntryPoint Vertex %1 "main" +OpName %1 "main" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%1 = OpFunction %2 None %3 +%4 = OpLabel +OpReturn +OpFunctionEnd +)"); +} + +TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Multiple) { + auto* f1 = CreateFunction(); + f1->name = ir.symbols.Register("main1"); + f1->return_type = ir.types.Get(); + f1->pipeline_stage = ir::Function::PipelineStage::kCompute; + f1->workgroup_size = {32, 4, 1}; + + auto* f2 = CreateFunction(); + f2->name = ir.symbols.Register("main2"); + f2->return_type = ir.types.Get(); + f2->pipeline_stage = ir::Function::PipelineStage::kCompute; + f2->workgroup_size = {8, 2, 16}; + + auto* f3 = CreateFunction(); + f3->name = ir.symbols.Register("main3"); + f3->return_type = ir.types.Get(); + f3->pipeline_stage = ir::Function::PipelineStage::kFragment; + + generator_.EmitFunction(f1); + generator_.EmitFunction(f2); + generator_.EmitFunction(f3); + EXPECT_EQ(DumpModule(generator_.Module()), R"(OpEntryPoint GLCompute %1 "main1" +OpEntryPoint GLCompute %5 "main2" +OpEntryPoint Fragment %7 "main3" +OpExecutionMode %1 LocalSize 32 4 1 +OpExecutionMode %5 LocalSize 8 2 16 +OpExecutionMode %7 OriginUpperLeft +OpName %1 "main1" +OpName %5 "main2" +OpName %7 "main3" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%1 = OpFunction %2 None %3 +%4 = OpLabel +OpReturn +OpFunctionEnd +%5 = OpFunction %2 None %3 +%6 = OpLabel +OpReturn +OpFunctionEnd +%7 = OpFunction %2 None %3 +%8 = OpLabel +OpReturn +OpFunctionEnd +)"); +} + } // namespace } // namespace tint::writer::spirv diff --git a/src/tint/writer/spirv/generator_impl_ir.cc b/src/tint/writer/spirv/generator_impl_ir.cc index e947e6c419..96820ccc91 100644 --- a/src/tint/writer/spirv/generator_impl_ir.cc +++ b/src/tint/writer/spirv/generator_impl_ir.cc @@ -88,7 +88,10 @@ void GeneratorImplIr::EmitFunction(const ir::Function* func) { // Emit the function name. module_.PushDebug(spv::Op::OpName, {id, Operand(func->name.Name())}); - // TODO(jrprice): Emit OpEntryPoint and OpExecutionMode declarations if needed. + // Emit OpEntryPoint and OpExecutionMode declarations if needed. + if (func->pipeline_stage != ir::Function::PipelineStage::kUndefined) { + EmitEntryPoint(func, id); + } // Get the ID for the return type. auto return_type_id = Type(func->return_type); @@ -123,4 +126,34 @@ void GeneratorImplIr::EmitFunction(const ir::Function* func) { module_.PushFunction(current_function_); } +void GeneratorImplIr::EmitEntryPoint(const ir::Function* func, uint32_t id) { + SpvExecutionModel stage; + switch (func->pipeline_stage) { + case ir::Function::PipelineStage::kCompute: { + stage = SpvExecutionModelGLCompute; + module_.PushExecutionMode(spv::Op::OpExecutionMode, + {id, SpvExecutionModeLocalSize, func->workgroup_size->at(0), + func->workgroup_size->at(1), func->workgroup_size->at(2)}); + break; + } + case ir::Function::PipelineStage::kFragment: { + stage = SpvExecutionModelFragment; + module_.PushExecutionMode(spv::Op::OpExecutionMode, + {id, SpvExecutionModeOriginUpperLeft}); + // TODO(jrprice): Add DepthReplacing execution mode if FragDepth is used. + break; + } + case ir::Function::PipelineStage::kVertex: { + stage = SpvExecutionModelVertex; + break; + } + case ir::Function::PipelineStage::kUndefined: + TINT_ICE(Writer, diagnostics_) << "undefined pipeline stage for entry point"; + return; + } + + // TODO(jrprice): Add the interface list of all referenced global variables. + module_.PushEntryPoint(spv::Op::OpEntryPoint, {stage, id, func->name.Name()}); +} + } // namespace tint::writer::spirv diff --git a/src/tint/writer/spirv/generator_impl_ir.h b/src/tint/writer/spirv/generator_impl_ir.h index d9aab3cb02..0aa4e7bbd3 100644 --- a/src/tint/writer/spirv/generator_impl_ir.h +++ b/src/tint/writer/spirv/generator_impl_ir.h @@ -64,6 +64,11 @@ class GeneratorImplIr { /// @param func the function to emit void EmitFunction(const ir::Function* func); + /// Emit entry point declarations for a function. + /// @param func the function to emit entry point declarations for + /// @param id the result ID of the function declaration + void EmitEntryPoint(const ir::Function* func, uint32_t id); + private: const ir::Module* ir_; spirv::Module module_;