resolver: Fix constant propagation for POC
Pipeline overidable constants are not compile-time constant. If a module-scope const has an [[override]] decoration, do not assign the constant value to it, as this will propagate, and the constant value may become inlined in places that should be overridable. Also: Rename sem::GlobalVariable::IsPipelineConstant() to IsOverridable() to make it clearer that this is not a compile-time known value. Add SetIsOverridable() so we can correctly set the IsOverridable() flag even when there isn't an ID. Change-Id: I5ede9dd180d5ff1696b3868ea4313fc28f93af4b Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/69140 Auto-Submit: Ben Clayton <bclayton@google.com> Reviewed-by: James Price <jrprice@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
3fe243b282
commit
6cdb1bf7c0
|
@ -193,7 +193,7 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
|
|||
auto name = program_->Symbols().NameFor(decl->symbol);
|
||||
|
||||
auto* global = var->As<sem::GlobalVariable>();
|
||||
if (global && global->IsPipelineConstant()) {
|
||||
if (global && global->IsOverridable()) {
|
||||
OverridableConstant overridable_constant;
|
||||
overridable_constant.name = name;
|
||||
overridable_constant.numeric_id = global->ConstantId();
|
||||
|
@ -245,7 +245,7 @@ std::map<uint32_t, Scalar> Inspector::GetConstantIDs() {
|
|||
std::map<uint32_t, Scalar> result;
|
||||
for (auto* var : program_->AST().GlobalVariables()) {
|
||||
auto* global = program_->Sem().Get<sem::GlobalVariable>(var);
|
||||
if (!global || !global->IsPipelineConstant()) {
|
||||
if (!global || !global->IsOverridable()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -300,7 +300,7 @@ std::map<std::string, uint32_t> Inspector::GetConstantNameToIdMap() {
|
|||
std::map<std::string, uint32_t> result;
|
||||
for (auto* var : program_->AST().GlobalVariables()) {
|
||||
auto* global = program_->Sem().Get<sem::GlobalVariable>(var);
|
||||
if (global && global->IsPipelineConstant()) {
|
||||
if (global && global->IsOverridable()) {
|
||||
auto name = program_->Symbols().NameFor(var->symbol);
|
||||
result[name] = global->ConstantId();
|
||||
}
|
||||
|
|
|
@ -30,24 +30,26 @@ class ResolverPipelineOverridableConstantTest : public ResolverTest {
|
|||
auto* sem = Sem().Get<sem::GlobalVariable>(var);
|
||||
ASSERT_NE(sem, nullptr);
|
||||
EXPECT_EQ(sem->Declaration(), var);
|
||||
EXPECT_TRUE(sem->IsPipelineConstant());
|
||||
EXPECT_TRUE(sem->IsOverridable());
|
||||
EXPECT_EQ(sem->ConstantId(), id);
|
||||
EXPECT_FALSE(sem->ConstantValue());
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(ResolverPipelineOverridableConstantTest, NonOverridable) {
|
||||
auto* a = GlobalConst("a", ty.f32(), Construct(ty.f32()));
|
||||
auto* a = GlobalConst("a", ty.f32(), Expr(1.f));
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
|
||||
auto* sem_a = Sem().Get<sem::GlobalVariable>(a);
|
||||
ASSERT_NE(sem_a, nullptr);
|
||||
EXPECT_EQ(sem_a->Declaration(), a);
|
||||
EXPECT_FALSE(sem_a->IsPipelineConstant());
|
||||
EXPECT_FALSE(sem_a->IsOverridable());
|
||||
EXPECT_TRUE(sem_a->ConstantValue());
|
||||
}
|
||||
|
||||
TEST_F(ResolverPipelineOverridableConstantTest, WithId) {
|
||||
auto* a = GlobalConst("a", ty.f32(), Construct(ty.f32()), {Override(7u)});
|
||||
auto* a = GlobalConst("a", ty.f32(), Expr(1.f), {Override(7u)});
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
|
||||
|
@ -55,7 +57,7 @@ TEST_F(ResolverPipelineOverridableConstantTest, WithId) {
|
|||
}
|
||||
|
||||
TEST_F(ResolverPipelineOverridableConstantTest, WithoutId) {
|
||||
auto* a = GlobalConst("a", ty.f32(), Construct(ty.f32()), {Override()});
|
||||
auto* a = GlobalConst("a", ty.f32(), Expr(1.f), {Override()});
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
|
||||
|
@ -64,12 +66,12 @@ TEST_F(ResolverPipelineOverridableConstantTest, WithoutId) {
|
|||
|
||||
TEST_F(ResolverPipelineOverridableConstantTest, WithAndWithoutIds) {
|
||||
std::vector<ast::Variable*> variables;
|
||||
auto* a = GlobalConst("a", ty.f32(), Construct(ty.f32()), {Override()});
|
||||
auto* b = GlobalConst("b", ty.f32(), Construct(ty.f32()), {Override()});
|
||||
auto* c = GlobalConst("c", ty.f32(), Construct(ty.f32()), {Override(2u)});
|
||||
auto* d = GlobalConst("d", ty.f32(), Construct(ty.f32()), {Override(4u)});
|
||||
auto* e = GlobalConst("e", ty.f32(), Construct(ty.f32()), {Override()});
|
||||
auto* f = GlobalConst("f", ty.f32(), Construct(ty.f32()), {Override(1u)});
|
||||
auto* a = GlobalConst("a", ty.f32(), Expr(1.f), {Override()});
|
||||
auto* b = GlobalConst("b", ty.f32(), Expr(1.f), {Override()});
|
||||
auto* c = GlobalConst("c", ty.f32(), Expr(1.f), {Override(2u)});
|
||||
auto* d = GlobalConst("d", ty.f32(), Expr(1.f), {Override(4u)});
|
||||
auto* e = GlobalConst("e", ty.f32(), Expr(1.f), {Override()});
|
||||
auto* f = GlobalConst("f", ty.f32(), Expr(1.f), {Override(1u)});
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
|
||||
|
@ -83,10 +85,8 @@ TEST_F(ResolverPipelineOverridableConstantTest, WithAndWithoutIds) {
|
|||
}
|
||||
|
||||
TEST_F(ResolverPipelineOverridableConstantTest, DuplicateIds) {
|
||||
GlobalConst("a", ty.f32(), Construct(ty.f32()),
|
||||
{Override(Source{{12, 34}}, 7u)});
|
||||
GlobalConst("b", ty.f32(), Construct(ty.f32()),
|
||||
{Override(Source{{56, 78}}, 7u)});
|
||||
GlobalConst("a", ty.f32(), Expr(1.f), {Override(Source{{12, 34}}, 7u)});
|
||||
GlobalConst("b", ty.f32(), Expr(1.f), {Override(Source{{56, 78}}, 7u)});
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
|
||||
|
@ -95,8 +95,7 @@ TEST_F(ResolverPipelineOverridableConstantTest, DuplicateIds) {
|
|||
}
|
||||
|
||||
TEST_F(ResolverPipelineOverridableConstantTest, IdTooLarge) {
|
||||
GlobalConst("a", ty.f32(), Construct(ty.f32()),
|
||||
{Override(Source{{12, 34}}, 65536u)});
|
||||
GlobalConst("a", ty.f32(), Expr(1.f), {Override(Source{{12, 34}}, 65536u)});
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
|
||||
|
|
|
@ -547,13 +547,17 @@ sem::Variable* Resolver::Variable(const ast::Variable* var,
|
|||
binding_point = {bp.group->value, bp.binding->value};
|
||||
}
|
||||
|
||||
auto* override =
|
||||
ast::GetDecoration<ast::OverrideDecoration>(var->decorations);
|
||||
bool has_const_val = rhs && var->is_const && !override;
|
||||
|
||||
auto* global = builder_->create<sem::GlobalVariable>(
|
||||
var, var_ty, storage_class, access,
|
||||
(rhs && var->is_const) ? rhs->ConstantValue() : sem::Constant{},
|
||||
has_const_val ? rhs->ConstantValue() : sem::Constant{},
|
||||
binding_point);
|
||||
|
||||
if (auto* override =
|
||||
ast::GetDecoration<ast::OverrideDecoration>(var->decorations)) {
|
||||
if (override) {
|
||||
global->SetIsOverridable();
|
||||
if (override->has_value) {
|
||||
global->SetConstantId(static_cast<uint16_t>(override->value));
|
||||
}
|
||||
|
@ -1995,27 +1999,30 @@ bool Resolver::WorkgroupSizeFor(const ast::Function* func,
|
|||
return false;
|
||||
}
|
||||
|
||||
if (auto* ident = expr->As<ast::IdentifierExpression>()) {
|
||||
// We have an identifier of a module-scope constant.
|
||||
auto* var = variable_stack_.Get(ident->symbol);
|
||||
if (!var || !var->Declaration()->is_const) {
|
||||
sem::Constant value;
|
||||
|
||||
if (auto* user = Sem(expr)->As<sem::VariableUser>()) {
|
||||
// We have an variable of a module-scope constant.
|
||||
auto* decl = user->Variable()->Declaration();
|
||||
if (!decl->is_const) {
|
||||
AddError(kErrBadType, expr->source);
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* decl = var->Declaration();
|
||||
// Capture the constant if an [[override]] attribute is present.
|
||||
if (ast::HasDecoration<ast::OverrideDecoration>(decl->decorations)) {
|
||||
ws[i].overridable_const = decl;
|
||||
}
|
||||
|
||||
expr = decl->constructor;
|
||||
if (!expr) {
|
||||
if (decl->constructor) {
|
||||
value = Sem(decl->constructor)->ConstantValue();
|
||||
} else {
|
||||
// No constructor means this value must be overriden by the user.
|
||||
ws[i].value = 0;
|
||||
continue;
|
||||
}
|
||||
} else if (!expr->Is<ast::LiteralExpression>()) {
|
||||
} else if (expr->Is<ast::LiteralExpression>()) {
|
||||
value = Sem(expr)->ConstantValue();
|
||||
} else {
|
||||
AddError(
|
||||
"workgroup_size argument must be either a literal or a "
|
||||
"module-scope constant",
|
||||
|
@ -2023,20 +2030,19 @@ bool Resolver::WorkgroupSizeFor(const ast::Function* func,
|
|||
return false;
|
||||
}
|
||||
|
||||
auto val = expr_sem->ConstantValue();
|
||||
if (!val) {
|
||||
if (!value) {
|
||||
TINT_ICE(Resolver, diagnostics_)
|
||||
<< "could not resolve constant workgroup_size constant value";
|
||||
continue;
|
||||
}
|
||||
// Validate and set the default value for this dimension.
|
||||
if (is_i32 ? val.Elements()[0].i32 < 1 : val.Elements()[0].u32 < 1) {
|
||||
if (is_i32 ? value.Elements()[0].i32 < 1 : value.Elements()[0].u32 < 1) {
|
||||
AddError("workgroup_size argument must be at least 1", values[i]->source);
|
||||
return false;
|
||||
}
|
||||
|
||||
ws[i].value = is_i32 ? static_cast<uint32_t>(val.Elements()[0].i32)
|
||||
: val.Elements()[0].u32;
|
||||
ws[i].value = is_i32 ? static_cast<uint32_t>(value.Elements()[0].i32)
|
||||
: value.Elements()[0].u32;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -61,8 +61,7 @@ GlobalVariable::GlobalVariable(const ast::Variable* declaration,
|
|||
Constant constant_value,
|
||||
sem::BindingPoint binding_point)
|
||||
: Base(declaration, type, storage_class, access, std::move(constant_value)),
|
||||
binding_point_(binding_point),
|
||||
is_pipeline_constant_(false) {}
|
||||
binding_point_(binding_point) {}
|
||||
|
||||
GlobalVariable::~GlobalVariable() = default;
|
||||
|
||||
|
|
|
@ -129,22 +129,27 @@ class GlobalVariable : public Castable<GlobalVariable, Variable> {
|
|||
/// @returns the resource binding point for the variable
|
||||
sem::BindingPoint BindingPoint() const { return binding_point_; }
|
||||
|
||||
/// @returns the pipeline constant ID associated with the variable
|
||||
uint16_t ConstantId() const { return constant_id_; }
|
||||
|
||||
/// @param id the constant identifier to assign to this variable
|
||||
void SetConstantId(uint16_t id) {
|
||||
constant_id_ = id;
|
||||
is_pipeline_constant_ = true;
|
||||
is_overridable_ = true;
|
||||
}
|
||||
|
||||
/// @returns true if this variable is an overridable pipeline constant
|
||||
bool IsPipelineConstant() const { return is_pipeline_constant_; }
|
||||
/// @returns the pipeline constant ID associated with the variable
|
||||
uint16_t ConstantId() const { return constant_id_; }
|
||||
|
||||
/// @param is_overridable true if this is a pipeline overridable constant
|
||||
void SetIsOverridable(bool is_overridable = true) {
|
||||
is_overridable_ = is_overridable;
|
||||
}
|
||||
|
||||
/// @returns true if this is pipeline overridable constant
|
||||
bool IsOverridable() const { return is_overridable_; }
|
||||
|
||||
private:
|
||||
const sem::BindingPoint binding_point_;
|
||||
|
||||
bool is_pipeline_constant_ = false;
|
||||
bool is_overridable_ = false;
|
||||
uint16_t constant_id_ = 0;
|
||||
};
|
||||
|
||||
|
|
|
@ -1835,7 +1835,7 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
|
|||
if (wgsize[i].overridable_const) {
|
||||
auto* global = builder_.Sem().Get<sem::GlobalVariable>(
|
||||
wgsize[i].overridable_const);
|
||||
if (!global->IsPipelineConstant()) {
|
||||
if (!global->IsOverridable()) {
|
||||
TINT_ICE(Writer, builder_.Diagnostics())
|
||||
<< "expected a pipeline-overridable constant";
|
||||
}
|
||||
|
@ -2657,7 +2657,7 @@ bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
|
|||
auto* type = sem->Type();
|
||||
|
||||
auto* global = sem->As<sem::GlobalVariable>();
|
||||
if (global && global->IsPipelineConstant()) {
|
||||
if (global && global->IsOverridable()) {
|
||||
auto const_id = global->ConstantId();
|
||||
|
||||
line() << "#ifndef " << kSpecConstantPrefix << const_id;
|
||||
|
|
|
@ -2620,7 +2620,7 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
|
|||
if (wgsize[i].overridable_const) {
|
||||
auto* global = builder_.Sem().Get<sem::GlobalVariable>(
|
||||
wgsize[i].overridable_const);
|
||||
if (!global->IsPipelineConstant()) {
|
||||
if (!global->IsOverridable()) {
|
||||
TINT_ICE(Writer, builder_.Diagnostics())
|
||||
<< "expected a pipeline-overridable constant";
|
||||
}
|
||||
|
@ -3407,7 +3407,7 @@ bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
|
|||
auto* type = sem->Type();
|
||||
|
||||
auto* global = sem->As<sem::GlobalVariable>();
|
||||
if (global && global->IsPipelineConstant()) {
|
||||
if (global && global->IsOverridable()) {
|
||||
auto const_id = global->ConstantId();
|
||||
|
||||
line() << "#ifndef " << kSpecConstantPrefix << const_id;
|
||||
|
|
|
@ -2651,7 +2651,7 @@ bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
|
|||
}
|
||||
|
||||
auto* global = program_->Sem().Get<sem::GlobalVariable>(var);
|
||||
if (global && global->IsPipelineConstant()) {
|
||||
if (global && global->IsOverridable()) {
|
||||
out << " [[function_constant(" << global->ConstantId() << ")]]";
|
||||
} else if (var->constructor != nullptr) {
|
||||
out << " = ";
|
||||
|
|
|
@ -524,7 +524,7 @@ bool Builder::GenerateExecutionModes(const ast::Function* func, uint32_t id) {
|
|||
// Make the constant specializable.
|
||||
auto* sem_const = builder_.Sem().Get<sem::GlobalVariable>(
|
||||
wgsize[i].overridable_const);
|
||||
if (!sem_const->IsPipelineConstant()) {
|
||||
if (!sem_const->IsOverridable()) {
|
||||
TINT_ICE(Writer, builder_.Diagnostics())
|
||||
<< "expected a pipeline-overridable constant";
|
||||
}
|
||||
|
@ -1333,7 +1333,7 @@ uint32_t Builder::GenerateTypeConstructorExpression(
|
|||
|
||||
// Generate the zero initializer if there are no values provided.
|
||||
if (values.empty()) {
|
||||
if (global_var && global_var->IsPipelineConstant()) {
|
||||
if (global_var && global_var->IsOverridable()) {
|
||||
auto constant_id = global_var->ConstantId();
|
||||
if (result_type->Is<sem::I32>()) {
|
||||
return GenerateConstantIfNeeded(
|
||||
|
@ -1669,7 +1669,7 @@ uint32_t Builder::GenerateLiteralIfNeeded(const ast::Variable* var,
|
|||
ScalarConstant constant;
|
||||
|
||||
auto* global = builder_.Sem().Get<sem::GlobalVariable>(var);
|
||||
if (global && global->IsPipelineConstant()) {
|
||||
if (global && global->IsOverridable()) {
|
||||
constant.is_spec_op = true;
|
||||
constant.constant_id = global->ConstantId();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue