reader/spirv: Handle the MatrixStride decoration

Add `transform::DecomposeStridedMatrix`, which replaces matrix members of storage or uniform buffer structures, that have a [[stride]] decoration, into an array
of N column vectors.

This is required to correctly handle `mat2x2` matrices in UBOs, as std140 rules will expect a default stride of 16 bytes, when in WGSL the default structure layout expects a stride of 8 bytes.

Bug: tint:1047
Change-Id: If5ca3c6ec087bbc1ac31a8d9a657b99bf34042a4
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/59840
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ben Clayton 2021-07-27 08:17:29 +00:00 committed by Tint LUCI CQ
parent c6cbe3fda6
commit 97668c8c37
28 changed files with 1572 additions and 65 deletions

View File

@ -426,6 +426,8 @@ libtint_source_set("libtint_core_all_src") {
"transform/canonicalize_entry_point_io.h",
"transform/decompose_memory_access.cc",
"transform/decompose_memory_access.h",
"transform/decompose_strided_matrix.cc",
"transform/decompose_strided_matrix.h",
"transform/external_texture_transform.cc",
"transform/external_texture_transform.h",
"transform/first_index_offset.cc",

View File

@ -296,6 +296,8 @@ set(TINT_LIB_SRCS
transform/canonicalize_entry_point_io.h
transform/decompose_memory_access.cc
transform/decompose_memory_access.h
transform/decompose_strided_matrix.cc
transform/decompose_strided_matrix.h
transform/external_texture_transform.cc
transform/external_texture_transform.h
transform/first_index_offset.cc
@ -904,6 +906,7 @@ if(${TINT_BUILD_TESTS})
transform/calculate_array_length_test.cc
transform/canonicalize_entry_point_io_test.cc
transform/decompose_memory_access_test.cc
transform/decompose_strided_matrix_test.cc
transform/external_texture_transform_test.cc
transform/first_index_offset_test.cc
transform/fold_constants_test.cc

View File

@ -40,6 +40,8 @@ std::string DisableValidationDecoration::InternalName() const {
return "disable_validation__entry_point_parameter";
case DisabledValidation::kIgnoreConstructibleFunctionParameter:
return "disable_validation__ignore_constructible_function_parameter";
case DisabledValidation::kIgnoreStrideDecoration:
return "disable_validation__ignore_stride";
}
return "<invalid>";
}

View File

@ -40,6 +40,9 @@ enum class DisabledValidation {
/// When applied to a function parameter, the validator will not
/// check if parameter type is constructible
kIgnoreConstructibleFunctionParameter,
/// When applied to a member decoration, a stride decoration may be applied to
/// non-array types.
kIgnoreStrideDecoration,
};
/// An internal decoration used to tell the validator to ignore specific

View File

@ -32,7 +32,7 @@ void InternalDecoration::to_str(const sem::Info&,
std::ostream& out,
size_t indent) const {
make_indent(out, indent);
out << "tint_internal(" << InternalName() << ")" << std::endl;
out << "tint_internal(" << InternalName() << ")";
}
} // namespace ast

View File

@ -17,6 +17,10 @@
#include <utility>
#include "src/reader/spirv/parser_impl.h"
#include "src/transform/decompose_strided_matrix.h"
#include "src/transform/inline_pointer_lets.h"
#include "src/transform/manager.h"
#include "src/transform/simplify.h"
namespace tint {
namespace reader {
@ -40,7 +44,19 @@ Program Parse(const std::vector<uint32_t>& input) {
ProgramBuilder output;
CloneContext(&output, &program_with_disjoint_ast, false).Clone();
return Program(std::move(output));
auto program = Program(std::move(output));
// If the generated program contains matrices with a custom MatrixStride
// attribute then we need to decompose these into an array of vectors
if (transform::DecomposeStridedMatrix::ShouldRun(&program)) {
transform::Manager manager;
manager.Add<transform::InlinePointerLets>();
manager.Add<transform::Simplify>();
manager.Add<transform::DecomposeStridedMatrix>();
return manager.Run(&program).program;
}
return program;
}
} // namespace spirv

View File

@ -21,6 +21,7 @@
#include "source/opt/build_module.h"
#include "src/ast/bitcast_expression.h"
#include "src/ast/disable_validation_decoration.h"
#include "src/ast/interpolate_decoration.h"
#include "src/ast/override_decoration.h"
#include "src/ast/struct_block_decoration.h"
@ -439,13 +440,14 @@ std::string ParserImpl::ShowType(uint32_t type_id) {
return "SPIR-V type " + std::to_string(type_id);
}
ast::Decoration* ParserImpl::ConvertMemberDecoration(
ast::DecorationList ParserImpl::ConvertMemberDecoration(
uint32_t struct_type_id,
uint32_t member_index,
const Type* member_ty,
const Decoration& decoration) {
if (decoration.empty()) {
Fail() << "malformed SPIR-V decoration: it's empty";
return nullptr;
return {};
}
switch (decoration[0]) {
case SpvDecorationOffset:
@ -454,38 +456,49 @@ ast::Decoration* ParserImpl::ConvertMemberDecoration(
<< "malformed Offset decoration: expected 1 literal operand, has "
<< decoration.size() - 1 << ": member " << member_index << " of "
<< ShowType(struct_type_id);
return nullptr;
return {};
}
return create<ast::StructMemberOffsetDecoration>(Source{}, decoration[1]);
return {
create<ast::StructMemberOffsetDecoration>(Source{}, decoration[1]),
};
case SpvDecorationNonReadable:
// WGSL doesn't have a member decoration for this. Silently drop it.
return nullptr;
return {};
case SpvDecorationNonWritable:
// WGSL doesn't have a member decoration for this.
return nullptr;
return {};
case SpvDecorationColMajor:
// WGSL only supports column major matrices.
return nullptr;
return {};
case SpvDecorationRelaxedPrecision:
// WGSL doesn't support relaxed precision.
return nullptr;
return {};
case SpvDecorationRowMajor:
Fail() << "WGSL does not support row-major matrices: can't "
"translate member "
<< member_index << " of " << ShowType(struct_type_id);
return nullptr;
return {};
case SpvDecorationMatrixStride: {
if (decoration.size() != 2) {
Fail() << "malformed MatrixStride decoration: expected 1 literal "
"operand, has "
<< decoration.size() - 1 << ": member " << member_index << " of "
<< ShowType(struct_type_id);
return nullptr;
return {};
}
// TODO(dneto): Fail if the matrix stride is not allocation size of the
// column vector of the underlying matrix. This would need to unpack
// any levels of array-ness.
return nullptr;
uint32_t stride = decoration[1];
uint32_t natural_stride = 0;
if (auto* mat = member_ty->As<Matrix>()) {
natural_stride = (mat->rows == 2) ? 8 : 16;
}
if (stride == natural_stride) {
return {};
}
return {
create<ast::StrideDecoration>(Source{}, decoration[1]),
builder_.ASTNodes().Create<ast::DisableValidationDecoration>(
builder_.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
};
}
default:
// TODO(dneto): Support the remaining member decorations.
@ -493,7 +506,7 @@ ast::Decoration* ParserImpl::ConvertMemberDecoration(
}
Fail() << "unhandled member decoration: " << decoration[0] << " on member "
<< member_index << " of " << ShowType(struct_type_id);
return nullptr;
return {};
}
bool ParserImpl::BuildInternalModule() {
@ -1126,14 +1139,14 @@ const Type* ParserImpl::ConvertType(
// the members are non-writable.
is_non_writable = true;
} else {
auto* ast_member_decoration =
ConvertMemberDecoration(type_id, member_index, decoration);
auto decos = ConvertMemberDecoration(type_id, member_index,
ast_member_ty, decoration);
for (auto* deco : decos) {
ast_member_decorations.emplace_back(deco);
}
if (!success_) {
return nullptr;
}
if (ast_member_decoration) {
ast_member_decorations.push_back(ast_member_decoration);
}
}
}

View File

@ -272,16 +272,19 @@ class ParserImpl : Reader {
ast::Decoration* SetLocation(ast::DecorationList* decos,
ast::Decoration* replacement);
/// Converts a SPIR-V struct member decoration. If the decoration is
/// recognized but deliberately dropped, then returns nullptr without a
/// diagnostic. On failure, emits a diagnostic and returns nullptr.
/// Converts a SPIR-V struct member decoration into a number of AST
/// decorations. If the decoration is recognized but deliberately dropped,
/// then returns an empty list without a diagnostic. On failure, emits a
/// diagnostic and returns an empty list.
/// @param struct_type_id the ID of the struct type
/// @param member_index the index of the member
/// @param member_ty the type of the member
/// @param decoration an encoded SPIR-V Decoration
/// @returns the corresponding ast::StructuMemberDecoration
ast::Decoration* ConvertMemberDecoration(uint32_t struct_type_id,
uint32_t member_index,
const Decoration& decoration);
/// @returns the AST decorations
ast::DecorationList ConvertMemberDecoration(uint32_t struct_type_id,
uint32_t member_index,
const Type* member_ty,
const Decoration& decoration);
/// Returns a string for the given type. If the type ID is invalid,
/// then the resulting string only names the type ID.

View File

@ -25,16 +25,17 @@ using ::testing::Eq;
TEST_F(SpvParserTest, ConvertMemberDecoration_Empty) {
auto p = parser(std::vector<uint32_t>{});
auto* result = p->ConvertMemberDecoration(1, 1, {});
EXPECT_EQ(result, nullptr);
auto result = p->ConvertMemberDecoration(1, 1, nullptr, {});
EXPECT_TRUE(result.empty());
EXPECT_THAT(p->error(), Eq("malformed SPIR-V decoration: it's empty"));
}
TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithoutOperand) {
auto p = parser(std::vector<uint32_t>{});
auto* result = p->ConvertMemberDecoration(12, 13, {SpvDecorationOffset});
EXPECT_EQ(result, nullptr);
auto result =
p->ConvertMemberDecoration(12, 13, nullptr, {SpvDecorationOffset});
EXPECT_TRUE(result.empty());
EXPECT_THAT(p->error(), Eq("malformed Offset decoration: expected 1 literal "
"operand, has 0: member 13 of SPIR-V type 12"));
}
@ -42,9 +43,9 @@ TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithoutOperand) {
TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithTooManyOperands) {
auto p = parser(std::vector<uint32_t>{});
auto* result =
p->ConvertMemberDecoration(12, 13, {SpvDecorationOffset, 3, 4});
EXPECT_EQ(result, nullptr);
auto result =
p->ConvertMemberDecoration(12, 13, nullptr, {SpvDecorationOffset, 3, 4});
EXPECT_TRUE(result.empty());
EXPECT_THAT(p->error(), Eq("malformed Offset decoration: expected 1 literal "
"operand, has 2: member 13 of SPIR-V type 12"));
}
@ -52,32 +53,100 @@ TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithTooManyOperands) {
TEST_F(SpvParserTest, ConvertMemberDecoration_Offset) {
auto p = parser(std::vector<uint32_t>{});
auto* result = p->ConvertMemberDecoration(1, 1, {SpvDecorationOffset, 8});
ASSERT_NE(result, nullptr);
EXPECT_TRUE(result->Is<ast::StructMemberOffsetDecoration>());
auto* offset_deco = result->As<ast::StructMemberOffsetDecoration>();
auto result =
p->ConvertMemberDecoration(1, 1, nullptr, {SpvDecorationOffset, 8});
ASSERT_FALSE(result.empty());
EXPECT_TRUE(result[0]->Is<ast::StructMemberOffsetDecoration>());
auto* offset_deco = result[0]->As<ast::StructMemberOffsetDecoration>();
ASSERT_NE(offset_deco, nullptr);
EXPECT_EQ(offset_deco->offset(), 8u);
EXPECT_TRUE(p->error().empty());
}
TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x2_Stride_Natural) {
auto p = parser(std::vector<uint32_t>{});
spirv::F32 f32;
spirv::Matrix matrix(&f32, 2, 2);
auto result =
p->ConvertMemberDecoration(1, 1, &matrix, {SpvDecorationMatrixStride, 8});
EXPECT_TRUE(result.empty());
EXPECT_TRUE(p->error().empty());
}
TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x2_Stride_Custom) {
auto p = parser(std::vector<uint32_t>{});
spirv::F32 f32;
spirv::Matrix matrix(&f32, 2, 2);
auto result = p->ConvertMemberDecoration(1, 1, &matrix,
{SpvDecorationMatrixStride, 16});
ASSERT_FALSE(result.empty());
EXPECT_TRUE(result[0]->Is<ast::StrideDecoration>());
auto* stride_deco = result[0]->As<ast::StrideDecoration>();
ASSERT_NE(stride_deco, nullptr);
EXPECT_EQ(stride_deco->stride(), 16u);
EXPECT_TRUE(p->error().empty());
}
TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x4_Stride_Natural) {
auto p = parser(std::vector<uint32_t>{});
spirv::F32 f32;
spirv::Matrix matrix(&f32, 2, 4);
auto result = p->ConvertMemberDecoration(1, 1, &matrix,
{SpvDecorationMatrixStride, 16});
EXPECT_TRUE(result.empty());
EXPECT_TRUE(p->error().empty());
}
TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x4_Stride_Custom) {
auto p = parser(std::vector<uint32_t>{});
spirv::F32 f32;
spirv::Matrix matrix(&f32, 2, 4);
auto result = p->ConvertMemberDecoration(1, 1, &matrix,
{SpvDecorationMatrixStride, 64});
ASSERT_FALSE(result.empty());
EXPECT_TRUE(result[0]->Is<ast::StrideDecoration>());
auto* stride_deco = result[0]->As<ast::StrideDecoration>();
ASSERT_NE(stride_deco, nullptr);
EXPECT_EQ(stride_deco->stride(), 64u);
EXPECT_TRUE(p->error().empty());
}
TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x3_Stride_Custom) {
auto p = parser(std::vector<uint32_t>{});
spirv::F32 f32;
spirv::Matrix matrix(&f32, 2, 3);
auto result = p->ConvertMemberDecoration(1, 1, &matrix,
{SpvDecorationMatrixStride, 32});
ASSERT_FALSE(result.empty());
EXPECT_TRUE(result[0]->Is<ast::StrideDecoration>());
auto* stride_deco = result[0]->As<ast::StrideDecoration>();
ASSERT_NE(stride_deco, nullptr);
EXPECT_EQ(stride_deco->stride(), 32u);
EXPECT_TRUE(p->error().empty());
}
TEST_F(SpvParserTest, ConvertMemberDecoration_RelaxedPrecision) {
// WGSL does not support relaxed precision. Drop it.
// It's functionally correct to use full precision f32 instead of
// relaxed precision f32.
auto p = parser(std::vector<uint32_t>{});
auto* result =
p->ConvertMemberDecoration(1, 1, {SpvDecorationRelaxedPrecision});
EXPECT_EQ(result, nullptr);
auto result = p->ConvertMemberDecoration(1, 1, nullptr,
{SpvDecorationRelaxedPrecision});
EXPECT_TRUE(result.empty());
EXPECT_TRUE(p->error().empty());
}
TEST_F(SpvParserTest, ConvertMemberDecoration_UnhandledDecoration) {
auto p = parser(std::vector<uint32_t>{});
auto* result = p->ConvertMemberDecoration(12, 13, {12345678});
EXPECT_EQ(result, nullptr);
auto result = p->ConvertMemberDecoration(12, 13, nullptr, {12345678});
EXPECT_TRUE(result.empty());
EXPECT_THAT(p->error(), Eq("unhandled member decoration: 12345678 on member "
"13 of SPIR-V type 12"));
}

View File

@ -2015,7 +2015,7 @@ TEST_F(SpvModuleScopeVarParserTest, ColMajorDecoration_Dropped) {
})")) << module_str;
}
TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration_Dropped) {
TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration_Natural_Dropped) {
auto p = parser(test::Assemble(Preamble() + FragMain() + R"(
OpName %myvar "myvar"
OpDecorate %myvar DescriptorSet 0
@ -2054,6 +2054,45 @@ TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration_Dropped) {
})")) << module_str;
}
TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration) {
auto p = parser(test::Assemble(Preamble() + FragMain() + R"(
OpName %myvar "myvar"
OpDecorate %myvar DescriptorSet 0
OpDecorate %myvar Binding 0
OpDecorate %s Block
OpMemberDecorate %s 0 MatrixStride 64
OpMemberDecorate %s 0 Offset 0
%void = OpTypeVoid
%voidfn = OpTypeFunction %void
%float = OpTypeFloat 32
%v2float = OpTypeVector %float 2
%m3v2float = OpTypeMatrix %v2float 3
%s = OpTypeStruct %m3v2float
%ptr_sb_s = OpTypePointer StorageBuffer %s
%myvar = OpVariable %ptr_sb_s StorageBuffer
)" + MainBody()));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
EXPECT_TRUE(p->error().empty());
const auto module_str = p->program().to_str();
EXPECT_THAT(module_str, HasSubstr(R"(
Struct S {
[[block]]
StructMember{[[ stride 64 tint_internal(disable_validation__ignore_stride) offset 0 ]] field0: __mat_2_3__f32}
}
Variable{
Decorations{
GroupDecoration{0}
BindingDecoration{0}
}
myvar
storage
read_write
__type_name_S
}
})")) << module_str;
}
TEST_F(SpvModuleScopeVarParserTest, RowMajorDecoration_IsError) {
auto p = parser(test::Assemble(Preamble() + FragMain() + R"(
OpName %myvar "myvar"
@ -2620,7 +2659,8 @@ TEST_F(SpvModuleScopeVarParserTest, SampleId_I32_Load_AccessChain) {
private
undefined
__i32
})")) <<module_str;
})"))
<< module_str;
// Correct creation of value
EXPECT_THAT(module_str, HasSubstr(R"(
@ -3006,7 +3046,8 @@ TEST_F(SpvModuleScopeVarParserTest, SampleMask_In_U32_Direct) {
private
undefined
__array__u32_1
})")) <<module_str;
})"))
<< module_str;
// Correct creation of value
EXPECT_THAT(module_str, HasSubstr(R"(
@ -3149,7 +3190,8 @@ TEST_F(SpvModuleScopeVarParserTest, SampleMask_In_U32_AccessChain) {
private
undefined
__array__u32_1
})")) <<module_str;
})"))
<< module_str;
// Correct creation of value
EXPECT_THAT(module_str, HasSubstr(R"(
@ -5543,7 +5585,6 @@ INSTANTIATE_TEST_SUITE_P(
// {"NumWorkgroups", "%uint", "num_workgroups"}
// {"NumWorkgroups", "%int", "num_workgroups"}
TEST_F(SpvModuleScopeVarParserTest, RegisterInputOutputVars) {
const std::string assembly =
R"(

View File

@ -184,18 +184,21 @@ class ParserImplWrapperForTest {
return impl_.GetDecorationsForMember(id, member_index);
}
/// Converts a SPIR-V struct member decoration. If the decoration is
/// recognized but deliberately dropped, then returns nullptr without a
/// diagnostic. On failure, emits a diagnostic and returns nullptr.
/// Converts a SPIR-V struct member decoration into a number of AST
/// decorations. If the decoration is recognized but deliberately dropped,
/// then returns an empty list without a diagnostic. On failure, emits a
/// diagnostic and returns an empty list.
/// @param struct_type_id the ID of the struct type
/// @param member_index the index of the member
/// @param member_ty the type of the member
/// @param decoration an encoded SPIR-V Decoration
/// @returns the corresponding ast::StructuMemberDecoration
ast::Decoration* ConvertMemberDecoration(uint32_t struct_type_id,
uint32_t member_index,
const Decoration& decoration) {
/// @returns the AST decorations
ast::DecorationList ConvertMemberDecoration(uint32_t struct_type_id,
uint32_t member_index,
const Type* member_ty,
const Decoration& decoration) {
return impl_.ConvertMemberDecoration(struct_type_id, member_index,
decoration);
member_ty, decoration);
}
/// For a SPIR-V ID that might define a sampler, image, or sampled image

View File

@ -3937,13 +3937,20 @@ bool Resolver::ValidateStructure(const sem::Struct* str) {
auto has_position = false;
ast::InvariantDecoration* invariant_attribute = nullptr;
for (auto* deco : member->Declaration()->decorations()) {
if (!(deco->Is<ast::BuiltinDecoration>() ||
deco->Is<ast::InterpolateDecoration>() ||
deco->Is<ast::InvariantDecoration>() ||
deco->Is<ast::LocationDecoration>() ||
deco->Is<ast::StructMemberOffsetDecoration>() ||
deco->Is<ast::StructMemberSizeDecoration>() ||
deco->Is<ast::StructMemberAlignDecoration>())) {
if (!deco->IsAnyOf<ast::BuiltinDecoration, //
ast::InternalDecoration, //
ast::InterpolateDecoration, //
ast::InvariantDecoration, //
ast::LocationDecoration, //
ast::StructMemberOffsetDecoration, //
ast::StructMemberSizeDecoration, //
ast::StructMemberAlignDecoration>()) {
if (deco->Is<ast::StrideDecoration>() &&
IsValidationDisabled(
member->Declaration()->decorations(),
ast::DisabledValidation::kIgnoreStrideDecoration)) {
continue;
}
AddError("decoration is not valid for structure members",
deco->source());
return false;

View File

@ -0,0 +1,251 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/transform/decompose_strided_matrix.h"
#include <unordered_map>
#include <utility>
#include <vector>
#include "src/program_builder.h"
#include "src/sem/expression.h"
#include "src/sem/member_accessor_expression.h"
#include "src/transform/inline_pointer_lets.h"
#include "src/transform/simplify.h"
#include "src/utils/get_or_create.h"
#include "src/utils/hash.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeStridedMatrix);
namespace tint {
namespace transform {
namespace {
/// MatrixInfo describes a matrix member with a custom stride
struct MatrixInfo {
/// The stride in bytes between columns of the matrix
uint32_t stride = 0;
/// The type of the matrix
sem::Matrix const* matrix = nullptr;
/// @returns a new ast::Array that holds an vector column for each row of the
/// matrix.
ast::Array* array(ProgramBuilder* b) const {
return b->ty.array(b->ty.vec<ProgramBuilder::f32>(matrix->rows()),
matrix->columns(), stride);
}
/// Equality operator
bool operator==(const MatrixInfo& info) const {
return stride == info.stride && matrix == info.matrix;
}
/// Hash function
struct Hasher {
size_t operator()(const MatrixInfo& t) const {
return utils::Hash(t.stride, t.matrix);
}
};
};
/// Return type of the callback function of GatherCustomStrideMatrixMembers
enum GatherResult { kContinue, kStop };
/// GatherCustomStrideMatrixMembers scans `program` for all matrix members of
/// storage and uniform structs, which are of a matrix type, and have a custom
/// matrix stride attribute. For each matrix member found, `callback` is called.
/// `callback` is a function with the signature:
/// GatherResult(const sem::StructMember* member,
/// sem::Matrix* matrix,
/// uint32_t stride)
/// If `callback` return GatherResult::kStop, then the scanning will immediately
/// terminate, and GatherCustomStrideMatrixMembers() will return, otherwise
/// scanning will continue.
template <typename F>
void GatherCustomStrideMatrixMembers(const Program* program, F&& callback) {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* str = node->As<ast::Struct>()) {
auto* str_ty = program->Sem().Get(str);
if (!str_ty->UsedAs(ast::StorageClass::kUniform) &&
!str_ty->UsedAs(ast::StorageClass::kStorage)) {
continue;
}
for (auto* member : str_ty->Members()) {
auto* matrix = member->Type()->As<sem::Matrix>();
if (!matrix) {
continue;
}
auto* deco = ast::GetDecoration<ast::StrideDecoration>(
member->Declaration()->decorations());
if (!deco) {
continue;
}
uint32_t stride = deco->stride();
if (matrix->ColumnStride() == stride) {
continue;
}
if (callback(member, matrix, stride) == GatherResult::kStop) {
return;
}
}
}
}
}
} // namespace
DecomposeStridedMatrix::DecomposeStridedMatrix() = default;
DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
bool DecomposeStridedMatrix::ShouldRun(const Program* program) {
bool should_run = false;
GatherCustomStrideMatrixMembers(
program, [&](const sem::StructMember*, sem::Matrix*, uint32_t) {
should_run = true;
return GatherResult::kStop;
});
return should_run;
}
void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) {
if (!Requires<InlinePointerLets, Simplify>(ctx)) {
return;
}
// Scan the program for all storage and uniform structure matrix members with
// a custom stride attribute. Replace these matrices with an equivalent array,
// and populate the `decomposed` map with the members that have been replaced.
std::unordered_map<ast::StructMember*, MatrixInfo> decomposed;
GatherCustomStrideMatrixMembers(
ctx.src, [&](const sem::StructMember* member, sem::Matrix* matrix,
uint32_t stride) {
// We've got ourselves a struct member of a matrix type with a custom
// stride. Replace this with an array of column vectors.
MatrixInfo info{stride, matrix};
auto* replacement = ctx.dst->Member(
member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst));
ctx.Replace(member->Declaration(), replacement);
decomposed.emplace(member->Declaration(), info);
return GatherResult::kContinue;
});
// For all expressions where a single matrix column vector was indexed, we can
// preserve these without calling conversion functions.
// Example:
// ssbo.mat[2] -> ssbo.mat[2]
ctx.ReplaceAll(
[&](ast::ArrayAccessorExpression* expr) -> ast::ArrayAccessorExpression* {
if (auto* access =
ctx.src->Sem().Get<sem::StructMemberAccess>(expr->array())) {
auto it = decomposed.find(access->Member()->Declaration());
if (it != decomposed.end()) {
auto* obj = ctx.CloneWithoutTransform(expr->array());
auto* idx = ctx.Clone(expr->idx_expr());
return ctx.dst->IndexAccessor(obj, idx);
}
}
return nullptr;
});
// For all struct member accesses to the matrix on the LHS of an assignment,
// we need to convert the matrix to the array before assigning to the
// structure.
// Example:
// ssbo.mat = mat_to_arr(m)
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
ctx.ReplaceAll([&](ast::AssignmentStatement* stmt) -> ast::Statement* {
if (auto* access =
ctx.src->Sem().Get<sem::StructMemberAccess>(stmt->lhs())) {
auto it = decomposed.find(access->Member()->Declaration());
if (it == decomposed.end()) {
return nullptr;
}
MatrixInfo info = it->second;
auto fn = utils::GetOrCreate(mat_to_arr, info, [&] {
auto name = ctx.dst->Symbols().New(
"mat" + std::to_string(info.matrix->columns()) + "x" +
std::to_string(info.matrix->rows()) + "_stride_" +
std::to_string(info.stride) + "_to_arr");
auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
auto array = [&] { return info.array(ctx.dst); };
auto mat = ctx.dst->Sym("mat");
ast::ExpressionList columns(info.matrix->columns());
for (uint32_t i = 0; i < static_cast<uint32_t>(columns.size()); i++) {
columns[i] = ctx.dst->IndexAccessor(mat, i);
}
ctx.dst->Func(name,
{
ctx.dst->Param(mat, matrix()),
},
array(),
{
ctx.dst->Return(ctx.dst->Construct(array(), columns)),
});
return name;
});
auto* lhs = ctx.CloneWithoutTransform(stmt->lhs());
auto* rhs = ctx.dst->Call(fn, ctx.Clone(stmt->rhs()));
return ctx.dst->Assign(lhs, rhs);
}
return nullptr;
});
// For all other struct member accesses, we need to convert the array to the
// matrix type. Example:
// m = arr_to_mat(ssbo.mat)
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
ctx.ReplaceAll([&](ast::MemberAccessorExpression* expr) -> ast::Expression* {
if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr)) {
auto it = decomposed.find(access->Member()->Declaration());
if (it == decomposed.end()) {
return nullptr;
}
MatrixInfo info = it->second;
auto fn = utils::GetOrCreate(arr_to_mat, info, [&] {
auto name = ctx.dst->Symbols().New(
"arr_to_mat" + std::to_string(info.matrix->columns()) + "x" +
std::to_string(info.matrix->rows()) + "_stride_" +
std::to_string(info.stride));
auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
auto array = [&] { return info.array(ctx.dst); };
auto arr = ctx.dst->Sym("arr");
ast::ExpressionList columns(info.matrix->columns());
for (uint32_t i = 0; i < static_cast<uint32_t>(columns.size()); i++) {
columns[i] = ctx.dst->IndexAccessor(arr, i);
}
ctx.dst->Func(
name,
{
ctx.dst->Param(arr, array()),
},
matrix(),
{
ctx.dst->Return(ctx.dst->Construct(matrix(), columns)),
});
return name;
});
return ctx.dst->Call(fn, ctx.CloneWithoutTransform(expr));
}
return nullptr;
});
ctx.Clone();
}
} // namespace transform
} // namespace tint

View File

@ -0,0 +1,54 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TRANSFORM_DECOMPOSE_STRIDED_MATRIX_H_
#define SRC_TRANSFORM_DECOMPOSE_STRIDED_MATRIX_H_
#include "src/transform/transform.h"
namespace tint {
namespace transform {
/// DecomposeStridedMatrix transforms replaces matrix members of storage or
/// uniform buffer structures, that have a [[stride]] decoration, into an array
/// of N column vectors.
/// This transform is used by the SPIR-V reader to handle the SPIR-V
/// MatrixStride decoration.
class DecomposeStridedMatrix
: public Castable<DecomposeStridedMatrix, Transform> {
public:
/// Constructor
DecomposeStridedMatrix();
/// Destructor
~DecomposeStridedMatrix() override;
/// @param program the program to inspect
/// @returns true if this transform should be run for the given program
static bool ShouldRun(const Program* program);
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TRANSFORM_DECOMPOSE_STRIDED_MATRIX_H_

View File

@ -0,0 +1,727 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/transform/decompose_strided_matrix.h"
#include <memory>
#include <utility>
#include <vector>
#include "src/ast/disable_validation_decoration.h"
#include "src/program_builder.h"
#include "src/transform/inline_pointer_lets.h"
#include "src/transform/simplify.h"
#include "src/transform/test_helper.h"
namespace tint {
namespace transform {
namespace {
using DecomposeStridedMatrixTest = TransformTest;
using f32 = ProgramBuilder::f32;
TEST_F(DecomposeStridedMatrixTest, Empty) {
auto* src = R"()";
auto* expect = src;
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, MissingDependencyInlinePointerLets) {
auto* src = R"()";
auto* expect =
R"(error: tint::transform::DecomposeStridedMatrix depends on tint::transform::InlinePointerLets but the dependency was not run)";
auto got = Run<Simplify, DecomposeStridedMatrix>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, MissingDependencySimplify) {
auto* src = R"()";
auto* expect =
R"(error: tint::transform::DecomposeStridedMatrix depends on tint::transform::Simplify but the dependency was not run)";
auto got = Run<InlinePointerLets, DecomposeStridedMatrix>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix) {
// [[block]]
// struct S {
// [[offset(16), stride(32)]]
// [[internal(ignore_stride_decoration)]]
// m : mat2x2<f32>;
// };
// [[group(0), binding(0)]] var<uniform> s : S;
//
// [[stage(compute), workgroup_size(1)]]
// fn f() {
// let x : mat2x2<f32> = s.m;
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetDecoration>(16),
b.create<ast::StrideDecoration>(32),
b.ASTNodes().Create<ast::DisableValidationDecoration>(
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
}),
},
{
b.StructBlock(),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
b.GroupAndBinding(0, 0));
b.Func(
"f", {}, b.ty.void_(),
{
b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
[[block]]
struct S {
[[size(16)]]
padding : u32;
m : [[stride(32)]] array<vec2<f32>, 2>;
};
[[group(0), binding(0)]] var<uniform> s : S;
fn arr_to_mat2x2_stride_32(arr : [[stride(32)]] array<vec2<f32>, 2>) -> mat2x2<f32> {
return mat2x2<f32>(arr[0u], arr[1u]);
}
[[stage(compute), workgroup_size(1)]]
fn f() {
let x : mat2x2<f32> = arr_to_mat2x2_stride_32(s.m);
}
)";
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadUniformColumn) {
// [[block]]
// struct S {
// [[offset(16), stride(32)]]
// [[internal(ignore_stride_decoration)]]
// m : mat2x2<f32>;
// };
// [[group(0), binding(0)]] var<uniform> s : S;
//
// [[stage(compute), workgroup_size(1)]]
// fn f() {
// let x : vec2<f32> = s.m[1];
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetDecoration>(16),
b.create<ast::StrideDecoration>(32),
b.ASTNodes().Create<ast::DisableValidationDecoration>(
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
}),
},
{
b.StructBlock(),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
b.GroupAndBinding(0, 0));
b.Func("f", {}, b.ty.void_(),
{
b.Decl(b.Const("x", b.ty.vec2<f32>(),
b.IndexAccessor(b.MemberAccessor("s", "m"), 1))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
[[block]]
struct S {
[[size(16)]]
padding : u32;
m : [[stride(32)]] array<vec2<f32>, 2>;
};
[[group(0), binding(0)]] var<uniform> s : S;
[[stage(compute), workgroup_size(1)]]
fn f() {
let x : vec2<f32> = s.m[1];
}
)";
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix_DefaultStride) {
// [[block]]
// struct S {
// [[offset(16), stride(8)]]
// [[internal(ignore_stride_decoration)]]
// m : mat2x2<f32>;
// };
// [[group(0), binding(0)]] var<uniform> s : S;
//
// [[stage(compute), workgroup_size(1)]]
// fn f() {
// let x : mat2x2<f32> = s.m;
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetDecoration>(16),
b.create<ast::StrideDecoration>(8),
b.ASTNodes().Create<ast::DisableValidationDecoration>(
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
}),
},
{
b.StructBlock(),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
b.GroupAndBinding(0, 0));
b.Func(
"f", {}, b.ty.void_(),
{
b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
[[block]]
struct S {
[[size(16)]]
padding : u32;
[[stride(8), internal(disable_validation__ignore_stride)]]
m : mat2x2<f32>;
};
[[group(0), binding(0)]] var<uniform> s : S;
[[stage(compute), workgroup_size(1)]]
fn f() {
let x : mat2x2<f32> = s.m;
}
)";
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadStorageMatrix) {
// [[block]]
// struct S {
// [[offset(8), stride(32)]]
// [[internal(ignore_stride_decoration)]]
// m : mat2x2<f32>;
// };
// [[group(0), binding(0)]] var<storage, read_write> s : S;
//
// [[stage(compute), workgroup_size(1)]]
// fn f() {
// let x : mat2x2<f32> = s.m;
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetDecoration>(8),
b.create<ast::StrideDecoration>(32),
b.ASTNodes().Create<ast::DisableValidationDecoration>(
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
}),
},
{
b.StructBlock(),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func(
"f", {}, b.ty.void_(),
{
b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
[[block]]
struct S {
[[size(8)]]
padding : u32;
m : [[stride(32)]] array<vec2<f32>, 2>;
};
[[group(0), binding(0)]] var<storage, read_write> s : S;
fn arr_to_mat2x2_stride_32(arr : [[stride(32)]] array<vec2<f32>, 2>) -> mat2x2<f32> {
return mat2x2<f32>(arr[0u], arr[1u]);
}
[[stage(compute), workgroup_size(1)]]
fn f() {
let x : mat2x2<f32> = arr_to_mat2x2_stride_32(s.m);
}
)";
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadStorageColumn) {
// [[block]]
// struct S {
// [[offset(16), stride(32)]]
// [[internal(ignore_stride_decoration)]]
// m : mat2x2<f32>;
// };
// [[group(0), binding(0)]] var<storage, read_write> s : S;
//
// [[stage(compute), workgroup_size(1)]]
// fn f() {
// let x : vec2<f32> = s.m[1];
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetDecoration>(16),
b.create<ast::StrideDecoration>(32),
b.ASTNodes().Create<ast::DisableValidationDecoration>(
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
}),
},
{
b.StructBlock(),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func("f", {}, b.ty.void_(),
{
b.Decl(b.Const("x", b.ty.vec2<f32>(),
b.IndexAccessor(b.MemberAccessor("s", "m"), 1))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
[[block]]
struct S {
[[size(16)]]
padding : u32;
m : [[stride(32)]] array<vec2<f32>, 2>;
};
[[group(0), binding(0)]] var<storage, read_write> s : S;
[[stage(compute), workgroup_size(1)]]
fn f() {
let x : vec2<f32> = s.m[1];
}
)";
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, WriteStorageMatrix) {
// [[block]]
// struct S {
// [[offset(8), stride(32)]]
// [[internal(ignore_stride_decoration)]]
// m : mat2x2<f32>;
// };
// [[group(0), binding(0)]] var<storage, read_write> s : S;
//
// [[stage(compute), workgroup_size(1)]]
// fn f() {
// s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetDecoration>(8),
b.create<ast::StrideDecoration>(32),
b.ASTNodes().Create<ast::DisableValidationDecoration>(
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
}),
},
{
b.StructBlock(),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func("f", {}, b.ty.void_(),
{
b.Assign(b.MemberAccessor("s", "m"),
b.mat2x2<f32>(b.vec2<f32>(1.0f, 2.0f),
b.vec2<f32>(3.0f, 4.0f))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
[[block]]
struct S {
[[size(8)]]
padding : u32;
m : [[stride(32)]] array<vec2<f32>, 2>;
};
[[group(0), binding(0)]] var<storage, read_write> s : S;
fn mat2x2_stride_32_to_arr(mat : mat2x2<f32>) -> [[stride(32)]] array<vec2<f32>, 2> {
return [[stride(32)]] array<vec2<f32>, 2>(mat[0u], mat[1u]);
}
[[stage(compute), workgroup_size(1)]]
fn f() {
s.m = mat2x2_stride_32_to_arr(mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0)));
}
)";
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, WriteStorageColumn) {
// [[block]]
// struct S {
// [[offset(8), stride(32)]]
// [[internal(ignore_stride_decoration)]]
// m : mat2x2<f32>;
// };
// [[group(0), binding(0)]] var<storage, read_write> s : S;
//
// [[stage(compute), workgroup_size(1)]]
// fn f() {
// s.m[1] = vec2<f32>(1.0, 2.0);
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetDecoration>(8),
b.create<ast::StrideDecoration>(32),
b.ASTNodes().Create<ast::DisableValidationDecoration>(
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
}),
},
{
b.StructBlock(),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func("f", {}, b.ty.void_(),
{
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m"), 1),
b.vec2<f32>(1.0f, 2.0f)),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
[[block]]
struct S {
[[size(8)]]
padding : u32;
m : [[stride(32)]] array<vec2<f32>, 2>;
};
[[group(0), binding(0)]] var<storage, read_write> s : S;
[[stage(compute), workgroup_size(1)]]
fn f() {
s.m[1] = vec2<f32>(1.0, 2.0);
}
)";
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadWriteViaPointerLets) {
// [[block]]
// struct S {
// [[offset(8), stride(32)]]
// [[internal(ignore_stride_decoration)]]
// m : mat2x2<f32>;
// };
// [[group(0), binding(0)]] var<storage, read_write> s : S;
//
// [[stage(compute), workgroup_size(1)]]
// fn f() {
// let a = &s.m;
// let b = &*&*(a);
// let x = *b;
// let y = (*b)[1];
// let z = x[1];
// (*b) = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
// (*b)[1] = vec2<f32>(5.0, 6.0);
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetDecoration>(8),
b.create<ast::StrideDecoration>(32),
b.ASTNodes().Create<ast::DisableValidationDecoration>(
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
}),
},
{
b.StructBlock(),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func(
"f", {}, b.ty.void_(),
{
b.Decl(
b.Const("a", nullptr, b.AddressOf(b.MemberAccessor("s", "m")))),
b.Decl(b.Const("b", nullptr,
b.AddressOf(b.Deref(b.AddressOf(b.Deref("a")))))),
b.Decl(b.Const("x", nullptr, b.Deref("b"))),
b.Decl(b.Const("y", nullptr, b.IndexAccessor(b.Deref("b"), 1))),
b.Decl(b.Const("z", nullptr, b.IndexAccessor("x", 1))),
b.Assign(b.Deref("b"), b.mat2x2<f32>(b.vec2<f32>(1.0f, 2.0f),
b.vec2<f32>(3.0f, 4.0f))),
b.Assign(b.IndexAccessor(b.Deref("b"), 1), b.vec2<f32>(5.0f, 6.0f)),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
[[block]]
struct S {
[[size(8)]]
padding : u32;
m : [[stride(32)]] array<vec2<f32>, 2>;
};
[[group(0), binding(0)]] var<storage, read_write> s : S;
fn arr_to_mat2x2_stride_32(arr : [[stride(32)]] array<vec2<f32>, 2>) -> mat2x2<f32> {
return mat2x2<f32>(arr[0u], arr[1u]);
}
fn mat2x2_stride_32_to_arr(mat : mat2x2<f32>) -> [[stride(32)]] array<vec2<f32>, 2> {
return [[stride(32)]] array<vec2<f32>, 2>(mat[0u], mat[1u]);
}
[[stage(compute), workgroup_size(1)]]
fn f() {
let x = arr_to_mat2x2_stride_32(s.m);
let y = s.m[1];
let z = x[1];
s.m = mat2x2_stride_32_to_arr(mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0)));
s.m[1] = vec2<f32>(5.0, 6.0);
}
)";
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadPrivateMatrix) {
// struct S {
// [[offset(8), stride(32)]]
// [[internal(ignore_stride_decoration)]]
// m : mat2x2<f32>;
// };
// var<private> s : S;
//
// [[stage(compute), workgroup_size(1)]]
// fn f() {
// let x : mat2x2<f32> = s.m;
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetDecoration>(8),
b.create<ast::StrideDecoration>(32),
b.ASTNodes().Create<ast::DisableValidationDecoration>(
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
}),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kPrivate);
b.Func(
"f", {}, b.ty.void_(),
{
b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
struct S {
[[size(8)]]
padding : u32;
[[stride(32), internal(disable_validation__ignore_stride)]]
m : mat2x2<f32>;
};
var<private> s : S;
[[stage(compute), workgroup_size(1)]]
fn f() {
let x : mat2x2<f32> = s.m;
}
)";
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, WritePrivateMatrix) {
// struct S {
// [[offset(8), stride(32)]]
// [[internal(ignore_stride_decoration)]]
// m : mat2x2<f32>;
// };
// var<private> s : S;
//
// [[stage(compute), workgroup_size(1)]]
// fn f() {
// s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetDecoration>(8),
b.create<ast::StrideDecoration>(32),
b.ASTNodes().Create<ast::DisableValidationDecoration>(
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
}),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kPrivate);
b.Func("f", {}, b.ty.void_(),
{
b.Assign(b.MemberAccessor("s", "m"),
b.mat2x2<f32>(b.vec2<f32>(1.0f, 2.0f),
b.vec2<f32>(3.0f, 4.0f))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
struct S {
[[size(8)]]
padding : u32;
[[stride(32), internal(disable_validation__ignore_stride)]]
m : mat2x2<f32>;
};
var<private> s : S;
[[stage(compute), workgroup_size(1)]]
fn f() {
s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
}
)";
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@ -59,6 +59,16 @@ class TransformTestBase : public BASE {
// Keep this pointer alive after Transform() returns
files_.emplace_back(std::move(file));
return Run<TRANSFORMS...>(std::move(program), data);
}
/// Transforms and returns program `program`, transformed using a transform of
/// type `TRANSFORM`.
/// @param program the input Program
/// @param data the optional DataMap to pass to Transform::Run()
/// @return the transformed output
template <typename... TRANSFORMS>
Output Run(Program&& program, const DataMap& data = {}) {
if (!program.IsValid()) {
return Output(std::move(program));
}

View File

@ -696,6 +696,8 @@ bool GeneratorImpl::EmitDecorations(std::ostream& out,
out << "size(" << size->size() << ")";
} else if (auto* align = deco->As<ast::StructMemberAlignDecoration>()) {
out << "align(" << align->align() << ")";
} else if (auto* stride = deco->As<ast::StrideDecoration>()) {
out << "stride(" << stride->stride() << ")";
} else if (auto* internal = deco->As<ast::InternalDecoration>()) {
out << "internal(" << internal->InternalName() << ")";
} else {

View File

@ -289,6 +289,7 @@ tint_unittests_source_set("tint_unittests_core_src") {
"../src/transform/calculate_array_length_test.cc",
"../src/transform/canonicalize_entry_point_io_test.cc",
"../src/transform/decompose_memory_access_test.cc",
"../src/transform/decompose_strided_matrix_test.cc",
"../src/transform/external_texture_transform_test.cc",
"../src/transform/first_index_offset_test.cc",
"../src/transform/fold_constants_test.cc",

View File

@ -0,0 +1,11 @@
[[block]]
struct SSBO {
m : mat2x2<f32>;
};
[[group(0), binding(0)]] var<storage, read_write> ssbo : SSBO;
[[stage(compute), workgroup_size(1)]]
fn f() {
let v = ssbo.m;
ssbo.m = v;
}

View File

@ -0,0 +1,17 @@
RWByteAddressBuffer ssbo : register(u0, space0);
float2x2 tint_symbol(RWByteAddressBuffer buffer, uint offset) {
return float2x2(asfloat(buffer.Load2((offset + 0u))), asfloat(buffer.Load2((offset + 8u))));
}
void tint_symbol_2(RWByteAddressBuffer buffer, uint offset, float2x2 value) {
buffer.Store2((offset + 0u), asuint(value[0u]));
buffer.Store2((offset + 8u), asuint(value[1u]));
}
[numthreads(1, 1, 1)]
void f() {
const float2x2 v = tint_symbol(ssbo, 0u);
tint_symbol_2(ssbo, 0u, v);
return;
}

View File

@ -0,0 +1,13 @@
#include <metal_stdlib>
using namespace metal;
struct SSBO {
/* 0x0000 */ float2x2 m;
};
kernel void f(device SSBO& ssbo [[buffer(0)]]) {
float2x2 const v = ssbo.m;
ssbo.m = v;
return;
}

View File

@ -0,0 +1,38 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 17
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %f "f"
OpExecutionMode %f LocalSize 1 1 1
OpName %SSBO "SSBO"
OpMemberName %SSBO 0 "m"
OpName %ssbo "ssbo"
OpName %f "f"
OpDecorate %SSBO Block
OpMemberDecorate %SSBO 0 Offset 0
OpMemberDecorate %SSBO 0 ColMajor
OpMemberDecorate %SSBO 0 MatrixStride 8
OpDecorate %ssbo DescriptorSet 0
OpDecorate %ssbo Binding 0
%float = OpTypeFloat 32
%v2float = OpTypeVector %float 2
%mat2v2float = OpTypeMatrix %v2float 2
%SSBO = OpTypeStruct %mat2v2float
%_ptr_StorageBuffer_SSBO = OpTypePointer StorageBuffer %SSBO
%ssbo = OpVariable %_ptr_StorageBuffer_SSBO StorageBuffer
%void = OpTypeVoid
%7 = OpTypeFunction %void
%uint = OpTypeInt 32 0
%uint_0 = OpConstant %uint 0
%_ptr_StorageBuffer_mat2v2float = OpTypePointer StorageBuffer %mat2v2float
%f = OpFunction %void None %7
%10 = OpLabel
%14 = OpAccessChain %_ptr_StorageBuffer_mat2v2float %ssbo %uint_0
%15 = OpLoad %mat2v2float %14
%16 = OpAccessChain %_ptr_StorageBuffer_mat2v2float %ssbo %uint_0
OpStore %16 %15
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,12 @@
[[block]]
struct SSBO {
m : mat2x2<f32>;
};
[[group(0), binding(0)]] var<storage, read_write> ssbo : SSBO;
[[stage(compute), workgroup_size(1)]]
fn f() {
let v = ssbo.m;
ssbo.m = v;
}

View File

@ -0,0 +1,33 @@
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %f "f"
OpExecutionMode %f LocalSize 1 1 1
OpName %SSBO "SSBO"
OpMemberName %SSBO 0 "m"
OpName %ssbo "ssbo"
OpName %f "f"
OpDecorate %SSBO Block
OpMemberDecorate %SSBO 0 Offset 0
OpMemberDecorate %SSBO 0 ColMajor
OpMemberDecorate %SSBO 0 MatrixStride 16
OpDecorate %ssbo DescriptorSet 0
OpDecorate %ssbo Binding 0
%float = OpTypeFloat 32
%v2float = OpTypeVector %float 2
%mat2v2float = OpTypeMatrix %v2float 2
%SSBO = OpTypeStruct %mat2v2float
%_ptr_StorageBuffer_SSBO = OpTypePointer StorageBuffer %SSBO
%ssbo = OpVariable %_ptr_StorageBuffer_SSBO StorageBuffer
%void = OpTypeVoid
%7 = OpTypeFunction %void
%uint = OpTypeInt 32 0
%uint_0 = OpConstant %uint 0
%_ptr_StorageBuffer_mat2v2float = OpTypePointer StorageBuffer %mat2v2float
%f = OpFunction %void None %7
%10 = OpLabel
%14 = OpAccessChain %_ptr_StorageBuffer_mat2v2float %ssbo %uint_0
%15 = OpLoad %mat2v2float %14
%16 = OpAccessChain %_ptr_StorageBuffer_mat2v2float %ssbo %uint_0
OpStore %16 %15
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,47 @@
struct tint_padded_array_element {
float2 el;
};
RWByteAddressBuffer ssbo : register(u0, space0);
float2x2 arr_to_mat2x2_stride_16(tint_padded_array_element arr[2]) {
return float2x2(arr[0u].el, arr[1u].el);
}
typedef tint_padded_array_element mat2x2_stride_16_to_arr_ret[2];
mat2x2_stride_16_to_arr_ret mat2x2_stride_16_to_arr(float2x2 mat) {
const tint_padded_array_element tint_symbol_4[2] = {{mat[0u]}, {mat[1u]}};
return tint_symbol_4;
}
typedef tint_padded_array_element tint_symbol_ret[2];
tint_symbol_ret tint_symbol(RWByteAddressBuffer buffer, uint offset) {
tint_padded_array_element arr_1[2] = (tint_padded_array_element[2])0;
{
for(uint i = 0u; (i < 2u); i = (i + 1u)) {
arr_1[i].el = asfloat(buffer.Load2((offset + (i * 16u))));
}
}
return arr_1;
}
void tint_symbol_2(RWByteAddressBuffer buffer, uint offset, tint_padded_array_element value[2]) {
tint_padded_array_element array[2] = value;
{
for(uint i_1 = 0u; (i_1 < 2u); i_1 = (i_1 + 1u)) {
buffer.Store2((offset + (i_1 * 16u)), asuint(array[i_1].el));
}
}
}
void f_1() {
const float2x2 x_15 = arr_to_mat2x2_stride_16(tint_symbol(ssbo, 0u));
tint_symbol_2(ssbo, 0u, mat2x2_stride_16_to_arr(x_15));
return;
}
[numthreads(1, 1, 1)]
void f() {
f_1();
return;
}

View File

@ -0,0 +1,34 @@
#include <metal_stdlib>
using namespace metal;
struct tint_padded_array_element {
/* 0x0000 */ packed_float2 el;
/* 0x0008 */ int8_t tint_pad[8];
};
struct tint_array_wrapper {
/* 0x0000 */ tint_padded_array_element arr[2];
};
struct SSBO {
/* 0x0000 */ tint_array_wrapper m;
};
float2x2 arr_to_mat2x2_stride_16(tint_array_wrapper arr) {
return float2x2(arr.arr[0u].el, arr.arr[1u].el);
}
tint_array_wrapper mat2x2_stride_16_to_arr(float2x2 mat) {
tint_array_wrapper const tint_symbol = {.arr={{.el=mat[0u]}, {.el=mat[1u]}}};
return tint_symbol;
}
void f_1(device SSBO& ssbo) {
float2x2 const x_15 = arr_to_mat2x2_stride_16(ssbo.m);
ssbo.m = mat2x2_stride_16_to_arr(x_15);
return;
}
kernel void f(device SSBO& ssbo [[buffer(0)]]) {
f_1(ssbo);
return;
}

View File

@ -0,0 +1,70 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 39
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %f "f"
OpExecutionMode %f LocalSize 1 1 1
OpName %SSBO "SSBO"
OpMemberName %SSBO 0 "m"
OpName %ssbo "ssbo"
OpName %arr_to_mat2x2_stride_16 "arr_to_mat2x2_stride_16"
OpName %arr "arr"
OpName %mat2x2_stride_16_to_arr "mat2x2_stride_16_to_arr"
OpName %mat "mat"
OpName %f_1 "f_1"
OpName %f "f"
OpDecorate %SSBO Block
OpMemberDecorate %SSBO 0 Offset 0
OpDecorate %_arr_v2float_uint_2 ArrayStride 16
OpDecorate %ssbo DescriptorSet 0
OpDecorate %ssbo Binding 0
%float = OpTypeFloat 32
%v2float = OpTypeVector %float 2
%uint = OpTypeInt 32 0
%uint_2 = OpConstant %uint 2
%_arr_v2float_uint_2 = OpTypeArray %v2float %uint_2
%SSBO = OpTypeStruct %_arr_v2float_uint_2
%_ptr_StorageBuffer_SSBO = OpTypePointer StorageBuffer %SSBO
%ssbo = OpVariable %_ptr_StorageBuffer_SSBO StorageBuffer
%mat2v2float = OpTypeMatrix %v2float 2
%9 = OpTypeFunction %mat2v2float %_arr_v2float_uint_2
%uint_0 = OpConstant %uint 0
%uint_1 = OpConstant %uint 1
%19 = OpTypeFunction %_arr_v2float_uint_2 %mat2v2float
%void = OpTypeVoid
%26 = OpTypeFunction %void
%_ptr_StorageBuffer__arr_v2float_uint_2 = OpTypePointer StorageBuffer %_arr_v2float_uint_2
%arr_to_mat2x2_stride_16 = OpFunction %mat2v2float None %9
%arr = OpFunctionParameter %_arr_v2float_uint_2
%13 = OpLabel
%15 = OpCompositeExtract %v2float %arr 0
%17 = OpCompositeExtract %v2float %arr 1
%18 = OpCompositeConstruct %mat2v2float %15 %17
OpReturnValue %18
OpFunctionEnd
%mat2x2_stride_16_to_arr = OpFunction %_arr_v2float_uint_2 None %19
%mat = OpFunctionParameter %mat2v2float
%22 = OpLabel
%23 = OpCompositeExtract %v2float %mat 0
%24 = OpCompositeExtract %v2float %mat 1
%25 = OpCompositeConstruct %_arr_v2float_uint_2 %23 %24
OpReturnValue %25
OpFunctionEnd
%f_1 = OpFunction %void None %26
%29 = OpLabel
%32 = OpAccessChain %_ptr_StorageBuffer__arr_v2float_uint_2 %ssbo %uint_0
%33 = OpLoad %_arr_v2float_uint_2 %32
%30 = OpFunctionCall %mat2v2float %arr_to_mat2x2_stride_16 %33
%34 = OpAccessChain %_ptr_StorageBuffer__arr_v2float_uint_2 %ssbo %uint_0
%35 = OpFunctionCall %_arr_v2float_uint_2 %mat2x2_stride_16_to_arr %30
OpStore %34 %35
OpReturn
OpFunctionEnd
%f = OpFunction %void None %26
%37 = OpLabel
%38 = OpFunctionCall %void %f_1
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,25 @@
[[block]]
struct SSBO {
m : [[stride(16)]] array<vec2<f32>, 2>;
};
[[group(0), binding(0)]] var<storage, read_write> ssbo : SSBO;
fn arr_to_mat2x2_stride_16(arr : [[stride(16)]] array<vec2<f32>, 2>) -> mat2x2<f32> {
return mat2x2<f32>(arr[0u], arr[1u]);
}
fn mat2x2_stride_16_to_arr(mat : mat2x2<f32>) -> [[stride(16)]] array<vec2<f32>, 2> {
return [[stride(16)]] array<vec2<f32>, 2>(mat[0u], mat[1u]);
}
fn f_1() {
let x_15 : mat2x2<f32> = arr_to_mat2x2_stride_16(ssbo.m);
ssbo.m = mat2x2_stride_16_to_arr(x_15);
return;
}
[[stage(compute), workgroup_size(1, 1, 1)]]
fn f() {
f_1();
}