// Copyright 2022 The Tint Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "src/tint/transform/decompose_strided_array.h" #include #include #include #include "src/tint/program_builder.h" #include "src/tint/transform/simplify_pointers.h" #include "src/tint/transform/test_helper.h" #include "src/tint/transform/unshadow.h" using namespace tint::number_suffixes; // NOLINT namespace tint::transform { namespace { using DecomposeStridedArrayTest = TransformTest; TEST_F(DecomposeStridedArrayTest, ShouldRunEmptyModule) { ProgramBuilder b; EXPECT_FALSE(ShouldRun(Program(std::move(b)))); } TEST_F(DecomposeStridedArrayTest, ShouldRunNonStridedArray) { // var arr : array ProgramBuilder b; b.Global("arr", b.ty.array(), ast::StorageClass::kPrivate); EXPECT_FALSE(ShouldRun(Program(std::move(b)))); } TEST_F(DecomposeStridedArrayTest, ShouldRunDefaultStridedArray) { // var arr : @stride(4) array ProgramBuilder b; b.Global("arr", b.ty.array(4), ast::StorageClass::kPrivate); EXPECT_TRUE(ShouldRun(Program(std::move(b)))); } TEST_F(DecomposeStridedArrayTest, ShouldRunExplicitStridedArray) { // var arr : @stride(16) array ProgramBuilder b; b.Global("arr", b.ty.array(16), ast::StorageClass::kPrivate); EXPECT_TRUE(ShouldRun(Program(std::move(b)))); } TEST_F(DecomposeStridedArrayTest, Empty) { auto* src = R"()"; auto* expect = src; auto got = Run(src); EXPECT_EQ(expect, str(got)); } TEST_F(DecomposeStridedArrayTest, PrivateDefaultStridedArray) { // var arr : @stride(4) array // // @stage(compute) @workgroup_size(1) // fn f() { // let a : @stride(4) array = a; // let b : f32 = arr[1]; // } ProgramBuilder b; b.Global("arr", b.ty.array(4), ast::StorageClass::kPrivate); b.Func("f", {}, b.ty.void_(), { b.Decl(b.Let("a", b.ty.array(4), b.Expr("arr"))), b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor("arr", 1_i))), }, { b.Stage(ast::PipelineStage::kCompute), b.WorkgroupSize(1_i), }); auto* expect = R"( var arr : array; @stage(compute) @workgroup_size(1i) fn f() { let a : array = arr; let b : f32 = arr[1i]; } )"; auto got = Run(Program(std::move(b))); EXPECT_EQ(expect, str(got)); } TEST_F(DecomposeStridedArrayTest, PrivateStridedArray) { // var arr : @stride(32) array // // @stage(compute) @workgroup_size(1) // fn f() { // let a : @stride(32) array = a; // let b : f32 = arr[1]; // } ProgramBuilder b; b.Global("arr", b.ty.array(32), ast::StorageClass::kPrivate); b.Func("f", {}, b.ty.void_(), { b.Decl(b.Let("a", b.ty.array(32), b.Expr("arr"))), b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor("arr", 1_i))), }, { b.Stage(ast::PipelineStage::kCompute), b.WorkgroupSize(1_i), }); auto* expect = R"( struct strided_arr { @size(32) el : f32, } var arr : array; @stage(compute) @workgroup_size(1i) fn f() { let a : array = arr; let b : f32 = arr[1i].el; } )"; auto got = Run(Program(std::move(b))); EXPECT_EQ(expect, str(got)); } TEST_F(DecomposeStridedArrayTest, ReadUniformStridedArray) { // struct S { // a : @stride(32) array, // }; // @group(0) @binding(0) var s : S; // // @stage(compute) @workgroup_size(1) // fn f() { // let a : @stride(32) array = s.a; // let b : f32 = s.a[1]; // } ProgramBuilder b; auto* S = b.Structure("S", {b.Member("a", b.ty.array(32))}); b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform, b.GroupAndBinding(0, 0)); b.Func("f", {}, b.ty.void_(), { b.Decl(b.Let("a", b.ty.array(32), b.MemberAccessor("s", "a"))), b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i))), }, { b.Stage(ast::PipelineStage::kCompute), b.WorkgroupSize(1_i), }); auto* expect = R"( struct strided_arr { @size(32) el : f32, } struct S { a : array, } @group(0) @binding(0) var s : S; @stage(compute) @workgroup_size(1i) fn f() { let a : array = s.a; let b : f32 = s.a[1i].el; } )"; auto got = Run(Program(std::move(b))); EXPECT_EQ(expect, str(got)); } TEST_F(DecomposeStridedArrayTest, ReadUniformDefaultStridedArray) { // struct S { // a : @stride(16) array, 4u>, // }; // @group(0) @binding(0) var s : S; // // @stage(compute) @workgroup_size(1) // fn f() { // let a : @stride(16) array, 4u> = s.a; // let b : f32 = s.a[1][2]; // } ProgramBuilder b; auto* S = b.Structure("S", {b.Member("a", b.ty.array(b.ty.vec4(), 4_u, 16))}); b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform, b.GroupAndBinding(0, 0)); b.Func( "f", {}, b.ty.void_(), { b.Decl(b.Let("a", b.ty.array(b.ty.vec4(), 4_u, 16), b.MemberAccessor("s", "a"))), b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 2_i))), }, { b.Stage(ast::PipelineStage::kCompute), b.WorkgroupSize(1_i), }); auto* expect = R"( struct S { a : array, 4u>, } @group(0) @binding(0) var s : S; @stage(compute) @workgroup_size(1i) fn f() { let a : array, 4u> = s.a; let b : f32 = s.a[1i][2i]; } )"; auto got = Run(Program(std::move(b))); EXPECT_EQ(expect, str(got)); } TEST_F(DecomposeStridedArrayTest, ReadStorageStridedArray) { // struct S { // a : @stride(32) array, // }; // @group(0) @binding(0) var s : S; // // @stage(compute) @workgroup_size(1) // fn f() { // let a : @stride(32) array = s.a; // let b : f32 = s.a[1]; // } ProgramBuilder b; auto* S = b.Structure("S", {b.Member("a", b.ty.array(32))}); b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, b.GroupAndBinding(0, 0)); b.Func("f", {}, b.ty.void_(), { b.Decl(b.Let("a", b.ty.array(32), b.MemberAccessor("s", "a"))), b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i))), }, { b.Stage(ast::PipelineStage::kCompute), b.WorkgroupSize(1_i), }); auto* expect = R"( struct strided_arr { @size(32) el : f32, } struct S { a : array, } @group(0) @binding(0) var s : S; @stage(compute) @workgroup_size(1i) fn f() { let a : array = s.a; let b : f32 = s.a[1i].el; } )"; auto got = Run(Program(std::move(b))); EXPECT_EQ(expect, str(got)); } TEST_F(DecomposeStridedArrayTest, ReadStorageDefaultStridedArray) { // struct S { // a : @stride(4) array, // }; // @group(0) @binding(0) var s : S; // // @stage(compute) @workgroup_size(1) // fn f() { // let a : @stride(4) array = s.a; // let b : f32 = s.a[1]; // } ProgramBuilder b; auto* S = b.Structure("S", {b.Member("a", b.ty.array(4))}); b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, b.GroupAndBinding(0, 0)); b.Func("f", {}, b.ty.void_(), { b.Decl(b.Let("a", b.ty.array(4), b.MemberAccessor("s", "a"))), b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i))), }, { b.Stage(ast::PipelineStage::kCompute), b.WorkgroupSize(1_i), }); auto* expect = R"( struct S { a : array, } @group(0) @binding(0) var s : S; @stage(compute) @workgroup_size(1i) fn f() { let a : array = s.a; let b : f32 = s.a[1i]; } )"; auto got = Run(Program(std::move(b))); EXPECT_EQ(expect, str(got)); } TEST_F(DecomposeStridedArrayTest, WriteStorageStridedArray) { // struct S { // a : @stride(32) array, // }; // @group(0) @binding(0) var s : S; // // @stage(compute) @workgroup_size(1) // fn f() { // s.a = @stride(32) array(); // s.a = @stride(32) array(1.0, 2.0, 3.0, 4.0); // s.a[1i] = 5.0; // } ProgramBuilder b; auto* S = b.Structure("S", {b.Member("a", b.ty.array(32))}); b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, ast::Access::kReadWrite, b.GroupAndBinding(0, 0)); b.Func("f", {}, b.ty.void_(), { b.Assign(b.MemberAccessor("s", "a"), b.Construct(b.ty.array(32))), b.Assign(b.MemberAccessor("s", "a"), b.Construct(b.ty.array(32), 1_f, 2_f, 3_f, 4_f)), b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 5_f), }, { b.Stage(ast::PipelineStage::kCompute), b.WorkgroupSize(1_i), }); auto* expect = R"( struct strided_arr { @size(32) el : f32, } struct S { a : array, } @group(0) @binding(0) var s : S; @stage(compute) @workgroup_size(1i) fn f() { s.a = array(); s.a = array(strided_arr(1.0), strided_arr(2.0), strided_arr(3.0), strided_arr(4.0)); s.a[1i].el = 5.0; } )"; auto got = Run(Program(std::move(b))); EXPECT_EQ(expect, str(got)); } TEST_F(DecomposeStridedArrayTest, WriteStorageDefaultStridedArray) { // struct S { // a : @stride(4) array, // }; // @group(0) @binding(0) var s : S; // // @stage(compute) @workgroup_size(1) // fn f() { // s.a = @stride(4) array(); // s.a = @stride(4) array(1.0, 2.0, 3.0, 4.0); // s.a[1] = 5.0; // } ProgramBuilder b; auto* S = b.Structure("S", {b.Member("a", b.ty.array(4))}); b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, ast::Access::kReadWrite, b.GroupAndBinding(0, 0)); b.Func("f", {}, b.ty.void_(), { b.Assign(b.MemberAccessor("s", "a"), b.Construct(b.ty.array(4))), b.Assign(b.MemberAccessor("s", "a"), b.Construct(b.ty.array(4), 1_f, 2_f, 3_f, 4_f)), b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 5_f), }, { b.Stage(ast::PipelineStage::kCompute), b.WorkgroupSize(1_i), }); auto* expect = R"( struct S { a : array, } @group(0) @binding(0) var s : S; @stage(compute) @workgroup_size(1i) fn f() { s.a = array(); s.a = array(1.0, 2.0, 3.0, 4.0); s.a[1i] = 5.0; } )"; auto got = Run(Program(std::move(b))); EXPECT_EQ(expect, str(got)); } TEST_F(DecomposeStridedArrayTest, ReadWriteViaPointerLets) { // struct S { // a : @stride(32) array, // }; // @group(0) @binding(0) var s : S; // // @stage(compute) @workgroup_size(1) // fn f() { // let a = &s.a; // let b = &*&*(a); // let c = *b; // let d = (*b)[1]; // (*b) = @stride(32) array(1.0, 2.0, 3.0, 4.0); // (*b)[1] = 5.0; // } ProgramBuilder b; auto* S = b.Structure("S", {b.Member("a", b.ty.array(32))}); b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, ast::Access::kReadWrite, b.GroupAndBinding(0, 0)); b.Func("f", {}, b.ty.void_(), { b.Decl(b.Let("a", nullptr, b.AddressOf(b.MemberAccessor("s", "a")))), b.Decl(b.Let("b", nullptr, b.AddressOf(b.Deref(b.AddressOf(b.Deref("a")))))), b.Decl(b.Let("c", nullptr, b.Deref("b"))), b.Decl(b.Let("d", nullptr, b.IndexAccessor(b.Deref("b"), 1_i))), b.Assign(b.Deref("b"), b.Construct(b.ty.array(32), 1_f, 2_f, 3_f, 4_f)), b.Assign(b.IndexAccessor(b.Deref("b"), 1_i), 5_f), }, { b.Stage(ast::PipelineStage::kCompute), b.WorkgroupSize(1_i), }); auto* expect = R"( struct strided_arr { @size(32) el : f32, } struct S { a : array, } @group(0) @binding(0) var s : S; @stage(compute) @workgroup_size(1i) fn f() { let c = s.a; let d = s.a[1i].el; s.a = array(strided_arr(1.0), strided_arr(2.0), strided_arr(3.0), strided_arr(4.0)); s.a[1i].el = 5.0; } )"; auto got = Run(Program(std::move(b))); EXPECT_EQ(expect, str(got)); } TEST_F(DecomposeStridedArrayTest, PrivateAliasedStridedArray) { // type ARR = @stride(32) array; // struct S { // a : ARR, // }; // @group(0) @binding(0) var s : S; // // @stage(compute) @workgroup_size(1) // fn f() { // let a : ARR = s.a; // let b : f32 = s.a[1]; // s.a = ARR(); // s.a = ARR(1.0, 2.0, 3.0, 4.0); // s.a[1] = 5.0; // } ProgramBuilder b; b.Alias("ARR", b.ty.array(32)); auto* S = b.Structure("S", {b.Member("a", b.ty.type_name("ARR"))}); b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, ast::Access::kReadWrite, b.GroupAndBinding(0, 0)); b.Func("f", {}, b.ty.void_(), { b.Decl(b.Let("a", b.ty.type_name("ARR"), b.MemberAccessor("s", "a"))), b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i))), b.Assign(b.MemberAccessor("s", "a"), b.Construct(b.ty.type_name("ARR"))), b.Assign(b.MemberAccessor("s", "a"), b.Construct(b.ty.type_name("ARR"), 1_f, 2_f, 3_f, 4_f)), b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 5_f), }, { b.Stage(ast::PipelineStage::kCompute), b.WorkgroupSize(1_i), }); auto* expect = R"( struct strided_arr { @size(32) el : f32, } type ARR = array; struct S { a : ARR, } @group(0) @binding(0) var s : S; @stage(compute) @workgroup_size(1i) fn f() { let a : ARR = s.a; let b : f32 = s.a[1i].el; s.a = ARR(); s.a = ARR(strided_arr(1.0), strided_arr(2.0), strided_arr(3.0), strided_arr(4.0)); s.a[1i].el = 5.0; } )"; auto got = Run(Program(std::move(b))); EXPECT_EQ(expect, str(got)); } TEST_F(DecomposeStridedArrayTest, PrivateNestedStridedArray) { // type ARR_A = @stride(8) array; // type ARR_B = @stride(128) array<@stride(16) array, 4u>; // struct S { // a : ARR_B, // }; // @group(0) @binding(0) var s : S; // // @stage(compute) @workgroup_size(1) // fn f() { // let a : ARR_B = s.a; // let b : array<@stride(8) array, 3u> = s.a[3]; // let c = s.a[3][2]; // let d = s.a[3][2][1]; // s.a = ARR_B(); // s.a[3][2][1] = 5.0; // } ProgramBuilder b; b.Alias("ARR_A", b.ty.array(8)); b.Alias("ARR_B", b.ty.array( // b.ty.array(b.ty.type_name("ARR_A"), 3_u, 16), // 4_u, 128)); auto* S = b.Structure("S", {b.Member("a", b.ty.type_name("ARR_B"))}); b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, ast::Access::kReadWrite, b.GroupAndBinding(0, 0)); b.Func("f", {}, b.ty.void_(), { b.Decl(b.Let("a", b.ty.type_name("ARR_B"), b.MemberAccessor("s", "a"))), b.Decl(b.Let("b", b.ty.array(b.ty.type_name("ARR_A"), 3_u, 16), b.IndexAccessor( // b.MemberAccessor("s", "a"), // 3_i))), b.Decl(b.Let("c", b.ty.type_name("ARR_A"), b.IndexAccessor( // b.IndexAccessor( // b.MemberAccessor("s", "a"), // 3_i), 2_i))), b.Decl(b.Let("d", b.ty.f32(), b.IndexAccessor( // b.IndexAccessor( // b.IndexAccessor( // b.MemberAccessor("s", "a"), // 3_i), 2_i), 1_i))), b.Assign(b.MemberAccessor("s", "a"), b.Construct(b.ty.type_name("ARR_B"))), b.Assign(b.IndexAccessor( // b.IndexAccessor( // b.IndexAccessor( // b.MemberAccessor("s", "a"), // 3_i), 2_i), 1_i), 5_f), }, { b.Stage(ast::PipelineStage::kCompute), b.WorkgroupSize(1_i), }); auto* expect = R"( struct strided_arr { @size(8) el : f32, } type ARR_A = array; struct strided_arr_1 { @size(128) el : array, } type ARR_B = array; struct S { a : ARR_B, } @group(0) @binding(0) var s : S; @stage(compute) @workgroup_size(1i) fn f() { let a : ARR_B = s.a; let b : array = s.a[3i].el; let c : ARR_A = s.a[3i].el[2i]; let d : f32 = s.a[3i].el[2i][1i].el; s.a = ARR_B(); s.a[3i].el[2i][1i].el = 5.0; } )"; auto got = Run(Program(std::move(b))); EXPECT_EQ(expect, str(got)); } } // namespace } // namespace tint::transform