reader/spirv: Handle the MatrixStride decoration

Add `transform::DecomposeStridedMatrix`, which replaces matrix members of storage or uniform buffer structures, that have a [[stride]] decoration, into an array
of N column vectors.

This is required to correctly handle `mat2x2` matrices in UBOs, as std140 rules will expect a default stride of 16 bytes, when in WGSL the default structure layout expects a stride of 8 bytes.

Bug: tint:1047
Change-Id: If5ca3c6ec087bbc1ac31a8d9a657b99bf34042a4
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/59840
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ben Clayton
2021-07-27 08:17:29 +00:00
committed by Tint LUCI CQ
parent c6cbe3fda6
commit 97668c8c37
28 changed files with 1572 additions and 65 deletions

View File

@@ -0,0 +1,251 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/transform/decompose_strided_matrix.h"
#include <unordered_map>
#include <utility>
#include <vector>
#include "src/program_builder.h"
#include "src/sem/expression.h"
#include "src/sem/member_accessor_expression.h"
#include "src/transform/inline_pointer_lets.h"
#include "src/transform/simplify.h"
#include "src/utils/get_or_create.h"
#include "src/utils/hash.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeStridedMatrix);
namespace tint {
namespace transform {
namespace {
/// MatrixInfo describes a matrix member with a custom stride
struct MatrixInfo {
/// The stride in bytes between columns of the matrix
uint32_t stride = 0;
/// The type of the matrix
sem::Matrix const* matrix = nullptr;
/// @returns a new ast::Array that holds an vector column for each row of the
/// matrix.
ast::Array* array(ProgramBuilder* b) const {
return b->ty.array(b->ty.vec<ProgramBuilder::f32>(matrix->rows()),
matrix->columns(), stride);
}
/// Equality operator
bool operator==(const MatrixInfo& info) const {
return stride == info.stride && matrix == info.matrix;
}
/// Hash function
struct Hasher {
size_t operator()(const MatrixInfo& t) const {
return utils::Hash(t.stride, t.matrix);
}
};
};
/// Return type of the callback function of GatherCustomStrideMatrixMembers
enum GatherResult { kContinue, kStop };
/// GatherCustomStrideMatrixMembers scans `program` for all matrix members of
/// storage and uniform structs, which are of a matrix type, and have a custom
/// matrix stride attribute. For each matrix member found, `callback` is called.
/// `callback` is a function with the signature:
/// GatherResult(const sem::StructMember* member,
/// sem::Matrix* matrix,
/// uint32_t stride)
/// If `callback` return GatherResult::kStop, then the scanning will immediately
/// terminate, and GatherCustomStrideMatrixMembers() will return, otherwise
/// scanning will continue.
template <typename F>
void GatherCustomStrideMatrixMembers(const Program* program, F&& callback) {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* str = node->As<ast::Struct>()) {
auto* str_ty = program->Sem().Get(str);
if (!str_ty->UsedAs(ast::StorageClass::kUniform) &&
!str_ty->UsedAs(ast::StorageClass::kStorage)) {
continue;
}
for (auto* member : str_ty->Members()) {
auto* matrix = member->Type()->As<sem::Matrix>();
if (!matrix) {
continue;
}
auto* deco = ast::GetDecoration<ast::StrideDecoration>(
member->Declaration()->decorations());
if (!deco) {
continue;
}
uint32_t stride = deco->stride();
if (matrix->ColumnStride() == stride) {
continue;
}
if (callback(member, matrix, stride) == GatherResult::kStop) {
return;
}
}
}
}
}
} // namespace
DecomposeStridedMatrix::DecomposeStridedMatrix() = default;
DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
bool DecomposeStridedMatrix::ShouldRun(const Program* program) {
bool should_run = false;
GatherCustomStrideMatrixMembers(
program, [&](const sem::StructMember*, sem::Matrix*, uint32_t) {
should_run = true;
return GatherResult::kStop;
});
return should_run;
}
void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) {
if (!Requires<InlinePointerLets, Simplify>(ctx)) {
return;
}
// Scan the program for all storage and uniform structure matrix members with
// a custom stride attribute. Replace these matrices with an equivalent array,
// and populate the `decomposed` map with the members that have been replaced.
std::unordered_map<ast::StructMember*, MatrixInfo> decomposed;
GatherCustomStrideMatrixMembers(
ctx.src, [&](const sem::StructMember* member, sem::Matrix* matrix,
uint32_t stride) {
// We've got ourselves a struct member of a matrix type with a custom
// stride. Replace this with an array of column vectors.
MatrixInfo info{stride, matrix};
auto* replacement = ctx.dst->Member(
member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst));
ctx.Replace(member->Declaration(), replacement);
decomposed.emplace(member->Declaration(), info);
return GatherResult::kContinue;
});
// For all expressions where a single matrix column vector was indexed, we can
// preserve these without calling conversion functions.
// Example:
// ssbo.mat[2] -> ssbo.mat[2]
ctx.ReplaceAll(
[&](ast::ArrayAccessorExpression* expr) -> ast::ArrayAccessorExpression* {
if (auto* access =
ctx.src->Sem().Get<sem::StructMemberAccess>(expr->array())) {
auto it = decomposed.find(access->Member()->Declaration());
if (it != decomposed.end()) {
auto* obj = ctx.CloneWithoutTransform(expr->array());
auto* idx = ctx.Clone(expr->idx_expr());
return ctx.dst->IndexAccessor(obj, idx);
}
}
return nullptr;
});
// For all struct member accesses to the matrix on the LHS of an assignment,
// we need to convert the matrix to the array before assigning to the
// structure.
// Example:
// ssbo.mat = mat_to_arr(m)
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
ctx.ReplaceAll([&](ast::AssignmentStatement* stmt) -> ast::Statement* {
if (auto* access =
ctx.src->Sem().Get<sem::StructMemberAccess>(stmt->lhs())) {
auto it = decomposed.find(access->Member()->Declaration());
if (it == decomposed.end()) {
return nullptr;
}
MatrixInfo info = it->second;
auto fn = utils::GetOrCreate(mat_to_arr, info, [&] {
auto name = ctx.dst->Symbols().New(
"mat" + std::to_string(info.matrix->columns()) + "x" +
std::to_string(info.matrix->rows()) + "_stride_" +
std::to_string(info.stride) + "_to_arr");
auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
auto array = [&] { return info.array(ctx.dst); };
auto mat = ctx.dst->Sym("mat");
ast::ExpressionList columns(info.matrix->columns());
for (uint32_t i = 0; i < static_cast<uint32_t>(columns.size()); i++) {
columns[i] = ctx.dst->IndexAccessor(mat, i);
}
ctx.dst->Func(name,
{
ctx.dst->Param(mat, matrix()),
},
array(),
{
ctx.dst->Return(ctx.dst->Construct(array(), columns)),
});
return name;
});
auto* lhs = ctx.CloneWithoutTransform(stmt->lhs());
auto* rhs = ctx.dst->Call(fn, ctx.Clone(stmt->rhs()));
return ctx.dst->Assign(lhs, rhs);
}
return nullptr;
});
// For all other struct member accesses, we need to convert the array to the
// matrix type. Example:
// m = arr_to_mat(ssbo.mat)
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
ctx.ReplaceAll([&](ast::MemberAccessorExpression* expr) -> ast::Expression* {
if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr)) {
auto it = decomposed.find(access->Member()->Declaration());
if (it == decomposed.end()) {
return nullptr;
}
MatrixInfo info = it->second;
auto fn = utils::GetOrCreate(arr_to_mat, info, [&] {
auto name = ctx.dst->Symbols().New(
"arr_to_mat" + std::to_string(info.matrix->columns()) + "x" +
std::to_string(info.matrix->rows()) + "_stride_" +
std::to_string(info.stride));
auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
auto array = [&] { return info.array(ctx.dst); };
auto arr = ctx.dst->Sym("arr");
ast::ExpressionList columns(info.matrix->columns());
for (uint32_t i = 0; i < static_cast<uint32_t>(columns.size()); i++) {
columns[i] = ctx.dst->IndexAccessor(arr, i);
}
ctx.dst->Func(
name,
{
ctx.dst->Param(arr, array()),
},
matrix(),
{
ctx.dst->Return(ctx.dst->Construct(matrix(), columns)),
});
return name;
});
return ctx.dst->Call(fn, ctx.CloneWithoutTransform(expr));
}
return nullptr;
});
ctx.Clone();
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,54 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TRANSFORM_DECOMPOSE_STRIDED_MATRIX_H_
#define SRC_TRANSFORM_DECOMPOSE_STRIDED_MATRIX_H_
#include "src/transform/transform.h"
namespace tint {
namespace transform {
/// DecomposeStridedMatrix transforms replaces matrix members of storage or
/// uniform buffer structures, that have a [[stride]] decoration, into an array
/// of N column vectors.
/// This transform is used by the SPIR-V reader to handle the SPIR-V
/// MatrixStride decoration.
class DecomposeStridedMatrix
: public Castable<DecomposeStridedMatrix, Transform> {
public:
/// Constructor
DecomposeStridedMatrix();
/// Destructor
~DecomposeStridedMatrix() override;
/// @param program the program to inspect
/// @returns true if this transform should be run for the given program
static bool ShouldRun(const Program* program);
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TRANSFORM_DECOMPOSE_STRIDED_MATRIX_H_

View File

@@ -0,0 +1,727 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/transform/decompose_strided_matrix.h"
#include <memory>
#include <utility>
#include <vector>
#include "src/ast/disable_validation_decoration.h"
#include "src/program_builder.h"
#include "src/transform/inline_pointer_lets.h"
#include "src/transform/simplify.h"
#include "src/transform/test_helper.h"
namespace tint {
namespace transform {
namespace {
using DecomposeStridedMatrixTest = TransformTest;
using f32 = ProgramBuilder::f32;
TEST_F(DecomposeStridedMatrixTest, Empty) {
auto* src = R"()";
auto* expect = src;
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, MissingDependencyInlinePointerLets) {
auto* src = R"()";
auto* expect =
R"(error: tint::transform::DecomposeStridedMatrix depends on tint::transform::InlinePointerLets but the dependency was not run)";
auto got = Run<Simplify, DecomposeStridedMatrix>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, MissingDependencySimplify) {
auto* src = R"()";
auto* expect =
R"(error: tint::transform::DecomposeStridedMatrix depends on tint::transform::Simplify but the dependency was not run)";
auto got = Run<InlinePointerLets, DecomposeStridedMatrix>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix) {
// [[block]]
// struct S {
// [[offset(16), stride(32)]]
// [[internal(ignore_stride_decoration)]]
// m : mat2x2<f32>;
// };
// [[group(0), binding(0)]] var<uniform> s : S;
//
// [[stage(compute), workgroup_size(1)]]
// fn f() {
// let x : mat2x2<f32> = s.m;
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetDecoration>(16),
b.create<ast::StrideDecoration>(32),
b.ASTNodes().Create<ast::DisableValidationDecoration>(
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
}),
},
{
b.StructBlock(),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
b.GroupAndBinding(0, 0));
b.Func(
"f", {}, b.ty.void_(),
{
b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
[[block]]
struct S {
[[size(16)]]
padding : u32;
m : [[stride(32)]] array<vec2<f32>, 2>;
};
[[group(0), binding(0)]] var<uniform> s : S;
fn arr_to_mat2x2_stride_32(arr : [[stride(32)]] array<vec2<f32>, 2>) -> mat2x2<f32> {
return mat2x2<f32>(arr[0u], arr[1u]);
}
[[stage(compute), workgroup_size(1)]]
fn f() {
let x : mat2x2<f32> = arr_to_mat2x2_stride_32(s.m);
}
)";
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadUniformColumn) {
// [[block]]
// struct S {
// [[offset(16), stride(32)]]
// [[internal(ignore_stride_decoration)]]
// m : mat2x2<f32>;
// };
// [[group(0), binding(0)]] var<uniform> s : S;
//
// [[stage(compute), workgroup_size(1)]]
// fn f() {
// let x : vec2<f32> = s.m[1];
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetDecoration>(16),
b.create<ast::StrideDecoration>(32),
b.ASTNodes().Create<ast::DisableValidationDecoration>(
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
}),
},
{
b.StructBlock(),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
b.GroupAndBinding(0, 0));
b.Func("f", {}, b.ty.void_(),
{
b.Decl(b.Const("x", b.ty.vec2<f32>(),
b.IndexAccessor(b.MemberAccessor("s", "m"), 1))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
[[block]]
struct S {
[[size(16)]]
padding : u32;
m : [[stride(32)]] array<vec2<f32>, 2>;
};
[[group(0), binding(0)]] var<uniform> s : S;
[[stage(compute), workgroup_size(1)]]
fn f() {
let x : vec2<f32> = s.m[1];
}
)";
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix_DefaultStride) {
// [[block]]
// struct S {
// [[offset(16), stride(8)]]
// [[internal(ignore_stride_decoration)]]
// m : mat2x2<f32>;
// };
// [[group(0), binding(0)]] var<uniform> s : S;
//
// [[stage(compute), workgroup_size(1)]]
// fn f() {
// let x : mat2x2<f32> = s.m;
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetDecoration>(16),
b.create<ast::StrideDecoration>(8),
b.ASTNodes().Create<ast::DisableValidationDecoration>(
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
}),
},
{
b.StructBlock(),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
b.GroupAndBinding(0, 0));
b.Func(
"f", {}, b.ty.void_(),
{
b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
[[block]]
struct S {
[[size(16)]]
padding : u32;
[[stride(8), internal(disable_validation__ignore_stride)]]
m : mat2x2<f32>;
};
[[group(0), binding(0)]] var<uniform> s : S;
[[stage(compute), workgroup_size(1)]]
fn f() {
let x : mat2x2<f32> = s.m;
}
)";
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadStorageMatrix) {
// [[block]]
// struct S {
// [[offset(8), stride(32)]]
// [[internal(ignore_stride_decoration)]]
// m : mat2x2<f32>;
// };
// [[group(0), binding(0)]] var<storage, read_write> s : S;
//
// [[stage(compute), workgroup_size(1)]]
// fn f() {
// let x : mat2x2<f32> = s.m;
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetDecoration>(8),
b.create<ast::StrideDecoration>(32),
b.ASTNodes().Create<ast::DisableValidationDecoration>(
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
}),
},
{
b.StructBlock(),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func(
"f", {}, b.ty.void_(),
{
b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
[[block]]
struct S {
[[size(8)]]
padding : u32;
m : [[stride(32)]] array<vec2<f32>, 2>;
};
[[group(0), binding(0)]] var<storage, read_write> s : S;
fn arr_to_mat2x2_stride_32(arr : [[stride(32)]] array<vec2<f32>, 2>) -> mat2x2<f32> {
return mat2x2<f32>(arr[0u], arr[1u]);
}
[[stage(compute), workgroup_size(1)]]
fn f() {
let x : mat2x2<f32> = arr_to_mat2x2_stride_32(s.m);
}
)";
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadStorageColumn) {
// [[block]]
// struct S {
// [[offset(16), stride(32)]]
// [[internal(ignore_stride_decoration)]]
// m : mat2x2<f32>;
// };
// [[group(0), binding(0)]] var<storage, read_write> s : S;
//
// [[stage(compute), workgroup_size(1)]]
// fn f() {
// let x : vec2<f32> = s.m[1];
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetDecoration>(16),
b.create<ast::StrideDecoration>(32),
b.ASTNodes().Create<ast::DisableValidationDecoration>(
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
}),
},
{
b.StructBlock(),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func("f", {}, b.ty.void_(),
{
b.Decl(b.Const("x", b.ty.vec2<f32>(),
b.IndexAccessor(b.MemberAccessor("s", "m"), 1))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
[[block]]
struct S {
[[size(16)]]
padding : u32;
m : [[stride(32)]] array<vec2<f32>, 2>;
};
[[group(0), binding(0)]] var<storage, read_write> s : S;
[[stage(compute), workgroup_size(1)]]
fn f() {
let x : vec2<f32> = s.m[1];
}
)";
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, WriteStorageMatrix) {
// [[block]]
// struct S {
// [[offset(8), stride(32)]]
// [[internal(ignore_stride_decoration)]]
// m : mat2x2<f32>;
// };
// [[group(0), binding(0)]] var<storage, read_write> s : S;
//
// [[stage(compute), workgroup_size(1)]]
// fn f() {
// s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetDecoration>(8),
b.create<ast::StrideDecoration>(32),
b.ASTNodes().Create<ast::DisableValidationDecoration>(
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
}),
},
{
b.StructBlock(),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func("f", {}, b.ty.void_(),
{
b.Assign(b.MemberAccessor("s", "m"),
b.mat2x2<f32>(b.vec2<f32>(1.0f, 2.0f),
b.vec2<f32>(3.0f, 4.0f))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
[[block]]
struct S {
[[size(8)]]
padding : u32;
m : [[stride(32)]] array<vec2<f32>, 2>;
};
[[group(0), binding(0)]] var<storage, read_write> s : S;
fn mat2x2_stride_32_to_arr(mat : mat2x2<f32>) -> [[stride(32)]] array<vec2<f32>, 2> {
return [[stride(32)]] array<vec2<f32>, 2>(mat[0u], mat[1u]);
}
[[stage(compute), workgroup_size(1)]]
fn f() {
s.m = mat2x2_stride_32_to_arr(mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0)));
}
)";
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, WriteStorageColumn) {
// [[block]]
// struct S {
// [[offset(8), stride(32)]]
// [[internal(ignore_stride_decoration)]]
// m : mat2x2<f32>;
// };
// [[group(0), binding(0)]] var<storage, read_write> s : S;
//
// [[stage(compute), workgroup_size(1)]]
// fn f() {
// s.m[1] = vec2<f32>(1.0, 2.0);
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetDecoration>(8),
b.create<ast::StrideDecoration>(32),
b.ASTNodes().Create<ast::DisableValidationDecoration>(
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
}),
},
{
b.StructBlock(),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func("f", {}, b.ty.void_(),
{
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m"), 1),
b.vec2<f32>(1.0f, 2.0f)),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
[[block]]
struct S {
[[size(8)]]
padding : u32;
m : [[stride(32)]] array<vec2<f32>, 2>;
};
[[group(0), binding(0)]] var<storage, read_write> s : S;
[[stage(compute), workgroup_size(1)]]
fn f() {
s.m[1] = vec2<f32>(1.0, 2.0);
}
)";
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadWriteViaPointerLets) {
// [[block]]
// struct S {
// [[offset(8), stride(32)]]
// [[internal(ignore_stride_decoration)]]
// m : mat2x2<f32>;
// };
// [[group(0), binding(0)]] var<storage, read_write> s : S;
//
// [[stage(compute), workgroup_size(1)]]
// fn f() {
// let a = &s.m;
// let b = &*&*(a);
// let x = *b;
// let y = (*b)[1];
// let z = x[1];
// (*b) = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
// (*b)[1] = vec2<f32>(5.0, 6.0);
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetDecoration>(8),
b.create<ast::StrideDecoration>(32),
b.ASTNodes().Create<ast::DisableValidationDecoration>(
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
}),
},
{
b.StructBlock(),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func(
"f", {}, b.ty.void_(),
{
b.Decl(
b.Const("a", nullptr, b.AddressOf(b.MemberAccessor("s", "m")))),
b.Decl(b.Const("b", nullptr,
b.AddressOf(b.Deref(b.AddressOf(b.Deref("a")))))),
b.Decl(b.Const("x", nullptr, b.Deref("b"))),
b.Decl(b.Const("y", nullptr, b.IndexAccessor(b.Deref("b"), 1))),
b.Decl(b.Const("z", nullptr, b.IndexAccessor("x", 1))),
b.Assign(b.Deref("b"), b.mat2x2<f32>(b.vec2<f32>(1.0f, 2.0f),
b.vec2<f32>(3.0f, 4.0f))),
b.Assign(b.IndexAccessor(b.Deref("b"), 1), b.vec2<f32>(5.0f, 6.0f)),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
[[block]]
struct S {
[[size(8)]]
padding : u32;
m : [[stride(32)]] array<vec2<f32>, 2>;
};
[[group(0), binding(0)]] var<storage, read_write> s : S;
fn arr_to_mat2x2_stride_32(arr : [[stride(32)]] array<vec2<f32>, 2>) -> mat2x2<f32> {
return mat2x2<f32>(arr[0u], arr[1u]);
}
fn mat2x2_stride_32_to_arr(mat : mat2x2<f32>) -> [[stride(32)]] array<vec2<f32>, 2> {
return [[stride(32)]] array<vec2<f32>, 2>(mat[0u], mat[1u]);
}
[[stage(compute), workgroup_size(1)]]
fn f() {
let x = arr_to_mat2x2_stride_32(s.m);
let y = s.m[1];
let z = x[1];
s.m = mat2x2_stride_32_to_arr(mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0)));
s.m[1] = vec2<f32>(5.0, 6.0);
}
)";
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadPrivateMatrix) {
// struct S {
// [[offset(8), stride(32)]]
// [[internal(ignore_stride_decoration)]]
// m : mat2x2<f32>;
// };
// var<private> s : S;
//
// [[stage(compute), workgroup_size(1)]]
// fn f() {
// let x : mat2x2<f32> = s.m;
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetDecoration>(8),
b.create<ast::StrideDecoration>(32),
b.ASTNodes().Create<ast::DisableValidationDecoration>(
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
}),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kPrivate);
b.Func(
"f", {}, b.ty.void_(),
{
b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
struct S {
[[size(8)]]
padding : u32;
[[stride(32), internal(disable_validation__ignore_stride)]]
m : mat2x2<f32>;
};
var<private> s : S;
[[stage(compute), workgroup_size(1)]]
fn f() {
let x : mat2x2<f32> = s.m;
}
)";
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, WritePrivateMatrix) {
// struct S {
// [[offset(8), stride(32)]]
// [[internal(ignore_stride_decoration)]]
// m : mat2x2<f32>;
// };
// var<private> s : S;
//
// [[stage(compute), workgroup_size(1)]]
// fn f() {
// s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetDecoration>(8),
b.create<ast::StrideDecoration>(32),
b.ASTNodes().Create<ast::DisableValidationDecoration>(
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
}),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kPrivate);
b.Func("f", {}, b.ty.void_(),
{
b.Assign(b.MemberAccessor("s", "m"),
b.mat2x2<f32>(b.vec2<f32>(1.0f, 2.0f),
b.vec2<f32>(3.0f, 4.0f))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
struct S {
[[size(8)]]
padding : u32;
[[stride(32), internal(disable_validation__ignore_stride)]]
m : mat2x2<f32>;
};
var<private> s : S;
[[stage(compute), workgroup_size(1)]]
fn f() {
s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
}
)";
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -59,6 +59,16 @@ class TransformTestBase : public BASE {
// Keep this pointer alive after Transform() returns
files_.emplace_back(std::move(file));
return Run<TRANSFORMS...>(std::move(program), data);
}
/// Transforms and returns program `program`, transformed using a transform of
/// type `TRANSFORM`.
/// @param program the input Program
/// @param data the optional DataMap to pass to Transform::Run()
/// @return the transformed output
template <typename... TRANSFORMS>
Output Run(Program&& program, const DataMap& data = {}) {
if (!program.IsValid()) {
return Output(std::move(program));
}