resolver: Fix constant propagation for POC

Pipeline overidable constants are not compile-time constant.
If a module-scope const has an [[override]] decoration, do not assign
the constant value to it, as this will propagate, and the constant value
may become inlined in places that should be overridable.

Also: Rename sem::GlobalVariable::IsPipelineConstant() to
IsOverridable() to make it clearer that this is not a compile-time known
value. Add SetIsOverridable() so we can correctly set the
IsOverridable() flag even when there isn't an ID.

Change-Id: I5ede9dd180d5ff1696b3868ea4313fc28f93af4b
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/69140
Auto-Submit: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton
2021-11-11 19:12:36 +00:00
committed by Tint LUCI CQ
parent 3fe243b282
commit 6cdb1bf7c0
9 changed files with 63 additions and 54 deletions

View File

@@ -30,24 +30,26 @@ class ResolverPipelineOverridableConstantTest : public ResolverTest {
auto* sem = Sem().Get<sem::GlobalVariable>(var);
ASSERT_NE(sem, nullptr);
EXPECT_EQ(sem->Declaration(), var);
EXPECT_TRUE(sem->IsPipelineConstant());
EXPECT_TRUE(sem->IsOverridable());
EXPECT_EQ(sem->ConstantId(), id);
EXPECT_FALSE(sem->ConstantValue());
}
};
TEST_F(ResolverPipelineOverridableConstantTest, NonOverridable) {
auto* a = GlobalConst("a", ty.f32(), Construct(ty.f32()));
auto* a = GlobalConst("a", ty.f32(), Expr(1.f));
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem_a = Sem().Get<sem::GlobalVariable>(a);
ASSERT_NE(sem_a, nullptr);
EXPECT_EQ(sem_a->Declaration(), a);
EXPECT_FALSE(sem_a->IsPipelineConstant());
EXPECT_FALSE(sem_a->IsOverridable());
EXPECT_TRUE(sem_a->ConstantValue());
}
TEST_F(ResolverPipelineOverridableConstantTest, WithId) {
auto* a = GlobalConst("a", ty.f32(), Construct(ty.f32()), {Override(7u)});
auto* a = GlobalConst("a", ty.f32(), Expr(1.f), {Override(7u)});
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -55,7 +57,7 @@ TEST_F(ResolverPipelineOverridableConstantTest, WithId) {
}
TEST_F(ResolverPipelineOverridableConstantTest, WithoutId) {
auto* a = GlobalConst("a", ty.f32(), Construct(ty.f32()), {Override()});
auto* a = GlobalConst("a", ty.f32(), Expr(1.f), {Override()});
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -64,12 +66,12 @@ TEST_F(ResolverPipelineOverridableConstantTest, WithoutId) {
TEST_F(ResolverPipelineOverridableConstantTest, WithAndWithoutIds) {
std::vector<ast::Variable*> variables;
auto* a = GlobalConst("a", ty.f32(), Construct(ty.f32()), {Override()});
auto* b = GlobalConst("b", ty.f32(), Construct(ty.f32()), {Override()});
auto* c = GlobalConst("c", ty.f32(), Construct(ty.f32()), {Override(2u)});
auto* d = GlobalConst("d", ty.f32(), Construct(ty.f32()), {Override(4u)});
auto* e = GlobalConst("e", ty.f32(), Construct(ty.f32()), {Override()});
auto* f = GlobalConst("f", ty.f32(), Construct(ty.f32()), {Override(1u)});
auto* a = GlobalConst("a", ty.f32(), Expr(1.f), {Override()});
auto* b = GlobalConst("b", ty.f32(), Expr(1.f), {Override()});
auto* c = GlobalConst("c", ty.f32(), Expr(1.f), {Override(2u)});
auto* d = GlobalConst("d", ty.f32(), Expr(1.f), {Override(4u)});
auto* e = GlobalConst("e", ty.f32(), Expr(1.f), {Override()});
auto* f = GlobalConst("f", ty.f32(), Expr(1.f), {Override(1u)});
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -83,10 +85,8 @@ TEST_F(ResolverPipelineOverridableConstantTest, WithAndWithoutIds) {
}
TEST_F(ResolverPipelineOverridableConstantTest, DuplicateIds) {
GlobalConst("a", ty.f32(), Construct(ty.f32()),
{Override(Source{{12, 34}}, 7u)});
GlobalConst("b", ty.f32(), Construct(ty.f32()),
{Override(Source{{56, 78}}, 7u)});
GlobalConst("a", ty.f32(), Expr(1.f), {Override(Source{{12, 34}}, 7u)});
GlobalConst("b", ty.f32(), Expr(1.f), {Override(Source{{56, 78}}, 7u)});
EXPECT_FALSE(r()->Resolve());
@@ -95,8 +95,7 @@ TEST_F(ResolverPipelineOverridableConstantTest, DuplicateIds) {
}
TEST_F(ResolverPipelineOverridableConstantTest, IdTooLarge) {
GlobalConst("a", ty.f32(), Construct(ty.f32()),
{Override(Source{{12, 34}}, 65536u)});
GlobalConst("a", ty.f32(), Expr(1.f), {Override(Source{{12, 34}}, 65536u)});
EXPECT_FALSE(r()->Resolve());

View File

@@ -547,13 +547,17 @@ sem::Variable* Resolver::Variable(const ast::Variable* var,
binding_point = {bp.group->value, bp.binding->value};
}
auto* override =
ast::GetDecoration<ast::OverrideDecoration>(var->decorations);
bool has_const_val = rhs && var->is_const && !override;
auto* global = builder_->create<sem::GlobalVariable>(
var, var_ty, storage_class, access,
(rhs && var->is_const) ? rhs->ConstantValue() : sem::Constant{},
has_const_val ? rhs->ConstantValue() : sem::Constant{},
binding_point);
if (auto* override =
ast::GetDecoration<ast::OverrideDecoration>(var->decorations)) {
if (override) {
global->SetIsOverridable();
if (override->has_value) {
global->SetConstantId(static_cast<uint16_t>(override->value));
}
@@ -1995,27 +1999,30 @@ bool Resolver::WorkgroupSizeFor(const ast::Function* func,
return false;
}
if (auto* ident = expr->As<ast::IdentifierExpression>()) {
// We have an identifier of a module-scope constant.
auto* var = variable_stack_.Get(ident->symbol);
if (!var || !var->Declaration()->is_const) {
sem::Constant value;
if (auto* user = Sem(expr)->As<sem::VariableUser>()) {
// We have an variable of a module-scope constant.
auto* decl = user->Variable()->Declaration();
if (!decl->is_const) {
AddError(kErrBadType, expr->source);
return false;
}
auto* decl = var->Declaration();
// Capture the constant if an [[override]] attribute is present.
if (ast::HasDecoration<ast::OverrideDecoration>(decl->decorations)) {
ws[i].overridable_const = decl;
}
expr = decl->constructor;
if (!expr) {
if (decl->constructor) {
value = Sem(decl->constructor)->ConstantValue();
} else {
// No constructor means this value must be overriden by the user.
ws[i].value = 0;
continue;
}
} else if (!expr->Is<ast::LiteralExpression>()) {
} else if (expr->Is<ast::LiteralExpression>()) {
value = Sem(expr)->ConstantValue();
} else {
AddError(
"workgroup_size argument must be either a literal or a "
"module-scope constant",
@@ -2023,20 +2030,19 @@ bool Resolver::WorkgroupSizeFor(const ast::Function* func,
return false;
}
auto val = expr_sem->ConstantValue();
if (!val) {
if (!value) {
TINT_ICE(Resolver, diagnostics_)
<< "could not resolve constant workgroup_size constant value";
continue;
}
// Validate and set the default value for this dimension.
if (is_i32 ? val.Elements()[0].i32 < 1 : val.Elements()[0].u32 < 1) {
if (is_i32 ? value.Elements()[0].i32 < 1 : value.Elements()[0].u32 < 1) {
AddError("workgroup_size argument must be at least 1", values[i]->source);
return false;
}
ws[i].value = is_i32 ? static_cast<uint32_t>(val.Elements()[0].i32)
: val.Elements()[0].u32;
ws[i].value = is_i32 ? static_cast<uint32_t>(value.Elements()[0].i32)
: value.Elements()[0].u32;
}
return true;
}