resolver: Fix constant propagation for POC

Pipeline overidable constants are not compile-time constant.
If a module-scope const has an [[override]] decoration, do not assign
the constant value to it, as this will propagate, and the constant value
may become inlined in places that should be overridable.

Also: Rename sem::GlobalVariable::IsPipelineConstant() to
IsOverridable() to make it clearer that this is not a compile-time known
value. Add SetIsOverridable() so we can correctly set the
IsOverridable() flag even when there isn't an ID.

Change-Id: I5ede9dd180d5ff1696b3868ea4313fc28f93af4b
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/69140
Auto-Submit: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2021-11-11 19:12:36 +00:00 committed by Tint LUCI CQ
parent 3fe243b282
commit 6cdb1bf7c0
9 changed files with 63 additions and 54 deletions

View File

@ -193,7 +193,7 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
auto name = program_->Symbols().NameFor(decl->symbol); auto name = program_->Symbols().NameFor(decl->symbol);
auto* global = var->As<sem::GlobalVariable>(); auto* global = var->As<sem::GlobalVariable>();
if (global && global->IsPipelineConstant()) { if (global && global->IsOverridable()) {
OverridableConstant overridable_constant; OverridableConstant overridable_constant;
overridable_constant.name = name; overridable_constant.name = name;
overridable_constant.numeric_id = global->ConstantId(); overridable_constant.numeric_id = global->ConstantId();
@ -245,7 +245,7 @@ std::map<uint32_t, Scalar> Inspector::GetConstantIDs() {
std::map<uint32_t, Scalar> result; std::map<uint32_t, Scalar> result;
for (auto* var : program_->AST().GlobalVariables()) { for (auto* var : program_->AST().GlobalVariables()) {
auto* global = program_->Sem().Get<sem::GlobalVariable>(var); auto* global = program_->Sem().Get<sem::GlobalVariable>(var);
if (!global || !global->IsPipelineConstant()) { if (!global || !global->IsOverridable()) {
continue; continue;
} }
@ -300,7 +300,7 @@ std::map<std::string, uint32_t> Inspector::GetConstantNameToIdMap() {
std::map<std::string, uint32_t> result; std::map<std::string, uint32_t> result;
for (auto* var : program_->AST().GlobalVariables()) { for (auto* var : program_->AST().GlobalVariables()) {
auto* global = program_->Sem().Get<sem::GlobalVariable>(var); auto* global = program_->Sem().Get<sem::GlobalVariable>(var);
if (global && global->IsPipelineConstant()) { if (global && global->IsOverridable()) {
auto name = program_->Symbols().NameFor(var->symbol); auto name = program_->Symbols().NameFor(var->symbol);
result[name] = global->ConstantId(); result[name] = global->ConstantId();
} }

View File

@ -30,24 +30,26 @@ class ResolverPipelineOverridableConstantTest : public ResolverTest {
auto* sem = Sem().Get<sem::GlobalVariable>(var); auto* sem = Sem().Get<sem::GlobalVariable>(var);
ASSERT_NE(sem, nullptr); ASSERT_NE(sem, nullptr);
EXPECT_EQ(sem->Declaration(), var); EXPECT_EQ(sem->Declaration(), var);
EXPECT_TRUE(sem->IsPipelineConstant()); EXPECT_TRUE(sem->IsOverridable());
EXPECT_EQ(sem->ConstantId(), id); EXPECT_EQ(sem->ConstantId(), id);
EXPECT_FALSE(sem->ConstantValue());
} }
}; };
TEST_F(ResolverPipelineOverridableConstantTest, NonOverridable) { TEST_F(ResolverPipelineOverridableConstantTest, NonOverridable) {
auto* a = GlobalConst("a", ty.f32(), Construct(ty.f32())); auto* a = GlobalConst("a", ty.f32(), Expr(1.f));
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem_a = Sem().Get<sem::GlobalVariable>(a); auto* sem_a = Sem().Get<sem::GlobalVariable>(a);
ASSERT_NE(sem_a, nullptr); ASSERT_NE(sem_a, nullptr);
EXPECT_EQ(sem_a->Declaration(), a); EXPECT_EQ(sem_a->Declaration(), a);
EXPECT_FALSE(sem_a->IsPipelineConstant()); EXPECT_FALSE(sem_a->IsOverridable());
EXPECT_TRUE(sem_a->ConstantValue());
} }
TEST_F(ResolverPipelineOverridableConstantTest, WithId) { TEST_F(ResolverPipelineOverridableConstantTest, WithId) {
auto* a = GlobalConst("a", ty.f32(), Construct(ty.f32()), {Override(7u)}); auto* a = GlobalConst("a", ty.f32(), Expr(1.f), {Override(7u)});
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
@ -55,7 +57,7 @@ TEST_F(ResolverPipelineOverridableConstantTest, WithId) {
} }
TEST_F(ResolverPipelineOverridableConstantTest, WithoutId) { TEST_F(ResolverPipelineOverridableConstantTest, WithoutId) {
auto* a = GlobalConst("a", ty.f32(), Construct(ty.f32()), {Override()}); auto* a = GlobalConst("a", ty.f32(), Expr(1.f), {Override()});
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
@ -64,12 +66,12 @@ TEST_F(ResolverPipelineOverridableConstantTest, WithoutId) {
TEST_F(ResolverPipelineOverridableConstantTest, WithAndWithoutIds) { TEST_F(ResolverPipelineOverridableConstantTest, WithAndWithoutIds) {
std::vector<ast::Variable*> variables; std::vector<ast::Variable*> variables;
auto* a = GlobalConst("a", ty.f32(), Construct(ty.f32()), {Override()}); auto* a = GlobalConst("a", ty.f32(), Expr(1.f), {Override()});
auto* b = GlobalConst("b", ty.f32(), Construct(ty.f32()), {Override()}); auto* b = GlobalConst("b", ty.f32(), Expr(1.f), {Override()});
auto* c = GlobalConst("c", ty.f32(), Construct(ty.f32()), {Override(2u)}); auto* c = GlobalConst("c", ty.f32(), Expr(1.f), {Override(2u)});
auto* d = GlobalConst("d", ty.f32(), Construct(ty.f32()), {Override(4u)}); auto* d = GlobalConst("d", ty.f32(), Expr(1.f), {Override(4u)});
auto* e = GlobalConst("e", ty.f32(), Construct(ty.f32()), {Override()}); auto* e = GlobalConst("e", ty.f32(), Expr(1.f), {Override()});
auto* f = GlobalConst("f", ty.f32(), Construct(ty.f32()), {Override(1u)}); auto* f = GlobalConst("f", ty.f32(), Expr(1.f), {Override(1u)});
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
@ -83,10 +85,8 @@ TEST_F(ResolverPipelineOverridableConstantTest, WithAndWithoutIds) {
} }
TEST_F(ResolverPipelineOverridableConstantTest, DuplicateIds) { TEST_F(ResolverPipelineOverridableConstantTest, DuplicateIds) {
GlobalConst("a", ty.f32(), Construct(ty.f32()), GlobalConst("a", ty.f32(), Expr(1.f), {Override(Source{{12, 34}}, 7u)});
{Override(Source{{12, 34}}, 7u)}); GlobalConst("b", ty.f32(), Expr(1.f), {Override(Source{{56, 78}}, 7u)});
GlobalConst("b", ty.f32(), Construct(ty.f32()),
{Override(Source{{56, 78}}, 7u)});
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
@ -95,8 +95,7 @@ TEST_F(ResolverPipelineOverridableConstantTest, DuplicateIds) {
} }
TEST_F(ResolverPipelineOverridableConstantTest, IdTooLarge) { TEST_F(ResolverPipelineOverridableConstantTest, IdTooLarge) {
GlobalConst("a", ty.f32(), Construct(ty.f32()), GlobalConst("a", ty.f32(), Expr(1.f), {Override(Source{{12, 34}}, 65536u)});
{Override(Source{{12, 34}}, 65536u)});
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());

View File

@ -547,13 +547,17 @@ sem::Variable* Resolver::Variable(const ast::Variable* var,
binding_point = {bp.group->value, bp.binding->value}; binding_point = {bp.group->value, bp.binding->value};
} }
auto* override =
ast::GetDecoration<ast::OverrideDecoration>(var->decorations);
bool has_const_val = rhs && var->is_const && !override;
auto* global = builder_->create<sem::GlobalVariable>( auto* global = builder_->create<sem::GlobalVariable>(
var, var_ty, storage_class, access, var, var_ty, storage_class, access,
(rhs && var->is_const) ? rhs->ConstantValue() : sem::Constant{}, has_const_val ? rhs->ConstantValue() : sem::Constant{},
binding_point); binding_point);
if (auto* override = if (override) {
ast::GetDecoration<ast::OverrideDecoration>(var->decorations)) { global->SetIsOverridable();
if (override->has_value) { if (override->has_value) {
global->SetConstantId(static_cast<uint16_t>(override->value)); global->SetConstantId(static_cast<uint16_t>(override->value));
} }
@ -1995,27 +1999,30 @@ bool Resolver::WorkgroupSizeFor(const ast::Function* func,
return false; return false;
} }
if (auto* ident = expr->As<ast::IdentifierExpression>()) { sem::Constant value;
// We have an identifier of a module-scope constant.
auto* var = variable_stack_.Get(ident->symbol); if (auto* user = Sem(expr)->As<sem::VariableUser>()) {
if (!var || !var->Declaration()->is_const) { // We have an variable of a module-scope constant.
auto* decl = user->Variable()->Declaration();
if (!decl->is_const) {
AddError(kErrBadType, expr->source); AddError(kErrBadType, expr->source);
return false; return false;
} }
auto* decl = var->Declaration();
// Capture the constant if an [[override]] attribute is present. // Capture the constant if an [[override]] attribute is present.
if (ast::HasDecoration<ast::OverrideDecoration>(decl->decorations)) { if (ast::HasDecoration<ast::OverrideDecoration>(decl->decorations)) {
ws[i].overridable_const = decl; ws[i].overridable_const = decl;
} }
expr = decl->constructor; if (decl->constructor) {
if (!expr) { value = Sem(decl->constructor)->ConstantValue();
} else {
// No constructor means this value must be overriden by the user. // No constructor means this value must be overriden by the user.
ws[i].value = 0; ws[i].value = 0;
continue; continue;
} }
} else if (!expr->Is<ast::LiteralExpression>()) { } else if (expr->Is<ast::LiteralExpression>()) {
value = Sem(expr)->ConstantValue();
} else {
AddError( AddError(
"workgroup_size argument must be either a literal or a " "workgroup_size argument must be either a literal or a "
"module-scope constant", "module-scope constant",
@ -2023,20 +2030,19 @@ bool Resolver::WorkgroupSizeFor(const ast::Function* func,
return false; return false;
} }
auto val = expr_sem->ConstantValue(); if (!value) {
if (!val) {
TINT_ICE(Resolver, diagnostics_) TINT_ICE(Resolver, diagnostics_)
<< "could not resolve constant workgroup_size constant value"; << "could not resolve constant workgroup_size constant value";
continue; continue;
} }
// Validate and set the default value for this dimension. // Validate and set the default value for this dimension.
if (is_i32 ? val.Elements()[0].i32 < 1 : val.Elements()[0].u32 < 1) { if (is_i32 ? value.Elements()[0].i32 < 1 : value.Elements()[0].u32 < 1) {
AddError("workgroup_size argument must be at least 1", values[i]->source); AddError("workgroup_size argument must be at least 1", values[i]->source);
return false; return false;
} }
ws[i].value = is_i32 ? static_cast<uint32_t>(val.Elements()[0].i32) ws[i].value = is_i32 ? static_cast<uint32_t>(value.Elements()[0].i32)
: val.Elements()[0].u32; : value.Elements()[0].u32;
} }
return true; return true;
} }

View File

@ -61,8 +61,7 @@ GlobalVariable::GlobalVariable(const ast::Variable* declaration,
Constant constant_value, Constant constant_value,
sem::BindingPoint binding_point) sem::BindingPoint binding_point)
: Base(declaration, type, storage_class, access, std::move(constant_value)), : Base(declaration, type, storage_class, access, std::move(constant_value)),
binding_point_(binding_point), binding_point_(binding_point) {}
is_pipeline_constant_(false) {}
GlobalVariable::~GlobalVariable() = default; GlobalVariable::~GlobalVariable() = default;

View File

@ -129,22 +129,27 @@ class GlobalVariable : public Castable<GlobalVariable, Variable> {
/// @returns the resource binding point for the variable /// @returns the resource binding point for the variable
sem::BindingPoint BindingPoint() const { return binding_point_; } sem::BindingPoint BindingPoint() const { return binding_point_; }
/// @returns the pipeline constant ID associated with the variable
uint16_t ConstantId() const { return constant_id_; }
/// @param id the constant identifier to assign to this variable /// @param id the constant identifier to assign to this variable
void SetConstantId(uint16_t id) { void SetConstantId(uint16_t id) {
constant_id_ = id; constant_id_ = id;
is_pipeline_constant_ = true; is_overridable_ = true;
} }
/// @returns true if this variable is an overridable pipeline constant /// @returns the pipeline constant ID associated with the variable
bool IsPipelineConstant() const { return is_pipeline_constant_; } uint16_t ConstantId() const { return constant_id_; }
/// @param is_overridable true if this is a pipeline overridable constant
void SetIsOverridable(bool is_overridable = true) {
is_overridable_ = is_overridable;
}
/// @returns true if this is pipeline overridable constant
bool IsOverridable() const { return is_overridable_; }
private: private:
const sem::BindingPoint binding_point_; const sem::BindingPoint binding_point_;
bool is_pipeline_constant_ = false; bool is_overridable_ = false;
uint16_t constant_id_ = 0; uint16_t constant_id_ = 0;
}; };

View File

@ -1835,7 +1835,7 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
if (wgsize[i].overridable_const) { if (wgsize[i].overridable_const) {
auto* global = builder_.Sem().Get<sem::GlobalVariable>( auto* global = builder_.Sem().Get<sem::GlobalVariable>(
wgsize[i].overridable_const); wgsize[i].overridable_const);
if (!global->IsPipelineConstant()) { if (!global->IsOverridable()) {
TINT_ICE(Writer, builder_.Diagnostics()) TINT_ICE(Writer, builder_.Diagnostics())
<< "expected a pipeline-overridable constant"; << "expected a pipeline-overridable constant";
} }
@ -2657,7 +2657,7 @@ bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
auto* type = sem->Type(); auto* type = sem->Type();
auto* global = sem->As<sem::GlobalVariable>(); auto* global = sem->As<sem::GlobalVariable>();
if (global && global->IsPipelineConstant()) { if (global && global->IsOverridable()) {
auto const_id = global->ConstantId(); auto const_id = global->ConstantId();
line() << "#ifndef " << kSpecConstantPrefix << const_id; line() << "#ifndef " << kSpecConstantPrefix << const_id;

View File

@ -2620,7 +2620,7 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
if (wgsize[i].overridable_const) { if (wgsize[i].overridable_const) {
auto* global = builder_.Sem().Get<sem::GlobalVariable>( auto* global = builder_.Sem().Get<sem::GlobalVariable>(
wgsize[i].overridable_const); wgsize[i].overridable_const);
if (!global->IsPipelineConstant()) { if (!global->IsOverridable()) {
TINT_ICE(Writer, builder_.Diagnostics()) TINT_ICE(Writer, builder_.Diagnostics())
<< "expected a pipeline-overridable constant"; << "expected a pipeline-overridable constant";
} }
@ -3407,7 +3407,7 @@ bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
auto* type = sem->Type(); auto* type = sem->Type();
auto* global = sem->As<sem::GlobalVariable>(); auto* global = sem->As<sem::GlobalVariable>();
if (global && global->IsPipelineConstant()) { if (global && global->IsOverridable()) {
auto const_id = global->ConstantId(); auto const_id = global->ConstantId();
line() << "#ifndef " << kSpecConstantPrefix << const_id; line() << "#ifndef " << kSpecConstantPrefix << const_id;

View File

@ -2651,7 +2651,7 @@ bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
} }
auto* global = program_->Sem().Get<sem::GlobalVariable>(var); auto* global = program_->Sem().Get<sem::GlobalVariable>(var);
if (global && global->IsPipelineConstant()) { if (global && global->IsOverridable()) {
out << " [[function_constant(" << global->ConstantId() << ")]]"; out << " [[function_constant(" << global->ConstantId() << ")]]";
} else if (var->constructor != nullptr) { } else if (var->constructor != nullptr) {
out << " = "; out << " = ";

View File

@ -524,7 +524,7 @@ bool Builder::GenerateExecutionModes(const ast::Function* func, uint32_t id) {
// Make the constant specializable. // Make the constant specializable.
auto* sem_const = builder_.Sem().Get<sem::GlobalVariable>( auto* sem_const = builder_.Sem().Get<sem::GlobalVariable>(
wgsize[i].overridable_const); wgsize[i].overridable_const);
if (!sem_const->IsPipelineConstant()) { if (!sem_const->IsOverridable()) {
TINT_ICE(Writer, builder_.Diagnostics()) TINT_ICE(Writer, builder_.Diagnostics())
<< "expected a pipeline-overridable constant"; << "expected a pipeline-overridable constant";
} }
@ -1333,7 +1333,7 @@ uint32_t Builder::GenerateTypeConstructorExpression(
// Generate the zero initializer if there are no values provided. // Generate the zero initializer if there are no values provided.
if (values.empty()) { if (values.empty()) {
if (global_var && global_var->IsPipelineConstant()) { if (global_var && global_var->IsOverridable()) {
auto constant_id = global_var->ConstantId(); auto constant_id = global_var->ConstantId();
if (result_type->Is<sem::I32>()) { if (result_type->Is<sem::I32>()) {
return GenerateConstantIfNeeded( return GenerateConstantIfNeeded(
@ -1669,7 +1669,7 @@ uint32_t Builder::GenerateLiteralIfNeeded(const ast::Variable* var,
ScalarConstant constant; ScalarConstant constant;
auto* global = builder_.Sem().Get<sem::GlobalVariable>(var); auto* global = builder_.Sem().Get<sem::GlobalVariable>(var);
if (global && global->IsPipelineConstant()) { if (global && global->IsOverridable()) {
constant.is_spec_op = true; constant.is_spec_op = true;
constant.constant_id = global->ConstantId(); constant.constant_id = global->ConstantId();
} }