From b9b6e696313b1a9b101231948ca75c3b1b3fcfb8 Mon Sep 17 00:00:00 2001 From: James Price Date: Thu, 31 Mar 2022 22:30:10 +0000 Subject: [PATCH] Add ExpandCompoundAssignment transform This transform converts compound assignment statements into regular assignments, hoisting LHS expressions and converting for-loops and else-if statements if necessary. The vector-component case needs particular care, as we cannot take the address of a vector component. We need to capture a pointer to the whole vector and also the component index expression: // Before vector_array[foo()][bar()] *= 2.0; // After: let _vec = &vector_array[foo()]; let _idx = bar(); (*_vec)[_idx] = (*_vec)[_idx] * 2.0; Bug: tint:1325 Change-Id: I8b9b31fc9ac4b3697f954100ceb4be24d063bca6 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/85282 Reviewed-by: Antonio Maiorano Kokoro: Kokoro --- src/tint/BUILD.gn | 2 + src/tint/CMakeLists.txt | 3 + .../transform/expand_compound_assignment.cc | 149 ++++++ .../transform/expand_compound_assignment.h | 68 +++ .../expand_compound_assignment_test.cc | 457 ++++++++++++++++++ test/tint/BUILD.gn | 1 + 6 files changed, 680 insertions(+) create mode 100644 src/tint/transform/expand_compound_assignment.cc create mode 100644 src/tint/transform/expand_compound_assignment.h create mode 100644 src/tint/transform/expand_compound_assignment_test.cc diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index 7a5ca93279..6c525c858a 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -457,6 +457,8 @@ libtint_source_set("libtint_core_all_src") { "transform/fold_trivial_single_use_lets.h", "transform/for_loop_to_loop.cc", "transform/for_loop_to_loop.h", + "transform/expand_compound_assignment.cc", + "transform/expand_compound_assignment.h", "transform/localize_struct_array_assignment.cc", "transform/localize_struct_array_assignment.h", "transform/loop_to_for_loop.cc", diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index 731c19f73e..207b988821 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -335,6 +335,8 @@ set(TINT_LIB_SRCS transform/localize_struct_array_assignment.h transform/for_loop_to_loop.cc transform/for_loop_to_loop.h + transform/expand_compound_assignment.cc + transform/expand_compound_assignment.h transform/glsl.cc transform/glsl.h transform/loop_to_for_loop.cc @@ -1026,6 +1028,7 @@ if(TINT_BUILD_TESTS) transform/fold_constants_test.cc transform/fold_trivial_single_use_lets_test.cc transform/for_loop_to_loop_test.cc + transform/expand_compound_assignment.cc transform/localize_struct_array_assignment_test.cc transform/loop_to_for_loop_test.cc transform/module_scope_var_to_entry_point_param_test.cc diff --git a/src/tint/transform/expand_compound_assignment.cc b/src/tint/transform/expand_compound_assignment.cc new file mode 100644 index 0000000000..7c967d3624 --- /dev/null +++ b/src/tint/transform/expand_compound_assignment.cc @@ -0,0 +1,149 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/transform/expand_compound_assignment.h" + +#include + +#include "src/tint/ast/compound_assignment_statement.h" +#include "src/tint/program_builder.h" +#include "src/tint/sem/block_statement.h" +#include "src/tint/sem/expression.h" +#include "src/tint/sem/for_loop_statement.h" +#include "src/tint/sem/statement.h" +#include "src/tint/transform/utils/hoist_to_decl_before.h" + +TINT_INSTANTIATE_TYPEINFO(tint::transform::ExpandCompoundAssignment); + +namespace tint { +namespace transform { + +ExpandCompoundAssignment::ExpandCompoundAssignment() = default; + +ExpandCompoundAssignment::~ExpandCompoundAssignment() = default; + +bool ExpandCompoundAssignment::ShouldRun(const Program* program, + const DataMap&) const { + for (auto* node : program->ASTNodes().Objects()) { + if (node->Is()) { + return true; + } + } + return false; +} + +void ExpandCompoundAssignment::Run(CloneContext& ctx, + const DataMap&, + DataMap&) const { + HoistToDeclBefore hoist_to_decl_before(ctx); + + for (auto* node : ctx.src->ASTNodes().Objects()) { + if (auto* assign = node->As()) { + auto* sem_assign = ctx.src->Sem().Get(assign); + + // Helper function to create the LHS expression. This will be called twice + // when building the non-compound assignment statement, so must not + // produce expressions that cause side effects. + std::function lhs; + + // Helper function to create a variable that is a pointer to `expr`. + auto hoist_pointer_to = [&](const ast::Expression* expr) { + auto name = ctx.dst->Sym(); + auto* ptr = ctx.dst->AddressOf(ctx.Clone(expr)); + auto* decl = ctx.dst->Decl(ctx.dst->Const(name, nullptr, ptr)); + hoist_to_decl_before.InsertBefore(sem_assign, decl); + return name; + }; + + // Helper function to hoist `expr` to a let declaration. + auto hoist_expr_to_let = [&](const ast::Expression* expr) { + auto name = ctx.dst->Sym(); + auto* decl = + ctx.dst->Decl(ctx.dst->Const(name, nullptr, ctx.Clone(expr))); + hoist_to_decl_before.InsertBefore(sem_assign, decl); + return name; + }; + + // Helper function that returns `true` if the type of `expr` is a vector. + auto is_vec = [&](const ast::Expression* expr) { + return ctx.src->Sem().Get(expr)->Type()->UnwrapRef()->Is(); + }; + + // Hoist the LHS expression subtree into local constants to produce a new + // LHS that we can evaluate twice. + // We need to special case compound assignments to vector components since + // we cannot take the address of a vector component. + auto* index_accessor = assign->lhs->As(); + auto* member_accessor = assign->lhs->As(); + if (assign->lhs->Is() || + (member_accessor && + member_accessor->structure->Is())) { + // This is the simple case with no side effects, so we can just use the + // original LHS expression directly. + // Before: + // foo.bar += rhs; + // After: + // foo.bar = foo.bar + rhs; + lhs = [&]() { return ctx.Clone(assign->lhs); }; + } else if (index_accessor && is_vec(index_accessor->object)) { + // This is the case for vector component via an array accessor. We need + // to capture a pointer to the vector and also the index value. + // Before: + // v[idx()] += rhs; + // After: + // let vec_ptr = &v; + // let index = idx(); + // (*vec_ptr)[index] = (*vec_ptr)[index] + rhs; + auto lhs_ptr = hoist_pointer_to(index_accessor->object); + auto index = hoist_expr_to_let(index_accessor->index); + lhs = [&, lhs_ptr, index]() { + return ctx.dst->IndexAccessor(ctx.dst->Deref(lhs_ptr), index); + }; + } else if (member_accessor && is_vec(member_accessor->structure)) { + // This is the case for vector component via a member accessor. We just + // need to capture a pointer to the vector. + // Before: + // a[idx()].y += rhs; + // After: + // let vec_ptr = &a[idx()]; + // (*vec_ptr).y = (*vec_ptr).y + rhs; + auto lhs_ptr = hoist_pointer_to(member_accessor->structure); + lhs = [&, lhs_ptr]() { + return ctx.dst->MemberAccessor(ctx.dst->Deref(lhs_ptr), + ctx.Clone(member_accessor->member)); + }; + } else { + // For all other statements that may have side-effecting expressions, we + // just need to capture a pointer to the whole LHS. + // Before: + // a[idx()] += rhs; + // After: + // let lhs_ptr = &a[idx()]; + // (*lhs_ptr) = (*lhs_ptr) + rhs; + auto lhs_ptr = hoist_pointer_to(assign->lhs); + lhs = [&, lhs_ptr]() { return ctx.dst->Deref(lhs_ptr); }; + } + + // Replace the compound assignment with a regular assignment. + auto* rhs = ctx.dst->create( + assign->op, lhs(), ctx.Clone(assign->rhs)); + ctx.Replace(assign, ctx.dst->Assign(lhs(), rhs)); + } + } + hoist_to_decl_before.Apply(); + ctx.Clone(); +} + +} // namespace transform +} // namespace tint diff --git a/src/tint/transform/expand_compound_assignment.h b/src/tint/transform/expand_compound_assignment.h new file mode 100644 index 0000000000..73b0c83a63 --- /dev/null +++ b/src/tint/transform/expand_compound_assignment.h @@ -0,0 +1,68 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TINT_TRANSFORM_EXPAND_COMPOUND_ASSIGNMENT_H_ +#define SRC_TINT_TRANSFORM_EXPAND_COMPOUND_ASSIGNMENT_H_ + +#include "src/tint/transform/transform.h" + +namespace tint { +namespace transform { + +/// Converts compound assignment statements to regular assignment statements, +/// hoisting the LHS expression if necessary. +/// +/// Before: +/// ``` +/// a += 1; +/// vector_array[foo()][bar()] *= 2.0; +/// ``` +/// +/// After: +/// ``` +/// a = a + 1; +/// let _vec = &vector_array[foo()]; +/// let _idx = bar(); +/// (*_vec)[_idx] = (*_vec)[_idx] * 2.0; +/// ``` +class ExpandCompoundAssignment + : public Castable { + public: + /// Constructor + ExpandCompoundAssignment(); + /// Destructor + ~ExpandCompoundAssignment() override; + + /// @param program the program to inspect + /// @param data optional extra transform-specific input data + /// @returns true if this transform should be run for the given program + bool ShouldRun(const Program* program, + const DataMap& data = {}) const override; + + protected: + /// Runs the transform using the CloneContext built for transforming a + /// program. Run() is responsible for calling Clone() on the CloneContext. + /// @param ctx the CloneContext primed with the input program and + /// ProgramBuilder + /// @param inputs optional extra transform-specific input data + /// @param outputs optional extra transform-specific output data + void Run(CloneContext& ctx, + const DataMap& inputs, + DataMap& outputs) const override; +}; + +} // namespace transform +} // namespace tint + +#endif // SRC_TINT_TRANSFORM_EXPAND_COMPOUND_ASSIGNMENT_H_ diff --git a/src/tint/transform/expand_compound_assignment_test.cc b/src/tint/transform/expand_compound_assignment_test.cc new file mode 100644 index 0000000000..4ad02d9a33 --- /dev/null +++ b/src/tint/transform/expand_compound_assignment_test.cc @@ -0,0 +1,457 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/transform/expand_compound_assignment.h" + +#include + +#include "src/tint/transform/test_helper.h" + +namespace tint { +namespace transform { +namespace { + +using ExpandCompoundAssignmentTest = TransformTest; + +TEST_F(ExpandCompoundAssignmentTest, ShouldRunEmptyModule) { + auto* src = R"()"; + + EXPECT_FALSE(ShouldRun(src)); +} + +TEST_F(ExpandCompoundAssignmentTest, ShouldRunHasCompoundAssignment) { + auto* src = R"( +fn foo() { + var v : i32; + v += 1; +} +)"; + + EXPECT_TRUE(ShouldRun(src)); +} + +TEST_F(ExpandCompoundAssignmentTest, Basic) { + auto* src = R"( +fn main() { + var v : i32; + v += 1; +} +)"; + + auto* expect = R"( +fn main() { + var v : i32; + v = (v + 1); +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ExpandCompoundAssignmentTest, LhsPointer) { + auto* src = R"( +fn main() { + var v : i32; + let p = &v; + *p += 1; +} +)"; + + auto* expect = R"( +fn main() { + var v : i32; + let p = &(v); + let tint_symbol = &(*(p)); + *(tint_symbol) = (*(tint_symbol) + 1); +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ExpandCompoundAssignmentTest, LhsStructMember) { + auto* src = R"( +struct S { + m : f32, +} + +fn main() { + var s : S; + s.m += 1.0; +} +)"; + + auto* expect = R"( +struct S { + m : f32, +} + +fn main() { + var s : S; + s.m = (s.m + 1.0); +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ExpandCompoundAssignmentTest, LhsArrayElement) { + auto* src = R"( +var a : array; + +fn idx() -> i32 { + a[1] = 42; + return 1; +} + +fn main() { + a[idx()] += 1; +} +)"; + + auto* expect = R"( +var a : array; + +fn idx() -> i32 { + a[1] = 42; + return 1; +} + +fn main() { + let tint_symbol = &(a[idx()]); + *(tint_symbol) = (*(tint_symbol) + 1); +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ExpandCompoundAssignmentTest, LhsVectorComponent_ArrayAccessor) { + auto* src = R"( +var v : vec4; + +fn idx() -> i32 { + v.y = 42; + return 1; +} + +fn main() { + v[idx()] += 1; +} +)"; + + auto* expect = R"( +var v : vec4; + +fn idx() -> i32 { + v.y = 42; + return 1; +} + +fn main() { + let tint_symbol = &(v); + let tint_symbol_1 = idx(); + (*(tint_symbol))[tint_symbol_1] = ((*(tint_symbol))[tint_symbol_1] + 1); +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ExpandCompoundAssignmentTest, LhsVectorComponent_MemberAccessor) { + auto* src = R"( +fn main() { + var v : vec4; + v.y += 1; +} +)"; + + auto* expect = R"( +fn main() { + var v : vec4; + v.y = (v.y + 1); +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ExpandCompoundAssignmentTest, LhsMatrixColumn) { + auto* src = R"( +var m : mat4x4; + +fn idx() -> i32 { + m[0].y = 42.0; + return 1; +} + +fn main() { + m[idx()] += 1.0; +} +)"; + + auto* expect = R"( +var m : mat4x4; + +fn idx() -> i32 { + m[0].y = 42.0; + return 1; +} + +fn main() { + let tint_symbol = &(m[idx()]); + *(tint_symbol) = (*(tint_symbol) + 1.0); +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ExpandCompoundAssignmentTest, LhsMatrixElement) { + auto* src = R"( +var m : mat4x4; + +fn idx1() -> i32 { + m[0].y = 42.0; + return 1; +} + +fn idx2() -> i32 { + m[1].z = 42.0; + return 1; +} + +fn main() { + m[idx1()][idx2()] += 1.0; +} +)"; + + auto* expect = R"( +var m : mat4x4; + +fn idx1() -> i32 { + m[0].y = 42.0; + return 1; +} + +fn idx2() -> i32 { + m[1].z = 42.0; + return 1; +} + +fn main() { + let tint_symbol = &(m[idx1()]); + let tint_symbol_1 = idx2(); + (*(tint_symbol))[tint_symbol_1] = ((*(tint_symbol))[tint_symbol_1] + 1.0); +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ExpandCompoundAssignmentTest, LhsMultipleSideEffects) { + auto* src = R"( +struct S { + a : array, 3>, +} + +@group(0) @binding(0) var buffer : array; + +var p : i32; + +fn idx1() -> i32 { + p += 1; + return 3; +} + +fn idx2() -> i32 { + p *= 3; + return 2; +} + +fn idx3() -> i32 { + p -= 2; + return 1; +} + +fn main() { + buffer[idx1()].a[idx2()][idx3()] += 1.0; +} +)"; + + auto* expect = R"( +struct S { + a : array, 3>, +} + +@group(0) @binding(0) var buffer : array; + +var p : i32; + +fn idx1() -> i32 { + p = (p + 1); + return 3; +} + +fn idx2() -> i32 { + p = (p * 3); + return 2; +} + +fn idx3() -> i32 { + p = (p - 2); + return 1; +} + +fn main() { + let tint_symbol = &(buffer[idx1()].a[idx2()]); + let tint_symbol_1 = idx3(); + (*(tint_symbol))[tint_symbol_1] = ((*(tint_symbol))[tint_symbol_1] + 1.0); +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ExpandCompoundAssignmentTest, ForLoopInit) { + auto* src = R"( +var a : array, 4>; + +var p : i32; + +fn idx1() -> i32 { + p = (p + 1); + return 3; +} + +fn idx2() -> i32 { + p = (p * 3); + return 2; +} + +fn main() { + for (a[idx1()][idx2()] += 1; ; ) { + break; + } +} +)"; + + auto* expect = R"( +var a : array, 4>; + +var p : i32; + +fn idx1() -> i32 { + p = (p + 1); + return 3; +} + +fn idx2() -> i32 { + p = (p * 3); + return 2; +} + +fn main() { + let tint_symbol = &(a[idx1()]); + let tint_symbol_1 = idx2(); + for((*(tint_symbol))[tint_symbol_1] = ((*(tint_symbol))[tint_symbol_1] + 1); ; ) { + break; + } +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ExpandCompoundAssignmentTest, ForLoopCont) { + auto* src = R"( +var a : array, 4>; + +var p : i32; + +fn idx1() -> i32 { + p = (p + 1); + return 3; +} + +fn idx2() -> i32 { + p = (p * 3); + return 2; +} + +fn main() { + for (; ; a[idx1()][idx2()] += 1) { + break; + } +} +)"; + + auto* expect = R"( +var a : array, 4>; + +var p : i32; + +fn idx1() -> i32 { + p = (p + 1); + return 3; +} + +fn idx2() -> i32 { + p = (p * 3); + return 2; +} + +fn main() { + loop { + { + break; + } + + continuing { + let tint_symbol = &(a[idx1()]); + let tint_symbol_1 = idx2(); + (*(tint_symbol))[tint_symbol_1] = ((*(tint_symbol))[tint_symbol_1] + 1); + } + } +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +} // namespace +} // namespace transform +} // namespace tint diff --git a/test/tint/BUILD.gn b/test/tint/BUILD.gn index 8d193dad17..b6bcdd49a8 100644 --- a/test/tint/BUILD.gn +++ b/test/tint/BUILD.gn @@ -324,6 +324,7 @@ tint_unittests_source_set("tint_unittests_transform_src") { "../../src/tint/transform/fold_constants_test.cc", "../../src/tint/transform/fold_trivial_single_use_lets_test.cc", "../../src/tint/transform/for_loop_to_loop_test.cc", + "../../src/tint/transform/expand_compound_assignment_test.cc", "../../src/tint/transform/localize_struct_array_assignment_test.cc", "../../src/tint/transform/loop_to_for_loop_test.cc", "../../src/tint/transform/module_scope_var_to_entry_point_param_test.cc",