diff --git a/src/tint/inspector/inspector.cc b/src/tint/inspector/inspector.cc index 087e786b88..d32ef6f62b 100644 --- a/src/tint/inspector/inspector.cc +++ b/src/tint/inspector/inspector.cc @@ -148,9 +148,9 @@ EntryPoint Inspector::GetEntryPoint(const tint::ast::Function* func) { entry_point.stage = PipelineStage::kCompute; auto wgsize = sem->WorkgroupSize(); - if (!wgsize[0].overridable_const && !wgsize[1].overridable_const && - !wgsize[2].overridable_const) { - entry_point.workgroup_size = {wgsize[0].value, wgsize[1].value, wgsize[2].value}; + if (wgsize[0].has_value() && wgsize[1].has_value() && wgsize[2].has_value()) { + entry_point.workgroup_size = {wgsize[0].value(), wgsize[1].value(), + wgsize[2].value()}; } break; } @@ -849,19 +849,18 @@ void Inspector::GenerateSamplerTargets() { auto* t = c->args[static_cast(texture_index)]; auto* s = c->args[static_cast(sampler_index)]; - GetOriginatingResources( - std::array{t, s}, - [&](std::array globals) { - auto texture_binding_point = globals[0]->BindingPoint(); - auto sampler_binding_point = globals[1]->BindingPoint(); + GetOriginatingResources(std::array{t, s}, + [&](std::array globals) { + auto texture_binding_point = globals[0]->BindingPoint(); + auto sampler_binding_point = globals[1]->BindingPoint(); - for (auto* entry_point : entry_points) { - const auto& ep_name = - program_->Symbols().NameFor(entry_point->Declaration()->symbol); - (*sampler_targets_)[ep_name].Add( - {sampler_binding_point, texture_binding_point}); - } - }); + for (auto* entry_point : entry_points) { + const auto& ep_name = program_->Symbols().NameFor( + entry_point->Declaration()->symbol); + (*sampler_targets_)[ep_name].Add( + {sampler_binding_point, texture_binding_point}); + } + }); } } diff --git a/src/tint/resolver/function_validation_test.cc b/src/tint/resolver/function_validation_test.cc index 22f26205c2..cf872962e7 100644 --- a/src/tint/resolver/function_validation_test.cc +++ b/src/tint/resolver/function_validation_test.cc @@ -468,7 +468,7 @@ TEST_F(ResolverFunctionValidationTest, FunctionParamsConst) { TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_ConstU32) { // const x = 4u; - // const x = 8u; + // const y = 8u; // @compute @workgroup_size(x, y, 16u) // fn main() {} auto* x = GlobalConst("x", ty.u32(), Expr(4_u)); @@ -489,10 +489,29 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_ConstU32) { ASSERT_NE(sem_x, nullptr); ASSERT_NE(sem_y, nullptr); + EXPECT_EQ(sem_func->WorkgroupSize(), (sem::WorkgroupSize{4u, 8u, 16u})); + EXPECT_TRUE(sem_func->DirectlyReferencedGlobals().Contains(sem_x)); EXPECT_TRUE(sem_func->DirectlyReferencedGlobals().Contains(sem_y)); } +TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Cast) { + // @compute @workgroup_size(i32(5)) + // fn main() {} + auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty, + utils::Vector{ + Stage(ast::PipelineStage::kCompute), + WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), 5_a)), + }); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem_func = Sem().Get(func); + + ASSERT_NE(sem_func, nullptr); + EXPECT_EQ(sem_func->WorkgroupSize(), (sem::WorkgroupSize{5u, 1u, 1u})); +} + TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_I32) { // @compute @workgroup_size(1i, 2i, 3i) // fn main() {} @@ -651,9 +670,10 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_BadType) { }); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), - "12:34 error: workgroup_size argument must be either a literal, constant, or " - "overridable of type abstract-integer, i32 or u32"); + EXPECT_EQ( + r()->error(), + "12:34 error: workgroup_size argument must be a constant or override expression of type " + "abstract-integer, i32 or u32"); } TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_Negative) { @@ -696,9 +716,10 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_BadType) { }); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), - "12:34 error: workgroup_size argument must be either a literal, constant, or " - "overridable of type abstract-integer, i32 or u32"); + EXPECT_EQ( + r()->error(), + "12:34 error: workgroup_size argument must be a constant or override expression of type " + "abstract-integer, i32 or u32"); } TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_Negative) { @@ -759,8 +780,8 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_NonConst) { EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), - "12:34 error: workgroup_size argument must be either a literal, constant, or " - "overridable of type abstract-integer, i32 or u32"); + "12:34 error: workgroup_size argument must be a constant or override expression of " + "type abstract-integer, i32 or u32"); } TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_x) { @@ -774,8 +795,8 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_x) { EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), - "12:34 error: workgroup_size argument must be either a literal, constant, or " - "overridable of type abstract-integer, i32 or u32"); + "12:34 error: workgroup_size argument must be a constant or override expression of " + "type abstract-integer, i32 or u32"); } TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_y) { @@ -789,8 +810,8 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_y) { EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), - "12:34 error: workgroup_size argument must be either a literal, constant, or " - "overridable of type abstract-integer, i32 or u32"); + "12:34 error: workgroup_size argument must be a constant or override expression of " + "type abstract-integer, i32 or u32"); } TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_z) { @@ -804,8 +825,8 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_z) { EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), - "12:34 error: workgroup_size argument must be either a literal, constant, or " - "overridable of type abstract-integer, i32 or u32"); + "12:34 error: workgroup_size argument must be a constant or override expression of " + "type abstract-integer, i32 or u32"); } TEST_F(ResolverFunctionValidationTest, ReturnIsConstructible_NonPlain) { diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index d57be4cf97..c1920d87a3 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -1050,8 +1050,7 @@ bool Resolver::WorkgroupSize(const ast::Function* func) { // Set work-group size defaults. sem::WorkgroupSize ws; for (size_t i = 0; i < 3; i++) { - ws[i].value = 1; - ws[i].overridable_const = nullptr; + ws[i] = 1; } auto* attr = ast::GetAttribute(func->attributes); @@ -1064,7 +1063,7 @@ bool Resolver::WorkgroupSize(const ast::Function* func) { utils::Vector arg_tys; constexpr const char* kErrBadExpr = - "workgroup_size argument must be either a literal, constant, or overridable of type " + "workgroup_size argument must be a constant or override expression of type " "abstract-integer, i32 or u32"; for (size_t i = 0; i < 3; i++) { @@ -1084,6 +1083,12 @@ bool Resolver::WorkgroupSize(const ast::Function* func) { return false; } + if (expr->Stage() != sem::EvaluationStage::kConstant && + expr->Stage() != sem::EvaluationStage::kOverride) { + AddError(kErrBadExpr, value->source); + return false; + } + args.Push(expr); arg_tys.Push(ty); } @@ -1105,47 +1110,15 @@ bool Resolver::WorkgroupSize(const ast::Function* func) { if (!materialized) { return false; } - - const sem::Constant* value = nullptr; - - if (auto* user = args[i]->As()) { - // We have an variable of a module-scope constant. - auto* decl = user->Variable()->Declaration(); - if (!decl->IsAnyOf()) { - AddError(kErrBadExpr, values[i]->source); + if (auto* value = materialized->ConstantValue()) { + if (value->As() < 1) { + AddError("workgroup_size argument must be at least 1", values[i]->source); return false; } - // Capture the constant if it is pipeline-overridable. - if (decl->Is()) { - ws[i].overridable_const = decl; - } - - if (decl->constructor) { - value = sem_.Get(decl->constructor)->ConstantValue(); - } else { - // No constructor means this value must be overriden by the user. - ws[i].value = 0; - continue; - } - } else if (values[i]->Is() || args[i]->ConstantValue()) { - value = materialized->ConstantValue(); + ws[i] = value->As(); } else { - AddError(kErrBadExpr, values[i]->source); - return false; + ws[i] = std::nullopt; } - - if (!value) { - TINT_ICE(Resolver, diagnostics_) - << "could not resolve constant workgroup_size constant value"; - continue; - } - // validator_.Validate and set the default value for this dimension. - if (value->As() < 1) { - AddError("workgroup_size argument must be at least 1", values[i]->source); - return false; - } - - ws[i].value = value->As(); } current_function_->SetWorkgroupSize(std::move(ws)); diff --git a/src/tint/resolver/resolver_test.cc b/src/tint/resolver/resolver_test.cc index 1a3c623366..79cc139ba8 100644 --- a/src/tint/resolver/resolver_test.cc +++ b/src/tint/resolver/resolver_test.cc @@ -993,12 +993,9 @@ TEST_F(ResolverTest, Function_WorkgroupSize_NotSet) { auto* func_sem = Sem().Get(func); ASSERT_NE(func_sem, nullptr); - EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 1u); - EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 1u); - EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 1u); - EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr); - EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, nullptr); - EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr); + EXPECT_EQ(func_sem->WorkgroupSize()[0], 1u); + EXPECT_EQ(func_sem->WorkgroupSize()[1], 1u); + EXPECT_EQ(func_sem->WorkgroupSize()[2], 1u); } TEST_F(ResolverTest, Function_WorkgroupSize_Literals) { @@ -1015,12 +1012,9 @@ TEST_F(ResolverTest, Function_WorkgroupSize_Literals) { auto* func_sem = Sem().Get(func); ASSERT_NE(func_sem, nullptr); - EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 8u); - EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 2u); - EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 3u); - EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr); - EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, nullptr); - EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr); + EXPECT_EQ(func_sem->WorkgroupSize()[0], 8u); + EXPECT_EQ(func_sem->WorkgroupSize()[1], 2u); + EXPECT_EQ(func_sem->WorkgroupSize()[2], 3u); } TEST_F(ResolverTest, Function_WorkgroupSize_ViaConst) { @@ -1043,12 +1037,9 @@ TEST_F(ResolverTest, Function_WorkgroupSize_ViaConst) { auto* func_sem = Sem().Get(func); ASSERT_NE(func_sem, nullptr); - EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 16u); - EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 8u); - EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 2u); - EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr); - EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, nullptr); - EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr); + EXPECT_EQ(func_sem->WorkgroupSize()[0], 16u); + EXPECT_EQ(func_sem->WorkgroupSize()[1], 8u); + EXPECT_EQ(func_sem->WorkgroupSize()[2], 2u); } TEST_F(ResolverTest, Function_WorkgroupSize_ViaConst_NestedInitializer) { @@ -1071,12 +1062,9 @@ TEST_F(ResolverTest, Function_WorkgroupSize_ViaConst_NestedInitializer) { auto* func_sem = Sem().Get(func); ASSERT_NE(func_sem, nullptr); - EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 8u); - EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 4u); - EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 1u); - EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr); - EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, nullptr); - EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr); + EXPECT_EQ(func_sem->WorkgroupSize()[0], 8u); + EXPECT_EQ(func_sem->WorkgroupSize()[1], 4u); + EXPECT_EQ(func_sem->WorkgroupSize()[2], 1u); } TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts) { @@ -1085,9 +1073,9 @@ TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts) { // @id(2) override depth = 2i; // @compute @workgroup_size(width, height, depth) // fn main() {} - auto* width = Override("width", ty.i32(), Expr(16_i), Id(0_a)); - auto* height = Override("height", ty.i32(), Expr(8_i), Id(1_a)); - auto* depth = Override("depth", ty.i32(), Expr(2_i), Id(2_a)); + Override("width", ty.i32(), Expr(16_i), Id(0_a)); + Override("height", ty.i32(), Expr(8_i), Id(1_a)); + Override("depth", ty.i32(), Expr(2_i), Id(2_a)); auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty, utils::Vector{ Stage(ast::PipelineStage::kCompute), @@ -1099,12 +1087,9 @@ TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts) { auto* func_sem = Sem().Get(func); ASSERT_NE(func_sem, nullptr); - EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 16u); - EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 8u); - EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 2u); - EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, width); - EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, height); - EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, depth); + EXPECT_EQ(func_sem->WorkgroupSize()[0], std::nullopt); + EXPECT_EQ(func_sem->WorkgroupSize()[1], std::nullopt); + EXPECT_EQ(func_sem->WorkgroupSize()[2], std::nullopt); } TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts_NoInit) { @@ -1113,9 +1098,9 @@ TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts_NoInit) { // @id(2) override depth : i32; // @compute @workgroup_size(width, height, depth) // fn main() {} - auto* width = Override("width", ty.i32(), Id(0_a)); - auto* height = Override("height", ty.i32(), Id(1_a)); - auto* depth = Override("depth", ty.i32(), Id(2_a)); + Override("width", ty.i32(), Id(0_a)); + Override("height", ty.i32(), Id(1_a)); + Override("depth", ty.i32(), Id(2_a)); auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty, utils::Vector{ Stage(ast::PipelineStage::kCompute), @@ -1127,12 +1112,9 @@ TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts_NoInit) { auto* func_sem = Sem().Get(func); ASSERT_NE(func_sem, nullptr); - EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 0u); - EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 0u); - EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 0u); - EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, width); - EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, height); - EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, depth); + EXPECT_EQ(func_sem->WorkgroupSize()[0], std::nullopt); + EXPECT_EQ(func_sem->WorkgroupSize()[1], std::nullopt); + EXPECT_EQ(func_sem->WorkgroupSize()[2], std::nullopt); } TEST_F(ResolverTest, Function_WorkgroupSize_Mixed) { @@ -1140,7 +1122,7 @@ TEST_F(ResolverTest, Function_WorkgroupSize_Mixed) { // const depth = 3i; // @compute @workgroup_size(8, height, depth) // fn main() {} - auto* height = Override("height", ty.i32(), Expr(2_i), Id(0_a)); + Override("height", ty.i32(), Expr(2_i), Id(0_a)); GlobalConst("depth", ty.i32(), Expr(3_i)); auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty, utils::Vector{ @@ -1153,12 +1135,9 @@ TEST_F(ResolverTest, Function_WorkgroupSize_Mixed) { auto* func_sem = Sem().Get(func); ASSERT_NE(func_sem, nullptr); - EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 8u); - EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 2u); - EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 3u); - EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr); - EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, height); - EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr); + EXPECT_EQ(func_sem->WorkgroupSize()[0], 8u); + EXPECT_EQ(func_sem->WorkgroupSize()[1], std::nullopt); + EXPECT_EQ(func_sem->WorkgroupSize()[2], 3u); } TEST_F(ResolverTest, Expr_MemberAccessor_Struct) { diff --git a/src/tint/sem/function.cc b/src/tint/sem/function.cc index 6562526e39..fc7809be63 100644 --- a/src/tint/sem/function.cc +++ b/src/tint/sem/function.cc @@ -44,7 +44,7 @@ Function::Function(const ast::Function* declaration, utils::VectorRef parameters) : Base(return_type, SetOwner(std::move(parameters), this), EvaluationStage::kRuntime), declaration_(declaration), - workgroup_size_{WorkgroupDimension{1}, WorkgroupDimension{1}, WorkgroupDimension{1}}, + workgroup_size_{1, 1, 1}, return_location_(return_location) {} Function::~Function() = default; diff --git a/src/tint/sem/function.h b/src/tint/sem/function.h index 3f7256ae4c..50d853c73f 100644 --- a/src/tint/sem/function.h +++ b/src/tint/sem/function.h @@ -39,18 +39,10 @@ class Variable; namespace tint::sem { -/// WorkgroupDimension describes the size of a single dimension of an entry -/// point's workgroup size. -struct WorkgroupDimension { - /// The size of this dimension. - uint32_t value; - /// A pipeline-overridable constant that overrides the size, or nullptr if - /// this dimension is not overridable. - const ast::Variable* overridable_const = nullptr; -}; - /// WorkgroupSize is a three-dimensional array of WorkgroupDimensions. -using WorkgroupSize = std::array; +/// Each dimension is a std::optional as a workgroup size can be a constant or override expression. +/// Override expressions are not known at compilation time, so these will be std::nullopt. +using WorkgroupSize = std::array, 3>; /// Function holds the semantic information for function nodes. class Function final : public Castable { diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc index e332495dc0..21840bc290 100644 --- a/src/tint/writer/glsl/generator_impl.cc +++ b/src/tint/writer/glsl/generator_impl.cc @@ -102,7 +102,6 @@ namespace tint::writer::glsl { namespace { const char kTempNamePrefix[] = "tint_tmp"; -const char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_"; bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) { return IsAnyOf(stmts->Last()); @@ -1886,8 +1885,9 @@ bool GeneratorImpl::EmitGlobalVariable(const ast::Variable* global) { [&](const ast::Let* let) { return EmitProgramConstVariable(let); }, [&](const ast::Override*) { // Override is removed with SubstituteOverride - TINT_ICE(Writer, diagnostics_) - << "Override should have been removed by the substitute_override transform."; + diagnostics_.add_error(diag::System::Writer, + "override expressions should have been removed with the " + "SubstituteOverride transform"); return false; }, [&](const ast::Const*) { @@ -2104,16 +2104,14 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) { } out << "local_size_" << (i == 0 ? "x" : i == 1 ? "y" : "z") << " = "; - if (wgsize[i].overridable_const) { - auto* global = builder_.Sem().Get(wgsize[i].overridable_const); - if (!global->Declaration()->Is()) { - TINT_ICE(Writer, builder_.Diagnostics()) - << "expected a pipeline-overridable constant"; - } - out << kSpecConstantPrefix << global->OverrideId().value; - } else { - out << std::to_string(wgsize[i].value); + if (!wgsize[i].has_value()) { + diagnostics_.add_error( + diag::System::Writer, + "override expressions should have been removed with the SubstituteOverride " + "transform"); + return false; } + out << std::to_string(wgsize[i].value()); } out << ") in;"; } diff --git a/src/tint/writer/glsl/generator_impl_function_test.cc b/src/tint/writer/glsl/generator_impl_function_test.cc index eb0a9a8521..6afac16a86 100644 --- a/src/tint/writer/glsl/generator_impl_function_test.cc +++ b/src/tint/writer/glsl/generator_impl_function_test.cc @@ -783,6 +783,25 @@ void main() { )"); } +TEST_F(GlslGeneratorImplTest_Function, + Emit_Attribute_EntryPoint_Compute_WithWorkgroup_OverridableConst) { + Override("width", ty.i32(), Construct(ty.i32(), 2_i), Id(7_u)); + Override("height", ty.i32(), Construct(ty.i32(), 3_i), Id(8_u)); + Override("depth", ty.i32(), Construct(ty.i32(), 4_i), Id(9_u)); + Func("main", utils::Empty, ty.void_(), {}, + utils::Vector{ + Stage(ast::PipelineStage::kCompute), + WorkgroupSize("width", "height", "depth"), + }); + + GeneratorImpl& gen = Build(); + + EXPECT_FALSE(gen.Generate()) << gen.error(); + EXPECT_EQ( + gen.error(), + R"(error: override expressions should have been removed with the SubstituteOverride transform)"); +} + TEST_F(GlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) { Func("my_func", utils::Vector{Param("a", ty.array())}, ty.void_(), utils::Vector{ diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc index 31869554dc..516acdd4c0 100644 --- a/src/tint/writer/hlsl/generator_impl.cc +++ b/src/tint/writer/hlsl/generator_impl.cc @@ -81,7 +81,6 @@ namespace tint::writer::hlsl { namespace { const char kTempNamePrefix[] = "tint_tmp"; -const char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_"; const char* image_format_to_rwtexture_type(ast::TexelFormat image_format) { switch (image_format) { @@ -2842,8 +2841,9 @@ bool GeneratorImpl::EmitGlobalVariable(const ast::Variable* global) { }, [&](const ast::Override*) { // Override is removed with SubstituteOverride - TINT_ICE(Writer, diagnostics_) - << "Override should have been removed by the substitute_override transform."; + diagnostics_.add_error(diag::System::Writer, + "override expressions should have been removed with the " + "SubstituteOverride transform"); return false; }, [&](const ast::Const*) { @@ -3044,18 +3044,14 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) { if (i > 0) { out << ", "; } - - if (wgsize[i].overridable_const) { - auto* global = - builder_.Sem().Get(wgsize[i].overridable_const); - if (!global->Declaration()->Is()) { - TINT_ICE(Writer, diagnostics_) - << "expected a pipeline-overridable constant"; - } - out << kSpecConstantPrefix << global->OverrideId().value; - } else { - out << std::to_string(wgsize[i].value); + if (!wgsize[i].has_value()) { + diagnostics_.add_error( + diag::System::Writer, + "override expressions should have been removed with the SubstituteOverride " + "transform"); + return false; } + out << std::to_string(wgsize[i].value()); } out << ")]" << std::endl; } diff --git a/src/tint/writer/hlsl/generator_impl_function_test.cc b/src/tint/writer/hlsl/generator_impl_function_test.cc index 5c167fae44..322b56074a 100644 --- a/src/tint/writer/hlsl/generator_impl_function_test.cc +++ b/src/tint/writer/hlsl/generator_impl_function_test.cc @@ -712,6 +712,25 @@ void main() { )"); } +TEST_F(HlslGeneratorImplTest_Function, + Emit_Attribute_EntryPoint_Compute_WithWorkgroup_OverridableConst) { + Override("width", ty.i32(), Construct(ty.i32(), 2_i), Id(7_u)); + Override("height", ty.i32(), Construct(ty.i32(), 3_i), Id(8_u)); + Override("depth", ty.i32(), Construct(ty.i32(), 4_i), Id(9_u)); + Func("main", utils::Empty, ty.void_(), utils::Empty, + utils::Vector{ + Stage(ast::PipelineStage::kCompute), + WorkgroupSize("width", "height", "depth"), + }); + + GeneratorImpl& gen = Build(); + + EXPECT_FALSE(gen.Generate()) << gen.error(); + EXPECT_EQ( + gen.error(), + R"(error: override expressions should have been removed with the SubstituteOverride transform)"); +} + TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) { Func("my_func", utils::Vector{ diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc index 5ff57519cc..d1bab39473 100644 --- a/src/tint/writer/msl/generator_impl.cc +++ b/src/tint/writer/msl/generator_impl.cc @@ -273,8 +273,9 @@ bool GeneratorImpl::Generate() { }, [&](const ast::Override*) { // Override is removed with SubstituteOverride - TINT_ICE(Writer, diagnostics_) - << "Override should have been removed by the substitute_override transform."; + diagnostics_.add_error(diag::System::Writer, + "override expressions should have been removed with the " + "SubstituteOverride transform."); return false; }, [&](const ast::Function* func) { diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc index f51d808a97..05f371678d 100644 --- a/src/tint/writer/spirv/builder.cc +++ b/src/tint/writer/spirv/builder.cc @@ -506,13 +506,17 @@ bool Builder::GenerateExecutionModes(const ast::Function* func, uint32_t id) { } else if (func->PipelineStage() == ast::PipelineStage::kCompute) { auto& wgsize = func_sem->WorkgroupSize(); - // SubstituteOverride replaced all overrides with constants. - uint32_t x = wgsize[0].value; - uint32_t y = wgsize[1].value; - uint32_t z = wgsize[2].value; - push_execution_mode(spv::Op::OpExecutionMode, - {Operand(id), U32Operand(SpvExecutionModeLocalSize), Operand(x), - Operand(y), Operand(z)}); + // Check if the workgroup_size uses pipeline-overridable constants. + if (!wgsize[0].has_value() || !wgsize[1].has_value() || !wgsize[2].has_value()) { + error_ = + "override expressions should have been removed with the SubstituteOverride " + "transform"; + return false; + } + push_execution_mode( + spv::Op::OpExecutionMode, + {Operand(id), U32Operand(SpvExecutionModeLocalSize), // + Operand(wgsize[0].value()), Operand(wgsize[1].value()), Operand(wgsize[2].value())}); } for (auto builtin : func_sem->TransitivelyReferencedBuiltinVariables()) { diff --git a/src/tint/writer/spirv/builder_function_attribute_test.cc b/src/tint/writer/spirv/builder_function_attribute_test.cc index 9e6c338814..60bf062a29 100644 --- a/src/tint/writer/spirv/builder_function_attribute_test.cc +++ b/src/tint/writer/spirv/builder_function_attribute_test.cc @@ -149,6 +149,41 @@ TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize_Const) { )"); } +TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize_OverridableConst) { + Override("width", ty.i32(), Construct(ty.i32(), 2_i), Id(7_u)); + Override("height", ty.i32(), Construct(ty.i32(), 3_i), Id(8_u)); + Override("depth", ty.i32(), Construct(ty.i32(), 4_i), Id(9_u)); + auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty, + utils::Vector{ + WorkgroupSize("width", "height", "depth"), + Stage(ast::PipelineStage::kCompute), + }); + + spirv::Builder& b = Build(); + + EXPECT_FALSE(b.GenerateExecutionModes(func, 3)) << b.error(); + EXPECT_EQ( + b.error(), + R"(override expressions should have been removed with the SubstituteOverride transform)"); +} + +TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize_LiteralAndConst) { + Override("height", ty.i32(), Construct(ty.i32(), 2_i), Id(7_u)); + GlobalConst("depth", ty.i32(), Construct(ty.i32(), 3_i)); + auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty, + utils::Vector{ + WorkgroupSize(4_i, "height", "depth"), + Stage(ast::PipelineStage::kCompute), + }); + + spirv::Builder& b = Build(); + + EXPECT_FALSE(b.GenerateExecutionModes(func, 3)) << b.error(); + EXPECT_EQ( + b.error(), + R"(override expressions should have been removed with the SubstituteOverride transform)"); +} + TEST_F(BuilderTest, Decoration_ExecutionMode_MultipleFragment) { auto* func1 = Func("main1", utils::Empty, ty.void_(), utils::Empty, utils::Vector{