tint/transform: Skip SimplifyPointers if possible

Change-Id: Id937d82e9062cf7a4c54401121ed6d22e5d4fd73
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/116870
Auto-Submit: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ben Clayton 2023-01-12 18:29:07 +00:00 committed by Dawn LUCI CQ
parent 3d6c263446
commit 42363a5b18
6 changed files with 106 additions and 83 deletions

View File

@ -130,6 +130,86 @@ struct SimplifyPointers::State {
// A map of saved expressions to their saved variable name
utils::Hashmap<const ast::Expression*, Symbol, 8> saved_vars;
bool needs_transform = false;
// Find all the pointer-typed `let` declarations.
// Note that these must be function-scoped, as module-scoped `let`s are not
// permitted.
for (auto* node : ctx.src->ASTNodes().Objects()) {
Switch(
node, //
[&](const ast::VariableDeclStatement* let) {
if (!let->variable->Is<ast::Let>()) {
return; // Not a `let` declaration. Ignore.
}
auto* var = ctx.src->Sem().Get(let->variable);
if (!var->Type()->Is<type::Pointer>()) {
return; // Not a pointer type. Ignore.
}
// We're dealing with a pointer-typed `let` declaration.
// Scan the initializer expression for array index expressions that need
// to be hoist to temporary "saved" variables.
utils::Vector<const ast::VariableDeclStatement*, 8> saved;
CollectSavedArrayIndices(
var->Declaration()->initializer, [&](const ast::Expression* idx_expr) {
// We have a sub-expression that needs to be saved.
// Create a new variable
auto saved_name = ctx.dst->Symbols().New(
ctx.src->Symbols().NameFor(var->Declaration()->symbol) + "_save");
auto* decl =
ctx.dst->Decl(ctx.dst->Let(saved_name, ctx.Clone(idx_expr)));
saved.Push(decl);
// Record the substitution of `idx_expr` to the saved variable
// with the symbol `saved_name`. This will be used by the
// ReplaceAll() handler above.
saved_vars.Add(idx_expr, saved_name);
});
// Find the place to insert the saved declarations.
// Special care needs to be made for lets declared as the initializer
// part of for-loops. In this case the block will hold the for-loop
// statement, not the let.
if (!saved.IsEmpty()) {
auto* stmt = ctx.src->Sem().Get(let);
auto* block = stmt->Block();
// Find the statement owned by the block (either the let decl or a
// for-loop)
while (block != stmt->Parent()) {
stmt = stmt->Parent();
}
// Declare the stored variables just before stmt. Order here is
// important as order-of-operations needs to be preserved.
// CollectSavedArrayIndices() visits the LHS of an index accessor
// before the index expression.
for (auto* decl : saved) {
// Note that repeated calls to InsertBefore() with the same `before`
// argument will result in nodes to inserted in the order the
// calls are made (last call is inserted last).
ctx.InsertBefore(block->Declaration()->statements, stmt->Declaration(),
decl);
}
}
// As the original `let` declaration will be fully inlined, there's no
// need for the original declaration to exist. Remove it.
RemoveStatement(ctx, let);
},
[&](const ast::UnaryOpExpression* op) {
if (op->op == ast::UnaryOp::kAddressOf) {
// Transform can be skipped if no address-of operator is used, as there
// will be no pointers that can be inlined.
needs_transform = true;
}
});
}
if (!needs_transform) {
return SkipTransform;
}
// Register the ast::Expression transform handler.
// This performs two different transformations:
// * Identifiers that resolve to the pointer-typed `let` declarations are
@ -160,70 +240,6 @@ struct SimplifyPointers::State {
return expr;
});
// Find all the pointer-typed `let` declarations.
// Note that these must be function-scoped, as module-scoped `let`s are not
// permitted.
for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* let = node->As<ast::VariableDeclStatement>()) {
if (!let->variable->Is<ast::Let>()) {
continue; // Not a `let` declaration. Ignore.
}
auto* var = ctx.src->Sem().Get(let->variable);
if (!var->Type()->Is<type::Pointer>()) {
continue; // Not a pointer type. Ignore.
}
// We're dealing with a pointer-typed `let` declaration.
// Scan the initializer expression for array index expressions that need
// to be hoist to temporary "saved" variables.
utils::Vector<const ast::VariableDeclStatement*, 8> saved;
CollectSavedArrayIndices(
var->Declaration()->initializer, [&](const ast::Expression* idx_expr) {
// We have a sub-expression that needs to be saved.
// Create a new variable
auto saved_name = ctx.dst->Symbols().New(
ctx.src->Symbols().NameFor(var->Declaration()->symbol) + "_save");
auto* decl = ctx.dst->Decl(ctx.dst->Let(saved_name, ctx.Clone(idx_expr)));
saved.Push(decl);
// Record the substitution of `idx_expr` to the saved variable
// with the symbol `saved_name`. This will be used by the
// ReplaceAll() handler above.
saved_vars.Add(idx_expr, saved_name);
});
// Find the place to insert the saved declarations.
// Special care needs to be made for lets declared as the initializer
// part of for-loops. In this case the block will hold the for-loop
// statement, not the let.
if (!saved.IsEmpty()) {
auto* stmt = ctx.src->Sem().Get(let);
auto* block = stmt->Block();
// Find the statement owned by the block (either the let decl or a
// for-loop)
while (block != stmt->Parent()) {
stmt = stmt->Parent();
}
// Declare the stored variables just before stmt. Order here is
// important as order-of-operations needs to be preserved.
// CollectSavedArrayIndices() visits the LHS of an index accessor
// before the index expression.
for (auto* decl : saved) {
// Note that repeated calls to InsertBefore() with the same `before`
// argument will result in nodes to inserted in the order the
// calls are made (last call is inserted last).
ctx.InsertBefore(block->Declaration()->statements, stmt->Declaration(),
decl);
}
}
// As the original `let` declaration will be fully inlined, there's no
// need for the original declaration to exist. Remove it.
RemoveStatement(ctx, let);
}
}
ctx.Clone();
return Program(std::move(b));
}

View File

@ -24,11 +24,18 @@ using SimplifyPointersTest = TransformTest;
TEST_F(SimplifyPointersTest, EmptyModule) {
auto* src = "";
auto* expect = "";
auto got = Run<Unshadow, SimplifyPointers>(src);
EXPECT_FALSE(ShouldRun<SimplifyPointers>(src));
}
EXPECT_EQ(expect, str(got));
TEST_F(SimplifyPointersTest, NoAddressOf) {
auto* src = R"(
fn f(p : ptr<function, i32>) {
var v : i32;
}
)";
EXPECT_FALSE(ShouldRun<SimplifyPointers>(src));
}
TEST_F(SimplifyPointersTest, FoldPointer) {

View File

@ -16,13 +16,13 @@ void main_inner(uint local_invocation_index) {
{
for(uint idx = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
const uint i_1 = idx;
const str tint_symbol_2 = (str)0;
S[i_1] = tint_symbol_2;
const str tint_symbol_3 = (str)0;
S[i_1] = tint_symbol_3;
}
}
GroupMemoryBarrierWithGroupSync();
const uint tint_symbol_3[1] = {2u};
const str r = func_S_X(tint_symbol_3);
const uint tint_symbol_2[1] = {2u};
const str r = func_S_X(tint_symbol_2);
}
[numthreads(1, 1, 1)]

View File

@ -16,13 +16,13 @@ void main_inner(uint local_invocation_index) {
{
for(uint idx = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
const uint i_1 = idx;
const str tint_symbol_2 = (str)0;
S[i_1] = tint_symbol_2;
const str tint_symbol_3 = (str)0;
S[i_1] = tint_symbol_3;
}
}
GroupMemoryBarrierWithGroupSync();
const uint tint_symbol_3[1] = {2u};
const str r = func_S_X(tint_symbol_3);
const uint tint_symbol_2[1] = {2u};
const str r = func_S_X(tint_symbol_2);
}
[numthreads(1, 1, 1)]

View File

@ -17,13 +17,13 @@ void main_inner(uint local_invocation_index) {
{
for(uint idx = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
const uint i_1 = idx;
const str tint_symbol_3 = (str)0;
S[i_1] = tint_symbol_3;
const str tint_symbol_4 = (str)0;
S[i_1] = tint_symbol_4;
}
}
GroupMemoryBarrierWithGroupSync();
const uint tint_symbol_4[1] = {2u};
func_S_X(tint_symbol_4);
const uint tint_symbol_3[1] = {2u};
func_S_X(tint_symbol_3);
}
[numthreads(1, 1, 1)]

View File

@ -17,13 +17,13 @@ void main_inner(uint local_invocation_index) {
{
for(uint idx = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
const uint i_1 = idx;
const str tint_symbol_3 = (str)0;
S[i_1] = tint_symbol_3;
const str tint_symbol_4 = (str)0;
S[i_1] = tint_symbol_4;
}
}
GroupMemoryBarrierWithGroupSync();
const uint tint_symbol_4[1] = {2u};
func_S_X(tint_symbol_4);
const uint tint_symbol_3[1] = {2u};
func_S_X(tint_symbol_3);
}
[numthreads(1, 1, 1)]