diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index 9d58070d89..9db64f013e 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -47,6 +47,7 @@ const char kOutStructNameSuffix[] = "out"; const char kTintStructInVarPrefix[] = "tint_in"; const char kTintStructOutVarPrefix[] = "tint_out"; const char kTempNamePrefix[] = "tint_tmp"; +const char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_"; bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) { if (stmts->empty()) { @@ -1994,18 +1995,26 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out, auto* func_sem = builder_.Sem().Get(func); if (func->pipeline_stage() == ast::PipelineStage::kCompute) { + // Emit the workgroup_size attribute. auto wgsize = func_sem->workgroup_size(); - if (wgsize[0].overridable_const || wgsize[1].overridable_const || - wgsize[2].overridable_const) { - // TODO(crbug.com/tint/713): Handle overridable constants. - TINT_UNIMPLEMENTED(builder_.Diagnostics()) - << "pipeline-overridable workgroup sizes are not implemented"; + out << "[numthreads("; + for (int i = 0; i < 3; i++) { + if (i > 0) { + out << ", "; + } + + if (wgsize[i].overridable_const) { + auto* sem_const = builder_.Sem().Get(wgsize[i].overridable_const); + if (!sem_const->IsPipelineConstant()) { + TINT_ICE(builder_.Diagnostics()) + << "expected a pipeline-overridable constant"; + } + out << kSpecConstantPrefix << sem_const->ConstantId(); + } else { + out << std::to_string(wgsize[i].value); + } } - uint32_t x = wgsize[0].value; - uint32_t y = wgsize[1].value; - uint32_t z = wgsize[2].value; - out << "[numthreads(" << std::to_string(x) << ", " << std::to_string(y) - << ", " << std::to_string(z) << ")]" << std::endl; + out << ")]" << std::endl; make_indent(out); } @@ -2721,10 +2730,10 @@ bool GeneratorImpl::EmitProgramConstVariable(std::ostream& out, if (sem->IsPipelineConstant()) { auto const_id = sem->ConstantId(); - out << "#ifndef WGSL_SPEC_CONSTANT_" << const_id << std::endl; + out << "#ifndef " << kSpecConstantPrefix << const_id << std::endl; if (var->constructor() != nullptr) { - out << "#define WGSL_SPEC_CONSTANT_" << const_id << " " + out << "#define " << kSpecConstantPrefix << const_id << " " << constructor_out.str() << std::endl; } else { out << "#error spec constant required for constant id " << const_id @@ -2736,9 +2745,8 @@ bool GeneratorImpl::EmitProgramConstVariable(std::ostream& out, builder_.Symbols().NameFor(var->symbol()))) { return false; } - out << " " << builder_.Symbols().NameFor(var->symbol()) - << " = WGSL_SPEC_CONSTANT_" << const_id << ";" << std::endl; - out << "#undef WGSL_SPEC_CONSTANT_" << const_id << std::endl; + out << " " << builder_.Symbols().NameFor(var->symbol()) << " = " + << kSpecConstantPrefix << const_id << ";" << std::endl; } else { out << "static const "; if (!EmitType(out, type, sem->StorageClass(), sem->AccessControl(), diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc index 502b60df92..9ee9788229 100644 --- a/src/writer/hlsl/generator_impl_function_test.cc +++ b/src/writer/hlsl/generator_impl_function_test.cc @@ -919,11 +919,8 @@ void main() { } TEST_F(HlslGeneratorImplTest_Function, - Emit_Decoration_EntryPoint_Compute_WithWorkgroup) { - Func("main", ast::VariableList{}, ty.void_(), - { - Return(), - }, + Emit_Decoration_EntryPoint_Compute_WithWorkgroup_Literal) { + Func("main", ast::VariableList{}, ty.void_(), {}, { Stage(ast::PipelineStage::kCompute), WorkgroupSize(2, 4, 6), @@ -942,6 +939,69 @@ void main() { Validate(); } +TEST_F(HlslGeneratorImplTest_Function, + Emit_Decoration_EntryPoint_Compute_WithWorkgroup_Const) { + GlobalConst("width", ty.i32(), Construct(ty.i32(), 2)); + GlobalConst("height", ty.i32(), Construct(ty.i32(), 3)); + GlobalConst("depth", ty.i32(), Construct(ty.i32(), 4)); + Func("main", ast::VariableList{}, ty.void_(), {}, + { + Stage(ast::PipelineStage::kCompute), + WorkgroupSize("width", "height", "depth"), + }); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate(out)) << gen.error(); + EXPECT_EQ(result(), R"(static const int width = int(2); +static const int height = int(3); +static const int depth = int(4); +[numthreads(2, 3, 4)] +void main() { + return; +} + +)"); + + Validate(); +} + +TEST_F(HlslGeneratorImplTest_Function, + Emit_Decoration_EntryPoint_Compute_WithWorkgroup_OverridableConst) { + GlobalConst("width", ty.i32(), Construct(ty.i32(), 2), {Override(7u)}); + GlobalConst("height", ty.i32(), Construct(ty.i32(), 3), {Override(8u)}); + GlobalConst("depth", ty.i32(), Construct(ty.i32(), 4), {Override(9u)}); + Func("main", ast::VariableList{}, ty.void_(), {}, + { + Stage(ast::PipelineStage::kCompute), + WorkgroupSize("width", "height", "depth"), + }); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate(out)) << gen.error(); + EXPECT_EQ(result(), R"(#ifndef WGSL_SPEC_CONSTANT_7 +#define WGSL_SPEC_CONSTANT_7 int(2) +#endif +static const int width = WGSL_SPEC_CONSTANT_7; +#ifndef WGSL_SPEC_CONSTANT_8 +#define WGSL_SPEC_CONSTANT_8 int(3) +#endif +static const int height = WGSL_SPEC_CONSTANT_8; +#ifndef WGSL_SPEC_CONSTANT_9 +#define WGSL_SPEC_CONSTANT_9 int(4) +#endif +static const int depth = WGSL_SPEC_CONSTANT_9; +[numthreads(WGSL_SPEC_CONSTANT_7, WGSL_SPEC_CONSTANT_8, WGSL_SPEC_CONSTANT_9)] +void main() { + return; +} + +)"); + + Validate(); +} + TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) { Func("my_func", ast::VariableList{Param("a", ty.array())}, ty.void_(), { diff --git a/src/writer/hlsl/generator_impl_module_constant_test.cc b/src/writer/hlsl/generator_impl_module_constant_test.cc index 6e82c94bd9..8ff9c88584 100644 --- a/src/writer/hlsl/generator_impl_module_constant_test.cc +++ b/src/writer/hlsl/generator_impl_module_constant_test.cc @@ -45,7 +45,6 @@ TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant) { #define WGSL_SPEC_CONSTANT_23 3.0f #endif static const float pos = WGSL_SPEC_CONSTANT_23; -#undef WGSL_SPEC_CONSTANT_23 )"); } @@ -62,7 +61,6 @@ TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant_NoConstructor) { #error spec constant required for constant id 23 #endif static const float pos = WGSL_SPEC_CONSTANT_23; -#undef WGSL_SPEC_CONSTANT_23 )"); } @@ -84,12 +82,10 @@ TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant_NoId) { #define WGSL_SPEC_CONSTANT_0 3.0f #endif static const float a = WGSL_SPEC_CONSTANT_0; -#undef WGSL_SPEC_CONSTANT_0 #ifndef WGSL_SPEC_CONSTANT_1 #define WGSL_SPEC_CONSTANT_1 2.0f #endif static const float b = WGSL_SPEC_CONSTANT_1; -#undef WGSL_SPEC_CONSTANT_1 )"); }