tint/transform: Skip SimplifyPointers if possible

Change-Id: Id937d82e9062cf7a4c54401121ed6d22e5d4fd73
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/116870
Auto-Submit: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ben Clayton 2023-01-12 18:29:07 +00:00 committed by Dawn LUCI CQ
parent 3d6c263446
commit 42363a5b18
6 changed files with 106 additions and 83 deletions

View File

@ -130,48 +130,22 @@ struct SimplifyPointers::State {
// A map of saved expressions to their saved variable name // A map of saved expressions to their saved variable name
utils::Hashmap<const ast::Expression*, Symbol, 8> saved_vars; utils::Hashmap<const ast::Expression*, Symbol, 8> saved_vars;
// Register the ast::Expression transform handler. bool needs_transform = false;
// This performs two different transformations:
// * Identifiers that resolve to the pointer-typed `let` declarations are
// replaced with the recursively inlined initializer expression for the
// `let` declaration.
// * Sub-expressions inside the pointer-typed `let` initializer expression
// that have been hoisted to a saved variable are replaced with the saved
// variable identifier.
ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
// Look to see if we need to swap this Expression with a saved variable.
if (auto saved_var = saved_vars.Find(expr)) {
return ctx.dst->Expr(*saved_var);
}
// Reduce the expression, folding away chains of address-of / indirections
auto op = Reduce(expr);
// Clone the reduced root expression
expr = ctx.CloneWithoutTransform(op.expr);
// And reapply the minimum number of address-of / indirections
for (int i = 0; i < op.indirections; i++) {
expr = ctx.dst->Deref(expr);
}
for (int i = 0; i > op.indirections; i--) {
expr = ctx.dst->AddressOf(expr);
}
return expr;
});
// Find all the pointer-typed `let` declarations. // Find all the pointer-typed `let` declarations.
// Note that these must be function-scoped, as module-scoped `let`s are not // Note that these must be function-scoped, as module-scoped `let`s are not
// permitted. // permitted.
for (auto* node : ctx.src->ASTNodes().Objects()) { for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* let = node->As<ast::VariableDeclStatement>()) { Switch(
node, //
[&](const ast::VariableDeclStatement* let) {
if (!let->variable->Is<ast::Let>()) { if (!let->variable->Is<ast::Let>()) {
continue; // Not a `let` declaration. Ignore. return; // Not a `let` declaration. Ignore.
} }
auto* var = ctx.src->Sem().Get(let->variable); auto* var = ctx.src->Sem().Get(let->variable);
if (!var->Type()->Is<type::Pointer>()) { if (!var->Type()->Is<type::Pointer>()) {
continue; // Not a pointer type. Ignore. return; // Not a pointer type. Ignore.
} }
// We're dealing with a pointer-typed `let` declaration. // We're dealing with a pointer-typed `let` declaration.
@ -185,7 +159,8 @@ struct SimplifyPointers::State {
// Create a new variable // Create a new variable
auto saved_name = ctx.dst->Symbols().New( auto saved_name = ctx.dst->Symbols().New(
ctx.src->Symbols().NameFor(var->Declaration()->symbol) + "_save"); ctx.src->Symbols().NameFor(var->Declaration()->symbol) + "_save");
auto* decl = ctx.dst->Decl(ctx.dst->Let(saved_name, ctx.Clone(idx_expr))); auto* decl =
ctx.dst->Decl(ctx.dst->Let(saved_name, ctx.Clone(idx_expr)));
saved.Push(decl); saved.Push(decl);
// Record the substitution of `idx_expr` to the saved variable // Record the substitution of `idx_expr` to the saved variable
// with the symbol `saved_name`. This will be used by the // with the symbol `saved_name`. This will be used by the
@ -221,9 +196,50 @@ struct SimplifyPointers::State {
// As the original `let` declaration will be fully inlined, there's no // As the original `let` declaration will be fully inlined, there's no
// need for the original declaration to exist. Remove it. // need for the original declaration to exist. Remove it.
RemoveStatement(ctx, let); RemoveStatement(ctx, let);
},
[&](const ast::UnaryOpExpression* op) {
if (op->op == ast::UnaryOp::kAddressOf) {
// Transform can be skipped if no address-of operator is used, as there
// will be no pointers that can be inlined.
needs_transform = true;
} }
});
} }
if (!needs_transform) {
return SkipTransform;
}
// Register the ast::Expression transform handler.
// This performs two different transformations:
// * Identifiers that resolve to the pointer-typed `let` declarations are
// replaced with the recursively inlined initializer expression for the
// `let` declaration.
// * Sub-expressions inside the pointer-typed `let` initializer expression
// that have been hoisted to a saved variable are replaced with the saved
// variable identifier.
ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
// Look to see if we need to swap this Expression with a saved variable.
if (auto saved_var = saved_vars.Find(expr)) {
return ctx.dst->Expr(*saved_var);
}
// Reduce the expression, folding away chains of address-of / indirections
auto op = Reduce(expr);
// Clone the reduced root expression
expr = ctx.CloneWithoutTransform(op.expr);
// And reapply the minimum number of address-of / indirections
for (int i = 0; i < op.indirections; i++) {
expr = ctx.dst->Deref(expr);
}
for (int i = 0; i > op.indirections; i--) {
expr = ctx.dst->AddressOf(expr);
}
return expr;
});
ctx.Clone(); ctx.Clone();
return Program(std::move(b)); return Program(std::move(b));
} }

View File

@ -24,11 +24,18 @@ using SimplifyPointersTest = TransformTest;
TEST_F(SimplifyPointersTest, EmptyModule) { TEST_F(SimplifyPointersTest, EmptyModule) {
auto* src = ""; auto* src = "";
auto* expect = "";
auto got = Run<Unshadow, SimplifyPointers>(src); EXPECT_FALSE(ShouldRun<SimplifyPointers>(src));
}
EXPECT_EQ(expect, str(got)); TEST_F(SimplifyPointersTest, NoAddressOf) {
auto* src = R"(
fn f(p : ptr<function, i32>) {
var v : i32;
}
)";
EXPECT_FALSE(ShouldRun<SimplifyPointers>(src));
} }
TEST_F(SimplifyPointersTest, FoldPointer) { TEST_F(SimplifyPointersTest, FoldPointer) {

View File

@ -16,13 +16,13 @@ void main_inner(uint local_invocation_index) {
{ {
for(uint idx = local_invocation_index; (idx < 4u); idx = (idx + 1u)) { for(uint idx = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
const uint i_1 = idx; const uint i_1 = idx;
const str tint_symbol_2 = (str)0; const str tint_symbol_3 = (str)0;
S[i_1] = tint_symbol_2; S[i_1] = tint_symbol_3;
} }
} }
GroupMemoryBarrierWithGroupSync(); GroupMemoryBarrierWithGroupSync();
const uint tint_symbol_3[1] = {2u}; const uint tint_symbol_2[1] = {2u};
const str r = func_S_X(tint_symbol_3); const str r = func_S_X(tint_symbol_2);
} }
[numthreads(1, 1, 1)] [numthreads(1, 1, 1)]

View File

@ -16,13 +16,13 @@ void main_inner(uint local_invocation_index) {
{ {
for(uint idx = local_invocation_index; (idx < 4u); idx = (idx + 1u)) { for(uint idx = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
const uint i_1 = idx; const uint i_1 = idx;
const str tint_symbol_2 = (str)0; const str tint_symbol_3 = (str)0;
S[i_1] = tint_symbol_2; S[i_1] = tint_symbol_3;
} }
} }
GroupMemoryBarrierWithGroupSync(); GroupMemoryBarrierWithGroupSync();
const uint tint_symbol_3[1] = {2u}; const uint tint_symbol_2[1] = {2u};
const str r = func_S_X(tint_symbol_3); const str r = func_S_X(tint_symbol_2);
} }
[numthreads(1, 1, 1)] [numthreads(1, 1, 1)]

View File

@ -17,13 +17,13 @@ void main_inner(uint local_invocation_index) {
{ {
for(uint idx = local_invocation_index; (idx < 4u); idx = (idx + 1u)) { for(uint idx = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
const uint i_1 = idx; const uint i_1 = idx;
const str tint_symbol_3 = (str)0; const str tint_symbol_4 = (str)0;
S[i_1] = tint_symbol_3; S[i_1] = tint_symbol_4;
} }
} }
GroupMemoryBarrierWithGroupSync(); GroupMemoryBarrierWithGroupSync();
const uint tint_symbol_4[1] = {2u}; const uint tint_symbol_3[1] = {2u};
func_S_X(tint_symbol_4); func_S_X(tint_symbol_3);
} }
[numthreads(1, 1, 1)] [numthreads(1, 1, 1)]

View File

@ -17,13 +17,13 @@ void main_inner(uint local_invocation_index) {
{ {
for(uint idx = local_invocation_index; (idx < 4u); idx = (idx + 1u)) { for(uint idx = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
const uint i_1 = idx; const uint i_1 = idx;
const str tint_symbol_3 = (str)0; const str tint_symbol_4 = (str)0;
S[i_1] = tint_symbol_3; S[i_1] = tint_symbol_4;
} }
} }
GroupMemoryBarrierWithGroupSync(); GroupMemoryBarrierWithGroupSync();
const uint tint_symbol_4[1] = {2u}; const uint tint_symbol_3[1] = {2u};
func_S_X(tint_symbol_4); func_S_X(tint_symbol_3);
} }
[numthreads(1, 1, 1)] [numthreads(1, 1, 1)]