reader/spirv: Decompose arrays with strides

Transform any SPIR-V that has an array with a custom stride:

  @stride(S) array<T, N>

into:

  struct strided_arr {
    @size(S) er : T;
  };
  array<strided_arr, N>

Also remove any @stride decorations that match the default array stride.

Bug: tint:1394
Bug: tint:1381
Change-Id: I8be8f3a76c5335fdb2bc5183388366091dbc7642
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/78781
Reviewed-by: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton
2022-02-04 15:39:34 +00:00
committed by Tint LUCI CQ
parent de857e1c58
commit 009d129103
28 changed files with 1499 additions and 161 deletions

View File

@@ -441,6 +441,8 @@ libtint_source_set("libtint_core_all_src") {
"transform/combine_samplers.h",
"transform/decompose_memory_access.cc",
"transform/decompose_memory_access.h",
"transform/decompose_strided_array.cc",
"transform/decompose_strided_array.h",
"transform/decompose_strided_matrix.cc",
"transform/decompose_strided_matrix.h",
"transform/external_texture_transform.cc",

View File

@@ -309,6 +309,8 @@ set(TINT_LIB_SRCS
transform/canonicalize_entry_point_io.h
transform/decompose_memory_access.cc
transform/decompose_memory_access.h
transform/decompose_strided_array.cc
transform/decompose_strided_array.h
transform/decompose_strided_matrix.cc
transform/decompose_strided_matrix.h
transform/external_texture_transform.cc
@@ -984,6 +986,7 @@ if(TINT_BUILD_TESTS)
transform/canonicalize_entry_point_io_test.cc
transform/combine_samplers_test.cc
transform/decompose_memory_access_test.cc
transform/decompose_strided_array_test.cc
transform/decompose_strided_matrix_test.cc
transform/external_texture_transform_test.cc
transform/first_index_offset_test.cc

View File

@@ -17,6 +17,7 @@
#include <utility>
#include "src/reader/spirv/parser_impl.h"
#include "src/transform/decompose_strided_array.h"
#include "src/transform/decompose_strided_matrix.h"
#include "src/transform/manager.h"
#include "src/transform/remove_unreachable_statements.h"
@@ -54,6 +55,7 @@ Program Parse(const std::vector<uint32_t>& input) {
manager.Add<transform::Unshadow>();
manager.Add<transform::SimplifyPointers>();
manager.Add<transform::DecomposeStridedMatrix>();
manager.Add<transform::DecomposeStridedArray>();
manager.Add<transform::RemoveUnreachableStatements>();
return manager.Run(&program).program;
}

View File

@@ -21,6 +21,7 @@ namespace tint {
// Forward declarations
namespace ast {
class Array;
class CallExpression;
class Expression;
class ElseStatement;
@@ -60,6 +61,7 @@ class Variable;
/// rules will be used to infer the return type based on the argument type.
struct TypeMappings {
//! @cond Doxygen_Suppress
Array* operator()(ast::Array*);
Call* operator()(ast::CallExpression*);
Expression* operator()(ast::Expression*);
ElseStatement* operator()(ast::ElseStatement*);

View File

@@ -0,0 +1,162 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/transform/decompose_strided_array.h"
#include <unordered_map>
#include <utility>
#include <vector>
#include "src/program_builder.h"
#include "src/sem/call.h"
#include "src/sem/expression.h"
#include "src/sem/member_accessor_expression.h"
#include "src/sem/type_constructor.h"
#include "src/transform/simplify_pointers.h"
#include "src/utils/hash.h"
#include "src/utils/map.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeStridedArray);
namespace tint {
namespace transform {
namespace {
using DecomposedArrays = std::unordered_map<const sem::Array*, Symbol>;
} // namespace
DecomposeStridedArray::DecomposeStridedArray() = default;
DecomposeStridedArray::~DecomposeStridedArray() = default;
bool DecomposeStridedArray::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* ast = node->As<ast::Array>()) {
if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) {
return true;
}
}
}
return false;
}
void DecomposeStridedArray::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
const auto& sem = ctx.src->Sem();
static constexpr const char* kMemberName = "el";
// Maps an array type in the source program to the name of the struct wrapper
// type in the target program.
std::unordered_map<const sem::Array*, Symbol> decomposed;
// Find and replace all arrays with a @stride attribute with a array that has
// the @stride removed. If the source array stride does not match the natural
// stride for the array element type, then replace the array element type with
// a structure, holding a single field with a @size attribute equal to the
// array stride.
ctx.ReplaceAll([&](const ast::Array* ast) -> const ast::Array* {
if (auto* arr = sem.Get(ast)) {
if (!arr->IsStrideImplicit()) {
auto el_ty = utils::GetOrCreate(decomposed, arr, [&] {
auto name = ctx.dst->Symbols().New("strided_arr");
auto* member_ty = ctx.Clone(ast->type);
auto* member = ctx.dst->Member(kMemberName, member_ty,
{ctx.dst->MemberSize(arr->Stride())});
ctx.dst->Structure(name, {member});
return name;
});
auto* count = ctx.Clone(ast->count);
return ctx.dst->ty.array(ctx.dst->ty.type_name(el_ty), count);
}
if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) {
// Strip the @stride attribute
auto* ty = ctx.Clone(ast->type);
auto* count = ctx.Clone(ast->count);
return ctx.dst->ty.array(ty, count);
}
}
return nullptr;
});
// Find all array index-accessors expressions for arrays that have had their
// element changed to a single field structure. These expressions are adjusted
// to insert an additional member accessor for the single structure field.
// Example: `arr[i]` -> `arr[i].el`
ctx.ReplaceAll(
[&](const ast::IndexAccessorExpression* idx) -> const ast::Expression* {
if (auto* ty = ctx.src->TypeOf(idx->object)) {
if (auto* arr = ty->UnwrapRef()->As<sem::Array>()) {
if (!arr->IsStrideImplicit()) {
auto* expr = ctx.CloneWithoutTransform(idx);
return ctx.dst->MemberAccessor(expr, kMemberName);
}
}
}
return nullptr;
});
// Find all array type constructor expressions for array types that have had
// their element changed to a single field structure. These constructors are
// adjusted to wrap each of the arguments with an additional constructor for
// the new element structure type.
// Example:
// `@stride(32) array<i32, 3>(1, 2, 3)`
// ->
// `array<strided_arr, 3>(strided_arr(1), strided_arr(2), strided_arr(3))`
ctx.ReplaceAll(
[&](const ast::CallExpression* expr) -> const ast::Expression* {
if (!expr->args.empty()) {
if (auto* call = sem.Get(expr)) {
if (auto* ctor = call->Target()->As<sem::TypeConstructor>()) {
if (auto* arr = ctor->ReturnType()->As<sem::Array>()) {
// Begin by cloning the array constructor type or name
// If this is an unaliased array, this may add a new entry to
// decomposed.
// If this is an aliased array, decomposed should already be
// populated with any strided aliases.
ast::CallExpression::Target target;
if (expr->target.type) {
target.type = ctx.Clone(expr->target.type);
} else {
target.name = ctx.Clone(expr->target.name);
}
ast::ExpressionList args;
if (auto it = decomposed.find(arr); it != decomposed.end()) {
args.reserve(expr->args.size());
for (auto* arg : expr->args) {
args.emplace_back(
ctx.dst->Call(it->second, ctx.Clone(arg)));
}
} else {
args = ctx.Clone(expr->args);
}
return target.type ? ctx.dst->Construct(target.type, args)
: ctx.dst->Call(target.name, args);
}
}
}
}
return nullptr;
});
ctx.Clone();
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,61 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TRANSFORM_DECOMPOSE_STRIDED_ARRAY_H_
#define SRC_TRANSFORM_DECOMPOSE_STRIDED_ARRAY_H_
#include "src/transform/transform.h"
namespace tint {
namespace transform {
/// DecomposeStridedArray transforms replaces arrays with a non-default
/// `@stride` attribute with an array of structure elements, where the
/// structure contains a single field with an equivalent `@size` attribute.
/// `@stride` attributes on arrays that match the default stride are also
/// removed.
///
/// @note Depends on the following transforms to have been run first:
/// * SimplifyPointers
class DecomposeStridedArray
: public Castable<DecomposeStridedArray, Transform> {
public:
/// Constructor
DecomposeStridedArray();
/// Destructor
~DecomposeStridedArray() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TRANSFORM_DECOMPOSE_STRIDED_ARRAY_H_

View File

@@ -0,0 +1,698 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/transform/decompose_strided_array.h"
#include <memory>
#include <utility>
#include <vector>
#include "src/program_builder.h"
#include "src/transform/simplify_pointers.h"
#include "src/transform/test_helper.h"
#include "src/transform/unshadow.h"
namespace tint {
namespace transform {
namespace {
using DecomposeStridedArrayTest = TransformTest;
using f32 = ProgramBuilder::f32;
TEST_F(DecomposeStridedArrayTest, ShouldRunEmptyModule) {
ProgramBuilder b;
EXPECT_FALSE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
}
TEST_F(DecomposeStridedArrayTest, ShouldRunNonStridedArray) {
// var<private> arr : array<f32, 4>
ProgramBuilder b;
b.Global("arr", b.ty.array<f32, 4>(), ast::StorageClass::kPrivate);
EXPECT_FALSE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
}
TEST_F(DecomposeStridedArrayTest, ShouldRunDefaultStridedArray) {
// var<private> arr : @stride(4) array<f32, 4>
ProgramBuilder b;
b.Global("arr", b.ty.array<f32, 4>(4), ast::StorageClass::kPrivate);
EXPECT_TRUE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
}
TEST_F(DecomposeStridedArrayTest, ShouldRunExplicitStridedArray) {
// var<private> arr : @stride(16) array<f32, 4>
ProgramBuilder b;
b.Global("arr", b.ty.array<f32, 4>(16), ast::StorageClass::kPrivate);
EXPECT_TRUE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
}
TEST_F(DecomposeStridedArrayTest, Empty) {
auto* src = R"()";
auto* expect = src;
auto got = Run<DecomposeStridedArray>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, PrivateDefaultStridedArray) {
// var<private> arr : @stride(4) array<f32, 4>
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let a : @stride(4) array<f32, 4> = a;
// let b : f32 = arr[1];
// }
ProgramBuilder b;
b.Global("arr", b.ty.array<f32, 4>(4), ast::StorageClass::kPrivate);
b.Func("f", {}, b.ty.void_(),
{
b.Decl(b.Const("a", b.ty.array<f32, 4>(4), b.Expr("arr"))),
b.Decl(b.Const("b", b.ty.f32(), b.IndexAccessor("arr", 1))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
var<private> arr : array<f32, 4>;
@stage(compute) @workgroup_size(1)
fn f() {
let a : array<f32, 4> = arr;
let b : f32 = arr[1];
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, PrivateStridedArray) {
// var<private> arr : @stride(32) array<f32, 4>
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let a : @stride(32) array<f32, 4> = a;
// let b : f32 = arr[1];
// }
ProgramBuilder b;
b.Global("arr", b.ty.array<f32, 4>(32), ast::StorageClass::kPrivate);
b.Func("f", {}, b.ty.void_(),
{
b.Decl(b.Const("a", b.ty.array<f32, 4>(32), b.Expr("arr"))),
b.Decl(b.Const("b", b.ty.f32(), b.IndexAccessor("arr", 1))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
struct strided_arr {
@size(32)
el : f32;
}
var<private> arr : array<strided_arr, 4>;
@stage(compute) @workgroup_size(1)
fn f() {
let a : array<strided_arr, 4> = arr;
let b : f32 = arr[1].el;
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, ReadUniformStridedArray) {
// struct S {
// a : @stride(32) array<f32, 4>;
// };
// @group(0) @binding(0) var<uniform> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let a : @stride(32) array<f32, 4> = s.a;
// let b : f32 = s.a[1];
// }
ProgramBuilder b;
auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(32))});
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
b.GroupAndBinding(0, 0));
b.Func("f", {}, b.ty.void_(),
{
b.Decl(b.Const("a", b.ty.array<f32, 4>(32),
b.MemberAccessor("s", "a"))),
b.Decl(b.Const("b", b.ty.f32(),
b.IndexAccessor(b.MemberAccessor("s", "a"), 1))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
struct strided_arr {
@size(32)
el : f32;
}
struct S {
a : array<strided_arr, 4>;
}
@group(0) @binding(0) var<uniform> s : S;
@stage(compute) @workgroup_size(1)
fn f() {
let a : array<strided_arr, 4> = s.a;
let b : f32 = s.a[1].el;
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, ReadUniformDefaultStridedArray) {
// struct S {
// a : @stride(16) array<vec4<f32>, 4>;
// };
// @group(0) @binding(0) var<uniform> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let a : @stride(16) array<vec4<f32>, 4> = s.a;
// let b : f32 = s.a[1][2];
// }
ProgramBuilder b;
auto* S =
b.Structure("S", {b.Member("a", b.ty.array(b.ty.vec4<f32>(), 4, 16))});
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
b.GroupAndBinding(0, 0));
b.Func("f", {}, b.ty.void_(),
{
b.Decl(b.Const("a", b.ty.array(b.ty.vec4<f32>(), 4, 16),
b.MemberAccessor("s", "a"))),
b.Decl(b.Const(
"b", b.ty.f32(),
b.IndexAccessor(b.IndexAccessor(b.MemberAccessor("s", "a"), 1),
2))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect =
R"(
struct S {
a : array<vec4<f32>, 4>;
}
@group(0) @binding(0) var<uniform> s : S;
@stage(compute) @workgroup_size(1)
fn f() {
let a : array<vec4<f32>, 4> = s.a;
let b : f32 = s.a[1][2];
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, ReadStorageStridedArray) {
// struct S {
// a : @stride(32) array<f32, 4>;
// };
// @group(0) @binding(0) var<storage> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let a : @stride(32) array<f32, 4> = s.a;
// let b : f32 = s.a[1];
// }
ProgramBuilder b;
auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(32))});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
b.GroupAndBinding(0, 0));
b.Func("f", {}, b.ty.void_(),
{
b.Decl(b.Const("a", b.ty.array<f32, 4>(32),
b.MemberAccessor("s", "a"))),
b.Decl(b.Const("b", b.ty.f32(),
b.IndexAccessor(b.MemberAccessor("s", "a"), 1))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
struct strided_arr {
@size(32)
el : f32;
}
struct S {
a : array<strided_arr, 4>;
}
@group(0) @binding(0) var<storage> s : S;
@stage(compute) @workgroup_size(1)
fn f() {
let a : array<strided_arr, 4> = s.a;
let b : f32 = s.a[1].el;
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, ReadStorageDefaultStridedArray) {
// struct S {
// a : @stride(4) array<f32, 4>;
// };
// @group(0) @binding(0) var<storage> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let a : @stride(4) array<f32, 4> = s.a;
// let b : f32 = s.a[1];
// }
ProgramBuilder b;
auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(4))});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
b.GroupAndBinding(0, 0));
b.Func("f", {}, b.ty.void_(),
{
b.Decl(b.Const("a", b.ty.array<f32, 4>(4),
b.MemberAccessor("s", "a"))),
b.Decl(b.Const("b", b.ty.f32(),
b.IndexAccessor(b.MemberAccessor("s", "a"), 1))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
struct S {
a : array<f32, 4>;
}
@group(0) @binding(0) var<storage> s : S;
@stage(compute) @workgroup_size(1)
fn f() {
let a : array<f32, 4> = s.a;
let b : f32 = s.a[1];
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, WriteStorageStridedArray) {
// struct S {
// a : @stride(32) array<f32, 4>;
// };
// @group(0) @binding(0) var<storage, read_write> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// s.a = @stride(32) array<f32, 4>();
// s.a = @stride(32) array<f32, 4>(1.0, 2.0, 3.0, 4.0);
// s.a[1] = 5.0;
// }
ProgramBuilder b;
auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(32))});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func(
"f", {}, b.ty.void_(),
{
b.Assign(b.MemberAccessor("s", "a"),
b.Construct(b.ty.array<f32, 4>(32))),
b.Assign(b.MemberAccessor("s", "a"),
b.Construct(b.ty.array<f32, 4>(32), 1.0f, 2.0f, 3.0f, 4.0f)),
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1), 5.0f),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect =
R"(
struct strided_arr {
@size(32)
el : f32;
}
struct S {
a : array<strided_arr, 4>;
}
@group(0) @binding(0) var<storage, read_write> s : S;
@stage(compute) @workgroup_size(1)
fn f() {
s.a = array<strided_arr, 4>();
s.a = array<strided_arr, 4>(strided_arr(1.0), strided_arr(2.0), strided_arr(3.0), strided_arr(4.0));
s.a[1].el = 5.0;
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, WriteStorageDefaultStridedArray) {
// struct S {
// a : @stride(4) array<f32, 4>;
// };
// @group(0) @binding(0) var<storage, read_write> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// s.a = @stride(4) array<f32, 4>();
// s.a = @stride(4) array<f32, 4>(1.0, 2.0, 3.0, 4.0);
// s.a[1] = 5.0;
// }
ProgramBuilder b;
auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(4))});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func(
"f", {}, b.ty.void_(),
{
b.Assign(b.MemberAccessor("s", "a"),
b.Construct(b.ty.array<f32, 4>(4))),
b.Assign(b.MemberAccessor("s", "a"),
b.Construct(b.ty.array<f32, 4>(4), 1.0f, 2.0f, 3.0f, 4.0f)),
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1), 5.0f),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect =
R"(
struct S {
a : array<f32, 4>;
}
@group(0) @binding(0) var<storage, read_write> s : S;
@stage(compute) @workgroup_size(1)
fn f() {
s.a = array<f32, 4>();
s.a = array<f32, 4>(1.0, 2.0, 3.0, 4.0);
s.a[1] = 5.0;
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, ReadWriteViaPointerLets) {
// struct S {
// a : @stride(32) array<f32, 4>;
// };
// @group(0) @binding(0) var<storage, read_write> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let a = &s.a;
// let b = &*&*(a);
// let c = *b;
// let d = (*b)[1];
// (*b) = @stride(32) array<f32, 4>(1.0, 2.0, 3.0, 4.0);
// (*b)[1] = 5.0;
// }
ProgramBuilder b;
auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(32))});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func("f", {}, b.ty.void_(),
{
b.Decl(b.Const("a", nullptr,
b.AddressOf(b.MemberAccessor("s", "a")))),
b.Decl(b.Const("b", nullptr,
b.AddressOf(b.Deref(b.AddressOf(b.Deref("a")))))),
b.Decl(b.Const("c", nullptr, b.Deref("b"))),
b.Decl(b.Const("d", nullptr, b.IndexAccessor(b.Deref("b"), 1))),
b.Assign(b.Deref("b"), b.Construct(b.ty.array<f32, 4>(32), 1.0f,
2.0f, 3.0f, 4.0f)),
b.Assign(b.IndexAccessor(b.Deref("b"), 1), 5.0f),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect =
R"(
struct strided_arr {
@size(32)
el : f32;
}
struct S {
a : array<strided_arr, 4>;
}
@group(0) @binding(0) var<storage, read_write> s : S;
@stage(compute) @workgroup_size(1)
fn f() {
let c = s.a;
let d = s.a[1].el;
s.a = array<strided_arr, 4>(strided_arr(1.0), strided_arr(2.0), strided_arr(3.0), strided_arr(4.0));
s.a[1].el = 5.0;
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, PrivateAliasedStridedArray) {
// type ARR = @stride(32) array<f32, 4>;
// struct S {
// a : ARR;
// };
// @group(0) @binding(0) var<storage, read_write> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let a : ARR = s.a;
// let b : f32 = s.a[1];
// s.a = ARR();
// s.a = ARR(1.0, 2.0, 3.0, 4.0);
// s.a[1] = 5.0;
// }
ProgramBuilder b;
b.Alias("ARR", b.ty.array<f32, 4>(32));
auto* S = b.Structure("S", {b.Member("a", b.ty.type_name("ARR"))});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func(
"f", {}, b.ty.void_(),
{
b.Decl(
b.Const("a", b.ty.type_name("ARR"), b.MemberAccessor("s", "a"))),
b.Decl(b.Const("b", b.ty.f32(),
b.IndexAccessor(b.MemberAccessor("s", "a"), 1))),
b.Assign(b.MemberAccessor("s", "a"),
b.Construct(b.ty.type_name("ARR"))),
b.Assign(b.MemberAccessor("s", "a"),
b.Construct(b.ty.type_name("ARR"), 1.0f, 2.0f, 3.0f, 4.0f)),
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1), 5.0f),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
struct strided_arr {
@size(32)
el : f32;
}
type ARR = array<strided_arr, 4>;
struct S {
a : ARR;
}
@group(0) @binding(0) var<storage, read_write> s : S;
@stage(compute) @workgroup_size(1)
fn f() {
let a : ARR = s.a;
let b : f32 = s.a[1].el;
s.a = ARR();
s.a = ARR(strided_arr(1.0), strided_arr(2.0), strided_arr(3.0), strided_arr(4.0));
s.a[1].el = 5.0;
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, PrivateNestedStridedArray) {
// type ARR_A = @stride(8) array<f32, 2>;
// type ARR_B = @stride(128) array<@stride(16) array<ARR_A, 3>, 4>;
// struct S {
// a : ARR_B;
// };
// @group(0) @binding(0) var<storage, read_write> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let a : ARR_B = s.a;
// let b : array<@stride(8) array<f32, 2>, 3> = s.a[3];
// let c = s.a[3][2];
// let d = s.a[3][2][1];
// s.a = ARR_B();
// s.a[3][2][1] = 5.0;
// }
ProgramBuilder b;
b.Alias("ARR_A", b.ty.array<f32, 2>(8));
b.Alias("ARR_B",
b.ty.array( //
b.ty.array(b.ty.type_name("ARR_A"), 3, 16), //
4, 128));
auto* S = b.Structure("S", {b.Member("a", b.ty.type_name("ARR_B"))});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func("f", {}, b.ty.void_(),
{
b.Decl(b.Const("a", b.ty.type_name("ARR_B"),
b.MemberAccessor("s", "a"))),
b.Decl(b.Const("b", b.ty.array(b.ty.type_name("ARR_A"), 3, 16),
b.IndexAccessor( //
b.MemberAccessor("s", "a"), //
3))),
b.Decl(b.Const("c", b.ty.type_name("ARR_A"),
b.IndexAccessor( //
b.IndexAccessor( //
b.MemberAccessor("s", "a"), //
3),
2))),
b.Decl(b.Const("d", b.ty.f32(),
b.IndexAccessor( //
b.IndexAccessor( //
b.IndexAccessor( //
b.MemberAccessor("s", "a"), //
3),
2),
1))),
b.Assign(b.MemberAccessor("s", "a"),
b.Construct(b.ty.type_name("ARR_B"))),
b.Assign(b.IndexAccessor( //
b.IndexAccessor( //
b.IndexAccessor( //
b.MemberAccessor("s", "a"), //
3),
2),
1),
5.0f),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect =
R"(
struct strided_arr {
@size(8)
el : f32;
}
type ARR_A = array<strided_arr, 2>;
struct strided_arr_1 {
@size(128)
el : array<ARR_A, 3>;
}
type ARR_B = array<strided_arr_1, 4>;
struct S {
a : ARR_B;
}
@group(0) @binding(0) var<storage, read_write> s : S;
@stage(compute) @workgroup_size(1)
fn f() {
let a : ARR_B = s.a;
let b : array<ARR_A, 3> = s.a[3].el;
let c : ARR_A = s.a[3].el[2];
let d : f32 = s.a[3].el[2][1].el;
s.a = ARR_B();
s.a[3].el[2][1].el = 5.0;
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -81,6 +81,15 @@ class TransformTestBase : public BASE {
return manager.Run(&program, data);
}
/// @param program the input program
/// @param data the optional DataMap to pass to Transform::Run()
/// @return true if the transform should be run for the given input.
template <typename TRANSFORM>
bool ShouldRun(Program&& program, const DataMap& data = {}) {
EXPECT_TRUE(program.IsValid()) << program.Diagnostics().str();
return TRANSFORM().ShouldRun(&program, data);
}
/// @param in the input WGSL source
/// @param data the optional DataMap to pass to Transform::Run()
/// @return true if the transform should be run for the given input.
@@ -88,8 +97,7 @@ class TransformTestBase : public BASE {
bool ShouldRun(std::string in, const DataMap& data = {}) {
auto file = std::make_unique<Source::File>("test", in);
auto program = reader::wgsl::Parse(file.get());
EXPECT_TRUE(program.IsValid()) << program.Diagnostics().str();
return TRANSFORM().ShouldRun(&program, data);
return ShouldRun<TRANSFORM>(std::move(program), data);
}
/// @param output the output of the transform