diff --git a/src/tint/transform/simplify_pointers.cc b/src/tint/transform/simplify_pointers.cc index 5d36c1dc01..9757ae60d3 100644 --- a/src/tint/transform/simplify_pointers.cc +++ b/src/tint/transform/simplify_pointers.cc @@ -130,6 +130,86 @@ struct SimplifyPointers::State { // A map of saved expressions to their saved variable name utils::Hashmap saved_vars; + bool needs_transform = false; + + // Find all the pointer-typed `let` declarations. + // Note that these must be function-scoped, as module-scoped `let`s are not + // permitted. + for (auto* node : ctx.src->ASTNodes().Objects()) { + Switch( + node, // + [&](const ast::VariableDeclStatement* let) { + if (!let->variable->Is()) { + return; // Not a `let` declaration. Ignore. + } + + auto* var = ctx.src->Sem().Get(let->variable); + if (!var->Type()->Is()) { + return; // Not a pointer type. Ignore. + } + + // We're dealing with a pointer-typed `let` declaration. + + // Scan the initializer expression for array index expressions that need + // to be hoist to temporary "saved" variables. + utils::Vector saved; + CollectSavedArrayIndices( + var->Declaration()->initializer, [&](const ast::Expression* idx_expr) { + // We have a sub-expression that needs to be saved. + // Create a new variable + auto saved_name = ctx.dst->Symbols().New( + ctx.src->Symbols().NameFor(var->Declaration()->symbol) + "_save"); + auto* decl = + ctx.dst->Decl(ctx.dst->Let(saved_name, ctx.Clone(idx_expr))); + saved.Push(decl); + // Record the substitution of `idx_expr` to the saved variable + // with the symbol `saved_name`. This will be used by the + // ReplaceAll() handler above. + saved_vars.Add(idx_expr, saved_name); + }); + + // Find the place to insert the saved declarations. + // Special care needs to be made for lets declared as the initializer + // part of for-loops. In this case the block will hold the for-loop + // statement, not the let. + if (!saved.IsEmpty()) { + auto* stmt = ctx.src->Sem().Get(let); + auto* block = stmt->Block(); + // Find the statement owned by the block (either the let decl or a + // for-loop) + while (block != stmt->Parent()) { + stmt = stmt->Parent(); + } + // Declare the stored variables just before stmt. Order here is + // important as order-of-operations needs to be preserved. + // CollectSavedArrayIndices() visits the LHS of an index accessor + // before the index expression. + for (auto* decl : saved) { + // Note that repeated calls to InsertBefore() with the same `before` + // argument will result in nodes to inserted in the order the + // calls are made (last call is inserted last). + ctx.InsertBefore(block->Declaration()->statements, stmt->Declaration(), + decl); + } + } + + // As the original `let` declaration will be fully inlined, there's no + // need for the original declaration to exist. Remove it. + RemoveStatement(ctx, let); + }, + [&](const ast::UnaryOpExpression* op) { + if (op->op == ast::UnaryOp::kAddressOf) { + // Transform can be skipped if no address-of operator is used, as there + // will be no pointers that can be inlined. + needs_transform = true; + } + }); + } + + if (!needs_transform) { + return SkipTransform; + } + // Register the ast::Expression transform handler. // This performs two different transformations: // * Identifiers that resolve to the pointer-typed `let` declarations are @@ -160,70 +240,6 @@ struct SimplifyPointers::State { return expr; }); - // Find all the pointer-typed `let` declarations. - // Note that these must be function-scoped, as module-scoped `let`s are not - // permitted. - for (auto* node : ctx.src->ASTNodes().Objects()) { - if (auto* let = node->As()) { - if (!let->variable->Is()) { - continue; // Not a `let` declaration. Ignore. - } - - auto* var = ctx.src->Sem().Get(let->variable); - if (!var->Type()->Is()) { - continue; // Not a pointer type. Ignore. - } - - // We're dealing with a pointer-typed `let` declaration. - - // Scan the initializer expression for array index expressions that need - // to be hoist to temporary "saved" variables. - utils::Vector saved; - CollectSavedArrayIndices( - var->Declaration()->initializer, [&](const ast::Expression* idx_expr) { - // We have a sub-expression that needs to be saved. - // Create a new variable - auto saved_name = ctx.dst->Symbols().New( - ctx.src->Symbols().NameFor(var->Declaration()->symbol) + "_save"); - auto* decl = ctx.dst->Decl(ctx.dst->Let(saved_name, ctx.Clone(idx_expr))); - saved.Push(decl); - // Record the substitution of `idx_expr` to the saved variable - // with the symbol `saved_name`. This will be used by the - // ReplaceAll() handler above. - saved_vars.Add(idx_expr, saved_name); - }); - - // Find the place to insert the saved declarations. - // Special care needs to be made for lets declared as the initializer - // part of for-loops. In this case the block will hold the for-loop - // statement, not the let. - if (!saved.IsEmpty()) { - auto* stmt = ctx.src->Sem().Get(let); - auto* block = stmt->Block(); - // Find the statement owned by the block (either the let decl or a - // for-loop) - while (block != stmt->Parent()) { - stmt = stmt->Parent(); - } - // Declare the stored variables just before stmt. Order here is - // important as order-of-operations needs to be preserved. - // CollectSavedArrayIndices() visits the LHS of an index accessor - // before the index expression. - for (auto* decl : saved) { - // Note that repeated calls to InsertBefore() with the same `before` - // argument will result in nodes to inserted in the order the - // calls are made (last call is inserted last). - ctx.InsertBefore(block->Declaration()->statements, stmt->Declaration(), - decl); - } - } - - // As the original `let` declaration will be fully inlined, there's no - // need for the original declaration to exist. Remove it. - RemoveStatement(ctx, let); - } - } - ctx.Clone(); return Program(std::move(b)); } diff --git a/src/tint/transform/simplify_pointers_test.cc b/src/tint/transform/simplify_pointers_test.cc index 9848ff31bb..b05dfc4843 100644 --- a/src/tint/transform/simplify_pointers_test.cc +++ b/src/tint/transform/simplify_pointers_test.cc @@ -24,11 +24,18 @@ using SimplifyPointersTest = TransformTest; TEST_F(SimplifyPointersTest, EmptyModule) { auto* src = ""; - auto* expect = ""; - auto got = Run(src); + EXPECT_FALSE(ShouldRun(src)); +} - EXPECT_EQ(expect, str(got)); +TEST_F(SimplifyPointersTest, NoAddressOf) { + auto* src = R"( +fn f(p : ptr) { + var v : i32; +} +)"; + + EXPECT_FALSE(ShouldRun(src)); } TEST_F(SimplifyPointersTest, FoldPointer) { diff --git a/test/tint/ptr_ref/load/param/workgroup/struct_in_array.wgsl.expected.dxc.hlsl b/test/tint/ptr_ref/load/param/workgroup/struct_in_array.wgsl.expected.dxc.hlsl index 1b5ca94809..38e8b07af1 100644 --- a/test/tint/ptr_ref/load/param/workgroup/struct_in_array.wgsl.expected.dxc.hlsl +++ b/test/tint/ptr_ref/load/param/workgroup/struct_in_array.wgsl.expected.dxc.hlsl @@ -16,13 +16,13 @@ void main_inner(uint local_invocation_index) { { for(uint idx = local_invocation_index; (idx < 4u); idx = (idx + 1u)) { const uint i_1 = idx; - const str tint_symbol_2 = (str)0; - S[i_1] = tint_symbol_2; + const str tint_symbol_3 = (str)0; + S[i_1] = tint_symbol_3; } } GroupMemoryBarrierWithGroupSync(); - const uint tint_symbol_3[1] = {2u}; - const str r = func_S_X(tint_symbol_3); + const uint tint_symbol_2[1] = {2u}; + const str r = func_S_X(tint_symbol_2); } [numthreads(1, 1, 1)] diff --git a/test/tint/ptr_ref/load/param/workgroup/struct_in_array.wgsl.expected.fxc.hlsl b/test/tint/ptr_ref/load/param/workgroup/struct_in_array.wgsl.expected.fxc.hlsl index 1b5ca94809..38e8b07af1 100644 --- a/test/tint/ptr_ref/load/param/workgroup/struct_in_array.wgsl.expected.fxc.hlsl +++ b/test/tint/ptr_ref/load/param/workgroup/struct_in_array.wgsl.expected.fxc.hlsl @@ -16,13 +16,13 @@ void main_inner(uint local_invocation_index) { { for(uint idx = local_invocation_index; (idx < 4u); idx = (idx + 1u)) { const uint i_1 = idx; - const str tint_symbol_2 = (str)0; - S[i_1] = tint_symbol_2; + const str tint_symbol_3 = (str)0; + S[i_1] = tint_symbol_3; } } GroupMemoryBarrierWithGroupSync(); - const uint tint_symbol_3[1] = {2u}; - const str r = func_S_X(tint_symbol_3); + const uint tint_symbol_2[1] = {2u}; + const str r = func_S_X(tint_symbol_2); } [numthreads(1, 1, 1)] diff --git a/test/tint/ptr_ref/store/param/workgroup/struct_in_array.wgsl.expected.dxc.hlsl b/test/tint/ptr_ref/store/param/workgroup/struct_in_array.wgsl.expected.dxc.hlsl index b3b81836fc..36be9e2df8 100644 --- a/test/tint/ptr_ref/store/param/workgroup/struct_in_array.wgsl.expected.dxc.hlsl +++ b/test/tint/ptr_ref/store/param/workgroup/struct_in_array.wgsl.expected.dxc.hlsl @@ -17,13 +17,13 @@ void main_inner(uint local_invocation_index) { { for(uint idx = local_invocation_index; (idx < 4u); idx = (idx + 1u)) { const uint i_1 = idx; - const str tint_symbol_3 = (str)0; - S[i_1] = tint_symbol_3; + const str tint_symbol_4 = (str)0; + S[i_1] = tint_symbol_4; } } GroupMemoryBarrierWithGroupSync(); - const uint tint_symbol_4[1] = {2u}; - func_S_X(tint_symbol_4); + const uint tint_symbol_3[1] = {2u}; + func_S_X(tint_symbol_3); } [numthreads(1, 1, 1)] diff --git a/test/tint/ptr_ref/store/param/workgroup/struct_in_array.wgsl.expected.fxc.hlsl b/test/tint/ptr_ref/store/param/workgroup/struct_in_array.wgsl.expected.fxc.hlsl index b3b81836fc..36be9e2df8 100644 --- a/test/tint/ptr_ref/store/param/workgroup/struct_in_array.wgsl.expected.fxc.hlsl +++ b/test/tint/ptr_ref/store/param/workgroup/struct_in_array.wgsl.expected.fxc.hlsl @@ -17,13 +17,13 @@ void main_inner(uint local_invocation_index) { { for(uint idx = local_invocation_index; (idx < 4u); idx = (idx + 1u)) { const uint i_1 = idx; - const str tint_symbol_3 = (str)0; - S[i_1] = tint_symbol_3; + const str tint_symbol_4 = (str)0; + S[i_1] = tint_symbol_4; } } GroupMemoryBarrierWithGroupSync(); - const uint tint_symbol_4[1] = {2u}; - func_S_X(tint_symbol_4); + const uint tint_symbol_3[1] = {2u}; + func_S_X(tint_symbol_3); } [numthreads(1, 1, 1)]