tint: Add sem::Expression

A new base class for sem::ValueExpression, which other types of
expression can derive from.

Example: sem::TypeExpression - an expression that resolves to a type.

Bug: tint:1810
Change-Id: I90dfb66b265b67d9fdf0c04eb3dce2442c7e18ea
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/118404
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton
2023-02-05 22:59:40 +00:00
committed by Dawn LUCI CQ
parent ef1811a18b
commit 0b4a2f1f50
55 changed files with 295 additions and 156 deletions

View File

@@ -74,7 +74,7 @@ struct OffsetExpr : Offset {
explicit OffsetExpr(const ast::Expression* e) : expr(e) {}
const ast::Expression* Build(CloneContext& ctx) const override {
auto* type = ctx.src->Sem().Get(expr)->Type()->UnwrapRef();
auto* type = ctx.src->Sem().GetVal(expr)->Type()->UnwrapRef();
auto* res = ctx.Clone(expr);
if (!type->Is<type::U32>()) {
res = ctx.dst->Construct<u32>(res);
@@ -881,7 +881,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
for (auto* node : src->ASTNodes().Objects()) {
if (auto* ident = node->As<ast::IdentifierExpression>()) {
// X
if (auto* sem_ident = sem.Get(ident)) {
if (auto* sem_ident = sem.GetVal(ident)) {
if (auto* var = sem_ident->UnwrapLoad()->As<sem::VariableUser>()) {
if (var->Variable()->AddressSpace() == type::AddressSpace::kStorage ||
var->Variable()->AddressSpace() == type::AddressSpace::kUniform) {

View File

@@ -123,7 +123,7 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&,
}
// Skip writes to invocation-private address spaces.
auto* ref = sem.Get(assign->lhs)->Type()->As<type::Reference>();
auto* ref = sem.GetVal(assign->lhs)->Type()->As<type::Reference>();
switch (ref->AddressSpace()) {
case type::AddressSpace::kStorage:
// Need to mask these.

View File

@@ -216,7 +216,7 @@ struct DirectVariableAccess::State {
// are grown and moved up the expression tree. After this stage, we are left with all the
// expression access chains to variables that we may need to transform.
for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* expr = sem.Get<sem::ValueExpression>(node)) {
if (auto* expr = sem.GetVal(node)) {
AppendAccessChain(expr);
}
}
@@ -464,7 +464,7 @@ struct DirectVariableAccess::State {
[&](const ast::Let*) {
if (variable->Type()->Is<type::Pointer>()) {
// variable is a pointer-let.
auto* init = sem.Get(variable->Declaration()->initializer);
auto* init = sem.GetVal(variable->Declaration()->initializer);
// Note: We do not use take_chain() here, as we need to preserve the
// AccessChain on the let's initializer, as the let needs its
// initializer updated, and the let may be used multiple times. Instead
@@ -498,7 +498,7 @@ struct DirectVariableAccess::State {
// If this is a '&' or '*', simply move the chain to the unary op expression.
if (unary->op == ast::UnaryOp::kAddressOf ||
unary->op == ast::UnaryOp::kIndirection) {
take_chain(sem.Get(unary->expr));
take_chain(sem.GetVal(unary->expr));
}
}
});
@@ -990,7 +990,7 @@ struct DirectVariableAccess::State {
return nullptr; // Just clone the expression.
}
auto* expr = sem.Get(ast_expr);
auto* expr = sem.GetVal(ast_expr);
if (!expr) {
// No semantic node for the expression.
return nullptr; // Just clone the expression.

View File

@@ -86,7 +86,10 @@ struct ExpandCompoundAssignment::State {
// Helper function that returns `true` if the type of `expr` is a vector.
auto is_vec = [&](const ast::Expression* expr) {
return ctx.src->Sem().Get(expr)->Type()->UnwrapRef()->Is<type::Vector>();
if (auto* val_expr = ctx.src->Sem().GetVal(expr)) {
return val_expr->Type()->UnwrapRef()->Is<type::Vector>();
}
return false;
};
// Hoist the LHS expression subtree into local constants to produce a new

View File

@@ -138,7 +138,7 @@ Transform::ApplyResult FirstIndexOffset::Apply(const Program* src,
// Fix up all references to the builtins with the offsets
ctx.ReplaceAll([=, &ctx](const ast::Expression* expr) -> const ast::Expression* {
if (auto* sem = ctx.src->Sem().Get(expr)) {
if (auto* sem = ctx.src->Sem().GetVal(expr)) {
if (auto* user = sem->UnwrapLoad()->As<sem::VariableUser>()) {
auto it = builtin_vars.find(user->Variable());
if (it != builtin_vars.end()) {

View File

@@ -164,7 +164,7 @@ struct LocalizeStructArrayAssignment::State {
ast::TraverseExpressions(
expr, b.Diagnostics(), [&](const ast::IndexAccessorExpression* ia) {
// Indexing using a runtime value?
auto* idx_sem = src->Sem().Get(ia->index);
auto* idx_sem = src->Sem().GetVal(ia->index);
if (!idx_sem->ConstantValue()) {
// Indexing a member access expr?
if (auto* ma = ia->object->As<ast::MemberAccessorExpression>()) {
@@ -186,7 +186,7 @@ struct LocalizeStructArrayAssignment::State {
// See https://www.w3.org/TR/WGSL/#originating-variable-section
std::pair<const type::Type*, type::AddressSpace> GetOriginatingTypeAndAddressSpace(
const ast::AssignmentStatement* assign_stmt) {
auto* root_ident = src->Sem().Get(assign_stmt->lhs)->RootIdentifier();
auto* root_ident = src->Sem().GetVal(assign_stmt->lhs)->RootIdentifier();
if (TINT_UNLIKELY(!root_ident)) {
TINT_ICE(Transform, b.Diagnostics())
<< "Unable to determine originating variable for lhs of assignment "

View File

@@ -198,7 +198,7 @@ struct MultiplanarExternalTexture::State {
builtin->Parameters()[0]->Type()->Is<type::ExternalTexture>() &&
builtin->Type() != sem::BuiltinType::kTextureDimensions) {
if (auto* var_user =
sem.Get(expr->args[0])->UnwrapLoad()->As<sem::VariableUser>()) {
sem.GetVal(expr->args[0])->UnwrapLoad()->As<sem::VariableUser>()) {
auto it = new_binding_symbols.find(var_user->Variable());
if (it == new_binding_symbols.end()) {
// If valid new binding locations were not provided earlier, we would have
@@ -223,7 +223,7 @@ struct MultiplanarExternalTexture::State {
// texture_external parameter. These need to be expanded out to multiple plane
// textures and the texture parameters structure.
for (auto* arg : expr->args) {
if (auto* var_user = sem.Get(arg)->UnwrapLoad()->As<sem::VariableUser>()) {
if (auto* var_user = sem.GetVal(arg)->UnwrapLoad()->As<sem::VariableUser>()) {
// Check if a parameter is a texture_external by trying to find
// it in the transform state.
auto it = new_binding_symbols.find(var_user->Variable());

View File

@@ -109,7 +109,7 @@ struct PackedVec3::State {
if (unary->op == ast::UnaryOp::kAddressOf ||
unary->op == ast::UnaryOp::kIndirection) {
// Memory access on the packed vector. Track these.
auto* inner = sem.Get(unary->expr);
auto* inner = sem.GetVal(unary->expr);
if (refs.Remove(inner)) {
refs.Add(expr);
}
@@ -121,7 +121,7 @@ struct PackedVec3::State {
[&](const sem::Statement* e) {
if (auto* assign = e->Declaration()->As<ast::AssignmentStatement>()) {
// We don't want to cast packed_vectors if they're being assigned to.
refs.Remove(sem.Get(assign->lhs));
refs.Remove(sem.GetVal(assign->lhs));
}
});
}

View File

@@ -48,7 +48,7 @@ struct PreservePadding::State {
Switch(
node, //
[&](const ast::AssignmentStatement* assign) {
auto* ty = sem.Get(assign->lhs)->Type();
auto* ty = sem.GetVal(assign->lhs)->Type();
if (assign->lhs->Is<ast::PhonyExpression>()) {
// Ignore phony assignment.
return;
@@ -80,7 +80,7 @@ struct PreservePadding::State {
if (!assignments_to_transform.count(assign)) {
return nullptr;
}
auto* ty = sem.Get(assign->lhs)->Type()->UnwrapRef();
auto* ty = sem.GetVal(assign->lhs)->Type()->UnwrapRef();
return MakeAssignment(ty, ctx.Clone(assign->lhs), ctx.Clone(assign->rhs));
});

View File

@@ -83,7 +83,7 @@ Transform::ApplyResult PromoteInitializersToLet::Apply(const Program* src,
// Walk the AST nodes. This order guarantees that leaf-expressions are visited first.
for (auto* node : src->ASTNodes().Objects()) {
if (auto* sem = src->Sem().Get<sem::ValueExpression>(node)) {
if (auto* sem = src->Sem().GetVal(node)) {
auto* stmt = sem->Stmt();
if (!stmt) {
// Expression is outside of a statement. This usually means the expression is part
@@ -118,7 +118,7 @@ Transform::ApplyResult PromoteInitializersToLet::Apply(const Program* src,
// After walking the full AST, const_chains only contains the outer-most constant expressions.
// Check if any of these need hoisting, and append those to to_hoist.
for (auto* expr : const_chains) {
if (auto* sem = src->Sem().Get(expr); should_hoist(sem)) {
if (auto* sem = src->Sem().GetVal(expr); should_hoist(sem)) {
to_hoist.Push(sem);
}
}

View File

@@ -66,9 +66,8 @@ Transform::ApplyResult SimplifySideEffectStatements::Apply(const Program* src,
HoistToDeclBefore hoist_to_decl_before(ctx);
for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* expr = node->As<ast::Expression>()) {
auto* sem_expr = src->Sem().Get(expr);
if (!sem_expr || !sem_expr->HasSideEffects()) {
if (auto* sem_expr = src->Sem().GetVal(node)) {
if (!sem_expr->HasSideEffects()) {
continue;
}
@@ -278,7 +277,7 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
return true;
},
[&](const ast::IdentifierExpression* e) {
if (auto* sem_e = sem.Get(e)) {
if (auto* sem_e = sem.GetVal(e)) {
if (auto* var_user = sem_e->UnwrapLoad()->As<sem::VariableUser>()) {
// Don't hoist constants.
if (var_user->ConstantValue()) {
@@ -417,8 +416,8 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
// Returns true if `binary_expr` should be decomposed for short-circuit eval.
bool IsLogicalWithSideEffects(const ast::BinaryExpression* binary_expr) {
return binary_expr->IsLogical() && (sem.Get(binary_expr->lhs)->HasSideEffects() ||
sem.Get(binary_expr->rhs)->HasSideEffects());
return binary_expr->IsLogical() && (sem.GetVal(binary_expr->lhs)->HasSideEffects() ||
sem.GetVal(binary_expr->rhs)->HasSideEffects());
}
// Recursive function used to decompose an expression for short-circuit eval.
@@ -560,7 +559,8 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
return Switch(
stmt,
[&](const ast::AssignmentStatement* s) -> const ast::Statement* {
if (!sem.Get(s->lhs)->HasSideEffects() && !sem.Get(s->rhs)->HasSideEffects()) {
if (!sem.GetVal(s->lhs)->HasSideEffects() &&
!sem.GetVal(s->rhs)->HasSideEffects()) {
return nullptr;
}
// rhs before lhs
@@ -580,7 +580,7 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
return ctx.CloneWithoutTransform(s);
},
[&](const ast::ForLoopStatement* s) -> const ast::Statement* {
if (!s->condition || !sem.Get(s->condition)->HasSideEffects()) {
if (!s->condition || !sem.GetVal(s->condition)->HasSideEffects()) {
return nullptr;
}
tint::utils::Vector<const ast::Statement*, 8> stmts;
@@ -589,7 +589,7 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
return ctx.CloneWithoutTransform(s);
},
[&](const ast::WhileStatement* s) -> const ast::Statement* {
if (!sem.Get(s->condition)->HasSideEffects()) {
if (!sem.GetVal(s->condition)->HasSideEffects()) {
return nullptr;
}
tint::utils::Vector<const ast::Statement*, 8> stmts;
@@ -598,7 +598,7 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
return ctx.CloneWithoutTransform(s);
},
[&](const ast::IfStatement* s) -> const ast::Statement* {
if (!sem.Get(s->condition)->HasSideEffects()) {
if (!sem.GetVal(s->condition)->HasSideEffects()) {
return nullptr;
}
tint::utils::Vector<const ast::Statement*, 8> stmts;
@@ -607,7 +607,7 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
return ctx.CloneWithoutTransform(s);
},
[&](const ast::ReturnStatement* s) -> const ast::Statement* {
if (!s->value || !sem.Get(s->value)->HasSideEffects()) {
if (!s->value || !sem.GetVal(s->value)->HasSideEffects()) {
return nullptr;
}
tint::utils::Vector<const ast::Statement*, 8> stmts;
@@ -626,7 +626,7 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
},
[&](const ast::VariableDeclStatement* s) -> const ast::Statement* {
auto* var = s->variable;
if (!var->initializer || !sem.Get(var->initializer)->HasSideEffects()) {
if (!var->initializer || !sem.GetVal(var->initializer)->HasSideEffects()) {
return nullptr;
}
tint::utils::Vector<const ast::Statement*, 8> stmts;

View File

@@ -102,7 +102,7 @@ Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&,
ctx.Replace(stmt, [&, side_effects] {
SinkSignature sig;
for (auto* arg : side_effects) {
sig.push_back(sem.Get(arg)->Type()->UnwrapRef());
sig.push_back(sem.GetVal(arg)->Type()->UnwrapRef());
}
auto sink = sinks.GetOrCreate(sig, [&] {
auto name = b.Symbols().New("phony_sink");

View File

@@ -1286,7 +1286,7 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
auto* sem = src->Sem().Get(accessor)->UnwrapLoad();
if (sem->Is<sem::Swizzle>()) {
preserved_identifiers.Add(accessor->member);
} else if (auto* str_expr = src->Sem().Get(accessor->object)) {
} else if (auto* str_expr = src->Sem().GetVal(accessor->object)) {
if (auto* ty = str_expr->Type()->UnwrapRef()->As<sem::Struct>()) {
if (ty->Declaration() == nullptr) { // Builtin structure
preserved_identifiers.Add(accessor->member);

View File

@@ -101,7 +101,7 @@ struct SpirvAtomic::State {
// Keep track of this expression. We'll need to modify the root identifier /
// structure to be atomic.
atomic_expressions.Add(ctx.src->Sem().Get(args[0]));
atomic_expressions.Add(ctx.src->Sem().GetVal(args[0]));
}
// Remove the stub from the output program
@@ -186,7 +186,7 @@ struct SpirvAtomic::State {
},
[&](const sem::ValueExpression* e) {
if (auto* unary = e->Declaration()->As<ast::UnaryOpExpression>()) {
atomic_expressions.Add(ctx.src->Sem().Get(unary->expr));
atomic_expressions.Add(ctx.src->Sem().GetVal(unary->expr));
}
});
}
@@ -249,7 +249,7 @@ struct SpirvAtomic::State {
Switch(
vu->Stmt()->Declaration(),
[&](const ast::AssignmentStatement* assign) {
auto* sem_lhs = ctx.src->Sem().Get(assign->lhs);
auto* sem_lhs = ctx.src->Sem().GetVal(assign->lhs);
if (is_ref_to_atomic_var(sem_lhs)) {
ctx.Replace(assign, [=] {
auto* lhs = ctx.CloneWithoutTransform(assign->lhs);
@@ -261,7 +261,7 @@ struct SpirvAtomic::State {
return;
}
auto sem_rhs = ctx.src->Sem().Get(assign->rhs);
auto sem_rhs = ctx.src->Sem().GetVal(assign->rhs);
if (is_ref_to_atomic_var(sem_rhs->UnwrapLoad())) {
ctx.Replace(assign->rhs, [=] {
auto* rhs = ctx.CloneWithoutTransform(assign->rhs);
@@ -273,7 +273,7 @@ struct SpirvAtomic::State {
},
[&](const ast::VariableDeclStatement* decl) {
auto* var = decl->variable;
if (auto* sem_init = ctx.src->Sem().Get(var->initializer)) {
if (auto* sem_init = ctx.src->Sem().GetVal(var->initializer)) {
if (is_ref_to_atomic_var(sem_init->UnwrapLoad())) {
ctx.Replace(var->initializer, [=] {
auto* rhs = ctx.CloneWithoutTransform(var->initializer);

View File

@@ -494,7 +494,7 @@ struct Std140::State {
/// @returns an AccessChain if the expression is an access to a std140-forked uniform buffer,
/// otherwise returns a std::nullopt.
std::optional<AccessChain> AccessChainFor(const ast::Expression* ast_expr) {
auto* expr = sem.Get(ast_expr);
auto* expr = sem.GetVal(ast_expr);
if (!expr) {
return std::nullopt;
}
@@ -580,7 +580,7 @@ struct Std140::State {
switch (u->op) {
case ast::UnaryOp::kAddressOf:
case ast::UnaryOp::kIndirection:
expr = sem.Get(u->expr);
expr = sem.GetVal(u->expr);
return Action::kContinue;
default:
TINT_ICE(Transform, b.Diagnostics())

View File

@@ -106,7 +106,7 @@ struct Unshadow::State {
ctx.ReplaceAll(
[&](const ast::IdentifierExpression* ident) -> const tint::ast::IdentifierExpression* {
if (auto* sem_ident = sem.Get(ident)) {
if (auto* sem_ident = sem.GetVal(ident)) {
if (auto* user = sem_ident->UnwrapLoad()->As<sem::VariableUser>()) {
if (auto renamed = renamed_to.Find(user->Variable())) {
return b.Expr(*renamed);

View File

@@ -108,7 +108,7 @@ TEST_F(HoistToDeclBeforeTest, ForLoopCond) {
CloneContext ctx(&cloned_b, &original);
HoistToDeclBefore hoistToDeclBefore(ctx);
auto* sem_expr = ctx.src->Sem().Get(expr);
auto* sem_expr = ctx.src->Sem().GetVal(expr);
hoistToDeclBefore.Add(sem_expr, expr, HoistToDeclBefore::VariableKind::kConst);
ctx.Clone();
@@ -189,7 +189,7 @@ TEST_F(HoistToDeclBeforeTest, WhileCond) {
CloneContext ctx(&cloned_b, &original);
HoistToDeclBefore hoistToDeclBefore(ctx);
auto* sem_expr = ctx.src->Sem().Get(expr);
auto* sem_expr = ctx.src->Sem().GetVal(expr);
hoistToDeclBefore.Add(sem_expr, expr, HoistToDeclBefore::VariableKind::kVar);
ctx.Clone();
@@ -233,7 +233,7 @@ TEST_F(HoistToDeclBeforeTest, ElseIf) {
CloneContext ctx(&cloned_b, &original);
HoistToDeclBefore hoistToDeclBefore(ctx);
auto* sem_expr = ctx.src->Sem().Get(expr);
auto* sem_expr = ctx.src->Sem().GetVal(expr);
hoistToDeclBefore.Add(sem_expr, expr, HoistToDeclBefore::VariableKind::kConst);
ctx.Clone();
@@ -339,7 +339,7 @@ TEST_F(HoistToDeclBeforeTest, Prepare_ForLoopCond) {
CloneContext ctx(&cloned_b, &original);
HoistToDeclBefore hoistToDeclBefore(ctx);
auto* sem_expr = ctx.src->Sem().Get(expr);
auto* sem_expr = ctx.src->Sem().GetVal(expr);
hoistToDeclBefore.Prepare(sem_expr);
ctx.Clone();
@@ -422,7 +422,7 @@ TEST_F(HoistToDeclBeforeTest, Prepare_ElseIf) {
CloneContext ctx(&cloned_b, &original);
HoistToDeclBefore hoistToDeclBefore(ctx);
auto* sem_expr = ctx.src->Sem().Get(expr);
auto* sem_expr = ctx.src->Sem().GetVal(expr);
hoistToDeclBefore.Prepare(sem_expr);
ctx.Clone();

View File

@@ -40,13 +40,13 @@ Transform::ApplyResult VarForDynamicIndex::Apply(const Program* src,
auto* object_expr = access_expr->object;
auto& sem = src->Sem();
if (sem.Get(index_expr)->ConstantValue()) {
if (sem.GetVal(index_expr)->ConstantValue()) {
// Index expression resolves to a compile time value.
// As this isn't a dynamic index, we can ignore this.
return true;
}
auto* indexed = sem.Get(object_expr);
auto* indexed = sem.GetVal(object_expr);
if (!indexed->Type()->IsAnyOf<type::Array, type::Matrix>()) {
// We only care about array and matrices.
return true;

View File

@@ -34,7 +34,7 @@ namespace {
bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* sem = program->Sem().Get<sem::ValueExpression>(node)) {
if (auto* sem = program->Sem().GetVal(node)) {
if (auto* call = sem->UnwrapMaterialize()->As<sem::Call>()) {
if (call->Target()->Is<sem::TypeConversion>() && call->Type()->Is<type::Matrix>()) {
auto& args = call->Arguments();

View File

@@ -403,7 +403,7 @@ struct ZeroInitWorkgroupMemory::State {
if (!expr) {
continue;
}
auto* sem = ctx.src->Sem().Get(expr);
auto* sem = ctx.src->Sem().GetVal(expr);
if (auto* c = sem->ConstantValue()) {
workgroup_size_const *= c->ValueAs<AInt>();
continue;