tint: Simplify workgroup size resolving

A `@workgroup_size()` value must be a constant or override expression.
There's nothing specific here about literals or variable expressions.

Remove the semantic tracking of override variables, as these can be override expressions.
The backends will require the `SubstituteOverride` transform to be run, so gut the workgroup_size override handling from the backends.

Bug: tint:1633
Change-Id: Ib3ff843fc64a3595d49223c661b4d58130c0ab30
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/100142
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
Ben Clayton 2022-09-21 21:05:45 +00:00 committed by Dawn LUCI CQ
parent 45a2c5193a
commit 490d9889a7
13 changed files with 202 additions and 166 deletions

View File

@ -148,9 +148,9 @@ EntryPoint Inspector::GetEntryPoint(const tint::ast::Function* func) {
entry_point.stage = PipelineStage::kCompute; entry_point.stage = PipelineStage::kCompute;
auto wgsize = sem->WorkgroupSize(); auto wgsize = sem->WorkgroupSize();
if (!wgsize[0].overridable_const && !wgsize[1].overridable_const && if (wgsize[0].has_value() && wgsize[1].has_value() && wgsize[2].has_value()) {
!wgsize[2].overridable_const) { entry_point.workgroup_size = {wgsize[0].value(), wgsize[1].value(),
entry_point.workgroup_size = {wgsize[0].value, wgsize[1].value, wgsize[2].value}; wgsize[2].value()};
} }
break; break;
} }
@ -849,15 +849,14 @@ void Inspector::GenerateSamplerTargets() {
auto* t = c->args[static_cast<size_t>(texture_index)]; auto* t = c->args[static_cast<size_t>(texture_index)];
auto* s = c->args[static_cast<size_t>(sampler_index)]; auto* s = c->args[static_cast<size_t>(sampler_index)];
GetOriginatingResources( GetOriginatingResources(std::array<const ast::Expression*, 2>{t, s},
std::array<const ast::Expression*, 2>{t, s},
[&](std::array<const sem::GlobalVariable*, 2> globals) { [&](std::array<const sem::GlobalVariable*, 2> globals) {
auto texture_binding_point = globals[0]->BindingPoint(); auto texture_binding_point = globals[0]->BindingPoint();
auto sampler_binding_point = globals[1]->BindingPoint(); auto sampler_binding_point = globals[1]->BindingPoint();
for (auto* entry_point : entry_points) { for (auto* entry_point : entry_points) {
const auto& ep_name = const auto& ep_name = program_->Symbols().NameFor(
program_->Symbols().NameFor(entry_point->Declaration()->symbol); entry_point->Declaration()->symbol);
(*sampler_targets_)[ep_name].Add( (*sampler_targets_)[ep_name].Add(
{sampler_binding_point, texture_binding_point}); {sampler_binding_point, texture_binding_point});
} }

View File

@ -468,7 +468,7 @@ TEST_F(ResolverFunctionValidationTest, FunctionParamsConst) {
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_ConstU32) { TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_ConstU32) {
// const x = 4u; // const x = 4u;
// const x = 8u; // const y = 8u;
// @compute @workgroup_size(x, y, 16u) // @compute @workgroup_size(x, y, 16u)
// fn main() {} // fn main() {}
auto* x = GlobalConst("x", ty.u32(), Expr(4_u)); auto* x = GlobalConst("x", ty.u32(), Expr(4_u));
@ -489,10 +489,29 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_ConstU32) {
ASSERT_NE(sem_x, nullptr); ASSERT_NE(sem_x, nullptr);
ASSERT_NE(sem_y, nullptr); ASSERT_NE(sem_y, nullptr);
EXPECT_EQ(sem_func->WorkgroupSize(), (sem::WorkgroupSize{4u, 8u, 16u}));
EXPECT_TRUE(sem_func->DirectlyReferencedGlobals().Contains(sem_x)); EXPECT_TRUE(sem_func->DirectlyReferencedGlobals().Contains(sem_x));
EXPECT_TRUE(sem_func->DirectlyReferencedGlobals().Contains(sem_y)); EXPECT_TRUE(sem_func->DirectlyReferencedGlobals().Contains(sem_y));
} }
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Cast) {
// @compute @workgroup_size(i32(5))
// fn main() {}
auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
Stage(ast::PipelineStage::kCompute),
WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), 5_a)),
});
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem_func = Sem().Get(func);
ASSERT_NE(sem_func, nullptr);
EXPECT_EQ(sem_func->WorkgroupSize(), (sem::WorkgroupSize{5u, 1u, 1u}));
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_I32) { TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_I32) {
// @compute @workgroup_size(1i, 2i, 3i) // @compute @workgroup_size(1i, 2i, 3i)
// fn main() {} // fn main() {}
@ -651,9 +670,10 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_BadType) {
}); });
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_EQ(
"12:34 error: workgroup_size argument must be either a literal, constant, or " r()->error(),
"overridable of type abstract-integer, i32 or u32"); "12:34 error: workgroup_size argument must be a constant or override expression of type "
"abstract-integer, i32 or u32");
} }
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_Negative) { TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_Negative) {
@ -696,9 +716,10 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_BadType) {
}); });
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_EQ(
"12:34 error: workgroup_size argument must be either a literal, constant, or " r()->error(),
"overridable of type abstract-integer, i32 or u32"); "12:34 error: workgroup_size argument must be a constant or override expression of type "
"abstract-integer, i32 or u32");
} }
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_Negative) { TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_Negative) {
@ -759,8 +780,8 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_NonConst) {
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_EQ(r()->error(),
"12:34 error: workgroup_size argument must be either a literal, constant, or " "12:34 error: workgroup_size argument must be a constant or override expression of "
"overridable of type abstract-integer, i32 or u32"); "type abstract-integer, i32 or u32");
} }
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_x) { TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_x) {
@ -774,8 +795,8 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_x) {
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_EQ(r()->error(),
"12:34 error: workgroup_size argument must be either a literal, constant, or " "12:34 error: workgroup_size argument must be a constant or override expression of "
"overridable of type abstract-integer, i32 or u32"); "type abstract-integer, i32 or u32");
} }
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_y) { TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_y) {
@ -789,8 +810,8 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_y) {
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_EQ(r()->error(),
"12:34 error: workgroup_size argument must be either a literal, constant, or " "12:34 error: workgroup_size argument must be a constant or override expression of "
"overridable of type abstract-integer, i32 or u32"); "type abstract-integer, i32 or u32");
} }
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_z) { TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_z) {
@ -804,8 +825,8 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_z) {
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_EQ(r()->error(),
"12:34 error: workgroup_size argument must be either a literal, constant, or " "12:34 error: workgroup_size argument must be a constant or override expression of "
"overridable of type abstract-integer, i32 or u32"); "type abstract-integer, i32 or u32");
} }
TEST_F(ResolverFunctionValidationTest, ReturnIsConstructible_NonPlain) { TEST_F(ResolverFunctionValidationTest, ReturnIsConstructible_NonPlain) {

View File

@ -1050,8 +1050,7 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
// Set work-group size defaults. // Set work-group size defaults.
sem::WorkgroupSize ws; sem::WorkgroupSize ws;
for (size_t i = 0; i < 3; i++) { for (size_t i = 0; i < 3; i++) {
ws[i].value = 1; ws[i] = 1;
ws[i].overridable_const = nullptr;
} }
auto* attr = ast::GetAttribute<ast::WorkgroupAttribute>(func->attributes); auto* attr = ast::GetAttribute<ast::WorkgroupAttribute>(func->attributes);
@ -1064,7 +1063,7 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
utils::Vector<const sem::Type*, 3> arg_tys; utils::Vector<const sem::Type*, 3> arg_tys;
constexpr const char* kErrBadExpr = constexpr const char* kErrBadExpr =
"workgroup_size argument must be either a literal, constant, or overridable of type " "workgroup_size argument must be a constant or override expression of type "
"abstract-integer, i32 or u32"; "abstract-integer, i32 or u32";
for (size_t i = 0; i < 3; i++) { for (size_t i = 0; i < 3; i++) {
@ -1084,6 +1083,12 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
return false; return false;
} }
if (expr->Stage() != sem::EvaluationStage::kConstant &&
expr->Stage() != sem::EvaluationStage::kOverride) {
AddError(kErrBadExpr, value->source);
return false;
}
args.Push(expr); args.Push(expr);
arg_tys.Push(ty); arg_tys.Push(ty);
} }
@ -1105,47 +1110,15 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
if (!materialized) { if (!materialized) {
return false; return false;
} }
if (auto* value = materialized->ConstantValue()) {
const sem::Constant* value = nullptr;
if (auto* user = args[i]->As<sem::VariableUser>()) {
// We have an variable of a module-scope constant.
auto* decl = user->Variable()->Declaration();
if (!decl->IsAnyOf<ast::Const, ast::Override>()) {
AddError(kErrBadExpr, values[i]->source);
return false;
}
// Capture the constant if it is pipeline-overridable.
if (decl->Is<ast::Override>()) {
ws[i].overridable_const = decl;
}
if (decl->constructor) {
value = sem_.Get(decl->constructor)->ConstantValue();
} else {
// No constructor means this value must be overriden by the user.
ws[i].value = 0;
continue;
}
} else if (values[i]->Is<ast::LiteralExpression>() || args[i]->ConstantValue()) {
value = materialized->ConstantValue();
} else {
AddError(kErrBadExpr, values[i]->source);
return false;
}
if (!value) {
TINT_ICE(Resolver, diagnostics_)
<< "could not resolve constant workgroup_size constant value";
continue;
}
// validator_.Validate and set the default value for this dimension.
if (value->As<AInt>() < 1) { if (value->As<AInt>() < 1) {
AddError("workgroup_size argument must be at least 1", values[i]->source); AddError("workgroup_size argument must be at least 1", values[i]->source);
return false; return false;
} }
ws[i] = value->As<uint32_t>();
ws[i].value = value->As<uint32_t>(); } else {
ws[i] = std::nullopt;
}
} }
current_function_->SetWorkgroupSize(std::move(ws)); current_function_->SetWorkgroupSize(std::move(ws));

View File

@ -993,12 +993,9 @@ TEST_F(ResolverTest, Function_WorkgroupSize_NotSet) {
auto* func_sem = Sem().Get(func); auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr); ASSERT_NE(func_sem, nullptr);
EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 1u); EXPECT_EQ(func_sem->WorkgroupSize()[0], 1u);
EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 1u); EXPECT_EQ(func_sem->WorkgroupSize()[1], 1u);
EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 1u); EXPECT_EQ(func_sem->WorkgroupSize()[2], 1u);
EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr);
EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, nullptr);
EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr);
} }
TEST_F(ResolverTest, Function_WorkgroupSize_Literals) { TEST_F(ResolverTest, Function_WorkgroupSize_Literals) {
@ -1015,12 +1012,9 @@ TEST_F(ResolverTest, Function_WorkgroupSize_Literals) {
auto* func_sem = Sem().Get(func); auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr); ASSERT_NE(func_sem, nullptr);
EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 8u); EXPECT_EQ(func_sem->WorkgroupSize()[0], 8u);
EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 2u); EXPECT_EQ(func_sem->WorkgroupSize()[1], 2u);
EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 3u); EXPECT_EQ(func_sem->WorkgroupSize()[2], 3u);
EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr);
EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, nullptr);
EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr);
} }
TEST_F(ResolverTest, Function_WorkgroupSize_ViaConst) { TEST_F(ResolverTest, Function_WorkgroupSize_ViaConst) {
@ -1043,12 +1037,9 @@ TEST_F(ResolverTest, Function_WorkgroupSize_ViaConst) {
auto* func_sem = Sem().Get(func); auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr); ASSERT_NE(func_sem, nullptr);
EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 16u); EXPECT_EQ(func_sem->WorkgroupSize()[0], 16u);
EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 8u); EXPECT_EQ(func_sem->WorkgroupSize()[1], 8u);
EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 2u); EXPECT_EQ(func_sem->WorkgroupSize()[2], 2u);
EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr);
EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, nullptr);
EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr);
} }
TEST_F(ResolverTest, Function_WorkgroupSize_ViaConst_NestedInitializer) { TEST_F(ResolverTest, Function_WorkgroupSize_ViaConst_NestedInitializer) {
@ -1071,12 +1062,9 @@ TEST_F(ResolverTest, Function_WorkgroupSize_ViaConst_NestedInitializer) {
auto* func_sem = Sem().Get(func); auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr); ASSERT_NE(func_sem, nullptr);
EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 8u); EXPECT_EQ(func_sem->WorkgroupSize()[0], 8u);
EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 4u); EXPECT_EQ(func_sem->WorkgroupSize()[1], 4u);
EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 1u); EXPECT_EQ(func_sem->WorkgroupSize()[2], 1u);
EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr);
EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, nullptr);
EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr);
} }
TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts) { TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts) {
@ -1085,9 +1073,9 @@ TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts) {
// @id(2) override depth = 2i; // @id(2) override depth = 2i;
// @compute @workgroup_size(width, height, depth) // @compute @workgroup_size(width, height, depth)
// fn main() {} // fn main() {}
auto* width = Override("width", ty.i32(), Expr(16_i), Id(0_a)); Override("width", ty.i32(), Expr(16_i), Id(0_a));
auto* height = Override("height", ty.i32(), Expr(8_i), Id(1_a)); Override("height", ty.i32(), Expr(8_i), Id(1_a));
auto* depth = Override("depth", ty.i32(), Expr(2_i), Id(2_a)); Override("depth", ty.i32(), Expr(2_i), Id(2_a));
auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty, auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{ utils::Vector{
Stage(ast::PipelineStage::kCompute), Stage(ast::PipelineStage::kCompute),
@ -1099,12 +1087,9 @@ TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts) {
auto* func_sem = Sem().Get(func); auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr); ASSERT_NE(func_sem, nullptr);
EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 16u); EXPECT_EQ(func_sem->WorkgroupSize()[0], std::nullopt);
EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 8u); EXPECT_EQ(func_sem->WorkgroupSize()[1], std::nullopt);
EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 2u); EXPECT_EQ(func_sem->WorkgroupSize()[2], std::nullopt);
EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, width);
EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, height);
EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, depth);
} }
TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts_NoInit) { TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts_NoInit) {
@ -1113,9 +1098,9 @@ TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts_NoInit) {
// @id(2) override depth : i32; // @id(2) override depth : i32;
// @compute @workgroup_size(width, height, depth) // @compute @workgroup_size(width, height, depth)
// fn main() {} // fn main() {}
auto* width = Override("width", ty.i32(), Id(0_a)); Override("width", ty.i32(), Id(0_a));
auto* height = Override("height", ty.i32(), Id(1_a)); Override("height", ty.i32(), Id(1_a));
auto* depth = Override("depth", ty.i32(), Id(2_a)); Override("depth", ty.i32(), Id(2_a));
auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty, auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{ utils::Vector{
Stage(ast::PipelineStage::kCompute), Stage(ast::PipelineStage::kCompute),
@ -1127,12 +1112,9 @@ TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts_NoInit) {
auto* func_sem = Sem().Get(func); auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr); ASSERT_NE(func_sem, nullptr);
EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 0u); EXPECT_EQ(func_sem->WorkgroupSize()[0], std::nullopt);
EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 0u); EXPECT_EQ(func_sem->WorkgroupSize()[1], std::nullopt);
EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 0u); EXPECT_EQ(func_sem->WorkgroupSize()[2], std::nullopt);
EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, width);
EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, height);
EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, depth);
} }
TEST_F(ResolverTest, Function_WorkgroupSize_Mixed) { TEST_F(ResolverTest, Function_WorkgroupSize_Mixed) {
@ -1140,7 +1122,7 @@ TEST_F(ResolverTest, Function_WorkgroupSize_Mixed) {
// const depth = 3i; // const depth = 3i;
// @compute @workgroup_size(8, height, depth) // @compute @workgroup_size(8, height, depth)
// fn main() {} // fn main() {}
auto* height = Override("height", ty.i32(), Expr(2_i), Id(0_a)); Override("height", ty.i32(), Expr(2_i), Id(0_a));
GlobalConst("depth", ty.i32(), Expr(3_i)); GlobalConst("depth", ty.i32(), Expr(3_i));
auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty, auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{ utils::Vector{
@ -1153,12 +1135,9 @@ TEST_F(ResolverTest, Function_WorkgroupSize_Mixed) {
auto* func_sem = Sem().Get(func); auto* func_sem = Sem().Get(func);
ASSERT_NE(func_sem, nullptr); ASSERT_NE(func_sem, nullptr);
EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 8u); EXPECT_EQ(func_sem->WorkgroupSize()[0], 8u);
EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 2u); EXPECT_EQ(func_sem->WorkgroupSize()[1], std::nullopt);
EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 3u); EXPECT_EQ(func_sem->WorkgroupSize()[2], 3u);
EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr);
EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, height);
EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr);
} }
TEST_F(ResolverTest, Expr_MemberAccessor_Struct) { TEST_F(ResolverTest, Expr_MemberAccessor_Struct) {

View File

@ -44,7 +44,7 @@ Function::Function(const ast::Function* declaration,
utils::VectorRef<Parameter*> parameters) utils::VectorRef<Parameter*> parameters)
: Base(return_type, SetOwner(std::move(parameters), this), EvaluationStage::kRuntime), : Base(return_type, SetOwner(std::move(parameters), this), EvaluationStage::kRuntime),
declaration_(declaration), declaration_(declaration),
workgroup_size_{WorkgroupDimension{1}, WorkgroupDimension{1}, WorkgroupDimension{1}}, workgroup_size_{1, 1, 1},
return_location_(return_location) {} return_location_(return_location) {}
Function::~Function() = default; Function::~Function() = default;

View File

@ -39,18 +39,10 @@ class Variable;
namespace tint::sem { namespace tint::sem {
/// WorkgroupDimension describes the size of a single dimension of an entry
/// point's workgroup size.
struct WorkgroupDimension {
/// The size of this dimension.
uint32_t value;
/// A pipeline-overridable constant that overrides the size, or nullptr if
/// this dimension is not overridable.
const ast::Variable* overridable_const = nullptr;
};
/// WorkgroupSize is a three-dimensional array of WorkgroupDimensions. /// WorkgroupSize is a three-dimensional array of WorkgroupDimensions.
using WorkgroupSize = std::array<WorkgroupDimension, 3>; /// Each dimension is a std::optional as a workgroup size can be a constant or override expression.
/// Override expressions are not known at compilation time, so these will be std::nullopt.
using WorkgroupSize = std::array<std::optional<uint32_t>, 3>;
/// Function holds the semantic information for function nodes. /// Function holds the semantic information for function nodes.
class Function final : public Castable<Function, CallTarget> { class Function final : public Castable<Function, CallTarget> {

View File

@ -102,7 +102,6 @@ namespace tint::writer::glsl {
namespace { namespace {
const char kTempNamePrefix[] = "tint_tmp"; const char kTempNamePrefix[] = "tint_tmp";
const char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) { bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
return IsAnyOf<ast::BreakStatement, ast::FallthroughStatement>(stmts->Last()); return IsAnyOf<ast::BreakStatement, ast::FallthroughStatement>(stmts->Last());
@ -1886,8 +1885,9 @@ bool GeneratorImpl::EmitGlobalVariable(const ast::Variable* global) {
[&](const ast::Let* let) { return EmitProgramConstVariable(let); }, [&](const ast::Let* let) { return EmitProgramConstVariable(let); },
[&](const ast::Override*) { [&](const ast::Override*) {
// Override is removed with SubstituteOverride // Override is removed with SubstituteOverride
TINT_ICE(Writer, diagnostics_) diagnostics_.add_error(diag::System::Writer,
<< "Override should have been removed by the substitute_override transform."; "override expressions should have been removed with the "
"SubstituteOverride transform");
return false; return false;
}, },
[&](const ast::Const*) { [&](const ast::Const*) {
@ -2104,16 +2104,14 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
} }
out << "local_size_" << (i == 0 ? "x" : i == 1 ? "y" : "z") << " = "; out << "local_size_" << (i == 0 ? "x" : i == 1 ? "y" : "z") << " = ";
if (wgsize[i].overridable_const) { if (!wgsize[i].has_value()) {
auto* global = builder_.Sem().Get<sem::GlobalVariable>(wgsize[i].overridable_const); diagnostics_.add_error(
if (!global->Declaration()->Is<ast::Override>()) { diag::System::Writer,
TINT_ICE(Writer, builder_.Diagnostics()) "override expressions should have been removed with the SubstituteOverride "
<< "expected a pipeline-overridable constant"; "transform");
} return false;
out << kSpecConstantPrefix << global->OverrideId().value;
} else {
out << std::to_string(wgsize[i].value);
} }
out << std::to_string(wgsize[i].value());
} }
out << ") in;"; out << ") in;";
} }

View File

@ -783,6 +783,25 @@ void main() {
)"); )");
} }
TEST_F(GlslGeneratorImplTest_Function,
Emit_Attribute_EntryPoint_Compute_WithWorkgroup_OverridableConst) {
Override("width", ty.i32(), Construct(ty.i32(), 2_i), Id(7_u));
Override("height", ty.i32(), Construct(ty.i32(), 3_i), Id(8_u));
Override("depth", ty.i32(), Construct(ty.i32(), 4_i), Id(9_u));
Func("main", utils::Empty, ty.void_(), {},
utils::Vector{
Stage(ast::PipelineStage::kCompute),
WorkgroupSize("width", "height", "depth"),
});
GeneratorImpl& gen = Build();
EXPECT_FALSE(gen.Generate()) << gen.error();
EXPECT_EQ(
gen.error(),
R"(error: override expressions should have been removed with the SubstituteOverride transform)");
}
TEST_F(GlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) { TEST_F(GlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) {
Func("my_func", utils::Vector{Param("a", ty.array<f32, 5>())}, ty.void_(), Func("my_func", utils::Vector{Param("a", ty.array<f32, 5>())}, ty.void_(),
utils::Vector{ utils::Vector{

View File

@ -81,7 +81,6 @@ namespace tint::writer::hlsl {
namespace { namespace {
const char kTempNamePrefix[] = "tint_tmp"; const char kTempNamePrefix[] = "tint_tmp";
const char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
const char* image_format_to_rwtexture_type(ast::TexelFormat image_format) { const char* image_format_to_rwtexture_type(ast::TexelFormat image_format) {
switch (image_format) { switch (image_format) {
@ -2842,8 +2841,9 @@ bool GeneratorImpl::EmitGlobalVariable(const ast::Variable* global) {
}, },
[&](const ast::Override*) { [&](const ast::Override*) {
// Override is removed with SubstituteOverride // Override is removed with SubstituteOverride
TINT_ICE(Writer, diagnostics_) diagnostics_.add_error(diag::System::Writer,
<< "Override should have been removed by the substitute_override transform."; "override expressions should have been removed with the "
"SubstituteOverride transform");
return false; return false;
}, },
[&](const ast::Const*) { [&](const ast::Const*) {
@ -3044,18 +3044,14 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
if (i > 0) { if (i > 0) {
out << ", "; out << ", ";
} }
if (!wgsize[i].has_value()) {
if (wgsize[i].overridable_const) { diagnostics_.add_error(
auto* global = diag::System::Writer,
builder_.Sem().Get<sem::GlobalVariable>(wgsize[i].overridable_const); "override expressions should have been removed with the SubstituteOverride "
if (!global->Declaration()->Is<ast::Override>()) { "transform");
TINT_ICE(Writer, diagnostics_) return false;
<< "expected a pipeline-overridable constant";
}
out << kSpecConstantPrefix << global->OverrideId().value;
} else {
out << std::to_string(wgsize[i].value);
} }
out << std::to_string(wgsize[i].value());
} }
out << ")]" << std::endl; out << ")]" << std::endl;
} }

View File

@ -712,6 +712,25 @@ void main() {
)"); )");
} }
TEST_F(HlslGeneratorImplTest_Function,
Emit_Attribute_EntryPoint_Compute_WithWorkgroup_OverridableConst) {
Override("width", ty.i32(), Construct(ty.i32(), 2_i), Id(7_u));
Override("height", ty.i32(), Construct(ty.i32(), 3_i), Id(8_u));
Override("depth", ty.i32(), Construct(ty.i32(), 4_i), Id(9_u));
Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
Stage(ast::PipelineStage::kCompute),
WorkgroupSize("width", "height", "depth"),
});
GeneratorImpl& gen = Build();
EXPECT_FALSE(gen.Generate()) << gen.error();
EXPECT_EQ(
gen.error(),
R"(error: override expressions should have been removed with the SubstituteOverride transform)");
}
TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) { TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) {
Func("my_func", Func("my_func",
utils::Vector{ utils::Vector{

View File

@ -273,8 +273,9 @@ bool GeneratorImpl::Generate() {
}, },
[&](const ast::Override*) { [&](const ast::Override*) {
// Override is removed with SubstituteOverride // Override is removed with SubstituteOverride
TINT_ICE(Writer, diagnostics_) diagnostics_.add_error(diag::System::Writer,
<< "Override should have been removed by the substitute_override transform."; "override expressions should have been removed with the "
"SubstituteOverride transform.");
return false; return false;
}, },
[&](const ast::Function* func) { [&](const ast::Function* func) {

View File

@ -506,13 +506,17 @@ bool Builder::GenerateExecutionModes(const ast::Function* func, uint32_t id) {
} else if (func->PipelineStage() == ast::PipelineStage::kCompute) { } else if (func->PipelineStage() == ast::PipelineStage::kCompute) {
auto& wgsize = func_sem->WorkgroupSize(); auto& wgsize = func_sem->WorkgroupSize();
// SubstituteOverride replaced all overrides with constants. // Check if the workgroup_size uses pipeline-overridable constants.
uint32_t x = wgsize[0].value; if (!wgsize[0].has_value() || !wgsize[1].has_value() || !wgsize[2].has_value()) {
uint32_t y = wgsize[1].value; error_ =
uint32_t z = wgsize[2].value; "override expressions should have been removed with the SubstituteOverride "
push_execution_mode(spv::Op::OpExecutionMode, "transform";
{Operand(id), U32Operand(SpvExecutionModeLocalSize), Operand(x), return false;
Operand(y), Operand(z)}); }
push_execution_mode(
spv::Op::OpExecutionMode,
{Operand(id), U32Operand(SpvExecutionModeLocalSize), //
Operand(wgsize[0].value()), Operand(wgsize[1].value()), Operand(wgsize[2].value())});
} }
for (auto builtin : func_sem->TransitivelyReferencedBuiltinVariables()) { for (auto builtin : func_sem->TransitivelyReferencedBuiltinVariables()) {

View File

@ -149,6 +149,41 @@ TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize_Const) {
)"); )");
} }
TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize_OverridableConst) {
Override("width", ty.i32(), Construct(ty.i32(), 2_i), Id(7_u));
Override("height", ty.i32(), Construct(ty.i32(), 3_i), Id(8_u));
Override("depth", ty.i32(), Construct(ty.i32(), 4_i), Id(9_u));
auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
WorkgroupSize("width", "height", "depth"),
Stage(ast::PipelineStage::kCompute),
});
spirv::Builder& b = Build();
EXPECT_FALSE(b.GenerateExecutionModes(func, 3)) << b.error();
EXPECT_EQ(
b.error(),
R"(override expressions should have been removed with the SubstituteOverride transform)");
}
TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize_LiteralAndConst) {
Override("height", ty.i32(), Construct(ty.i32(), 2_i), Id(7_u));
GlobalConst("depth", ty.i32(), Construct(ty.i32(), 3_i));
auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
WorkgroupSize(4_i, "height", "depth"),
Stage(ast::PipelineStage::kCompute),
});
spirv::Builder& b = Build();
EXPECT_FALSE(b.GenerateExecutionModes(func, 3)) << b.error();
EXPECT_EQ(
b.error(),
R"(override expressions should have been removed with the SubstituteOverride transform)");
}
TEST_F(BuilderTest, Decoration_ExecutionMode_MultipleFragment) { TEST_F(BuilderTest, Decoration_ExecutionMode_MultipleFragment) {
auto* func1 = Func("main1", utils::Empty, ty.void_(), utils::Empty, auto* func1 = Func("main1", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{ utils::Vector{