diff --git a/src/BUILD.gn b/src/BUILD.gn index acb21d857a..32a2b15118 100644 --- a/src/BUILD.gn +++ b/src/BUILD.gn @@ -426,6 +426,8 @@ libtint_source_set("libtint_core_all_src") { "transform/canonicalize_entry_point_io.h", "transform/decompose_memory_access.cc", "transform/decompose_memory_access.h", + "transform/decompose_strided_matrix.cc", + "transform/decompose_strided_matrix.h", "transform/external_texture_transform.cc", "transform/external_texture_transform.h", "transform/first_index_offset.cc", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b56686d504..fbf50afe11 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -296,6 +296,8 @@ set(TINT_LIB_SRCS transform/canonicalize_entry_point_io.h transform/decompose_memory_access.cc transform/decompose_memory_access.h + transform/decompose_strided_matrix.cc + transform/decompose_strided_matrix.h transform/external_texture_transform.cc transform/external_texture_transform.h transform/first_index_offset.cc @@ -904,6 +906,7 @@ if(${TINT_BUILD_TESTS}) transform/calculate_array_length_test.cc transform/canonicalize_entry_point_io_test.cc transform/decompose_memory_access_test.cc + transform/decompose_strided_matrix_test.cc transform/external_texture_transform_test.cc transform/first_index_offset_test.cc transform/fold_constants_test.cc diff --git a/src/ast/disable_validation_decoration.cc b/src/ast/disable_validation_decoration.cc index ca59d80a80..846530a2e4 100644 --- a/src/ast/disable_validation_decoration.cc +++ b/src/ast/disable_validation_decoration.cc @@ -40,6 +40,8 @@ std::string DisableValidationDecoration::InternalName() const { return "disable_validation__entry_point_parameter"; case DisabledValidation::kIgnoreConstructibleFunctionParameter: return "disable_validation__ignore_constructible_function_parameter"; + case DisabledValidation::kIgnoreStrideDecoration: + return "disable_validation__ignore_stride"; } return ""; } diff --git a/src/ast/disable_validation_decoration.h b/src/ast/disable_validation_decoration.h index 60a9fb7f5f..3ebf0bc112 100644 --- a/src/ast/disable_validation_decoration.h +++ b/src/ast/disable_validation_decoration.h @@ -40,6 +40,9 @@ enum class DisabledValidation { /// When applied to a function parameter, the validator will not /// check if parameter type is constructible kIgnoreConstructibleFunctionParameter, + /// When applied to a member decoration, a stride decoration may be applied to + /// non-array types. + kIgnoreStrideDecoration, }; /// An internal decoration used to tell the validator to ignore specific diff --git a/src/ast/internal_decoration.cc b/src/ast/internal_decoration.cc index 47db806f26..f5e2237350 100644 --- a/src/ast/internal_decoration.cc +++ b/src/ast/internal_decoration.cc @@ -32,7 +32,7 @@ void InternalDecoration::to_str(const sem::Info&, std::ostream& out, size_t indent) const { make_indent(out, indent); - out << "tint_internal(" << InternalName() << ")" << std::endl; + out << "tint_internal(" << InternalName() << ")"; } } // namespace ast diff --git a/src/reader/spirv/parser.cc b/src/reader/spirv/parser.cc index c382d90eb6..bcbc19349e 100644 --- a/src/reader/spirv/parser.cc +++ b/src/reader/spirv/parser.cc @@ -17,6 +17,10 @@ #include #include "src/reader/spirv/parser_impl.h" +#include "src/transform/decompose_strided_matrix.h" +#include "src/transform/inline_pointer_lets.h" +#include "src/transform/manager.h" +#include "src/transform/simplify.h" namespace tint { namespace reader { @@ -40,7 +44,19 @@ Program Parse(const std::vector& input) { ProgramBuilder output; CloneContext(&output, &program_with_disjoint_ast, false).Clone(); - return Program(std::move(output)); + auto program = Program(std::move(output)); + + // If the generated program contains matrices with a custom MatrixStride + // attribute then we need to decompose these into an array of vectors + if (transform::DecomposeStridedMatrix::ShouldRun(&program)) { + transform::Manager manager; + manager.Add(); + manager.Add(); + manager.Add(); + return manager.Run(&program).program; + } + + return program; } } // namespace spirv diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc index 7981c0ab69..de9fe80932 100644 --- a/src/reader/spirv/parser_impl.cc +++ b/src/reader/spirv/parser_impl.cc @@ -21,6 +21,7 @@ #include "source/opt/build_module.h" #include "src/ast/bitcast_expression.h" +#include "src/ast/disable_validation_decoration.h" #include "src/ast/interpolate_decoration.h" #include "src/ast/override_decoration.h" #include "src/ast/struct_block_decoration.h" @@ -439,13 +440,14 @@ std::string ParserImpl::ShowType(uint32_t type_id) { return "SPIR-V type " + std::to_string(type_id); } -ast::Decoration* ParserImpl::ConvertMemberDecoration( +ast::DecorationList ParserImpl::ConvertMemberDecoration( uint32_t struct_type_id, uint32_t member_index, + const Type* member_ty, const Decoration& decoration) { if (decoration.empty()) { Fail() << "malformed SPIR-V decoration: it's empty"; - return nullptr; + return {}; } switch (decoration[0]) { case SpvDecorationOffset: @@ -454,38 +456,49 @@ ast::Decoration* ParserImpl::ConvertMemberDecoration( << "malformed Offset decoration: expected 1 literal operand, has " << decoration.size() - 1 << ": member " << member_index << " of " << ShowType(struct_type_id); - return nullptr; + return {}; } - return create(Source{}, decoration[1]); + return { + create(Source{}, decoration[1]), + }; case SpvDecorationNonReadable: // WGSL doesn't have a member decoration for this. Silently drop it. - return nullptr; + return {}; case SpvDecorationNonWritable: // WGSL doesn't have a member decoration for this. - return nullptr; + return {}; case SpvDecorationColMajor: // WGSL only supports column major matrices. - return nullptr; + return {}; case SpvDecorationRelaxedPrecision: // WGSL doesn't support relaxed precision. - return nullptr; + return {}; case SpvDecorationRowMajor: Fail() << "WGSL does not support row-major matrices: can't " "translate member " << member_index << " of " << ShowType(struct_type_id); - return nullptr; + return {}; case SpvDecorationMatrixStride: { if (decoration.size() != 2) { Fail() << "malformed MatrixStride decoration: expected 1 literal " "operand, has " << decoration.size() - 1 << ": member " << member_index << " of " << ShowType(struct_type_id); - return nullptr; + return {}; } - // TODO(dneto): Fail if the matrix stride is not allocation size of the - // column vector of the underlying matrix. This would need to unpack - // any levels of array-ness. - return nullptr; + uint32_t stride = decoration[1]; + uint32_t natural_stride = 0; + if (auto* mat = member_ty->As()) { + natural_stride = (mat->rows == 2) ? 8 : 16; + } + if (stride == natural_stride) { + return {}; + } + return { + create(Source{}, decoration[1]), + builder_.ASTNodes().Create( + builder_.ID(), ast::DisabledValidation::kIgnoreStrideDecoration), + }; } default: // TODO(dneto): Support the remaining member decorations. @@ -493,7 +506,7 @@ ast::Decoration* ParserImpl::ConvertMemberDecoration( } Fail() << "unhandled member decoration: " << decoration[0] << " on member " << member_index << " of " << ShowType(struct_type_id); - return nullptr; + return {}; } bool ParserImpl::BuildInternalModule() { @@ -1126,14 +1139,14 @@ const Type* ParserImpl::ConvertType( // the members are non-writable. is_non_writable = true; } else { - auto* ast_member_decoration = - ConvertMemberDecoration(type_id, member_index, decoration); + auto decos = ConvertMemberDecoration(type_id, member_index, + ast_member_ty, decoration); + for (auto* deco : decos) { + ast_member_decorations.emplace_back(deco); + } if (!success_) { return nullptr; } - if (ast_member_decoration) { - ast_member_decorations.push_back(ast_member_decoration); - } } } diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h index 7a697084ee..6f80bcb684 100644 --- a/src/reader/spirv/parser_impl.h +++ b/src/reader/spirv/parser_impl.h @@ -272,16 +272,19 @@ class ParserImpl : Reader { ast::Decoration* SetLocation(ast::DecorationList* decos, ast::Decoration* replacement); - /// Converts a SPIR-V struct member decoration. If the decoration is - /// recognized but deliberately dropped, then returns nullptr without a - /// diagnostic. On failure, emits a diagnostic and returns nullptr. + /// Converts a SPIR-V struct member decoration into a number of AST + /// decorations. If the decoration is recognized but deliberately dropped, + /// then returns an empty list without a diagnostic. On failure, emits a + /// diagnostic and returns an empty list. /// @param struct_type_id the ID of the struct type /// @param member_index the index of the member + /// @param member_ty the type of the member /// @param decoration an encoded SPIR-V Decoration - /// @returns the corresponding ast::StructuMemberDecoration - ast::Decoration* ConvertMemberDecoration(uint32_t struct_type_id, - uint32_t member_index, - const Decoration& decoration); + /// @returns the AST decorations + ast::DecorationList ConvertMemberDecoration(uint32_t struct_type_id, + uint32_t member_index, + const Type* member_ty, + const Decoration& decoration); /// Returns a string for the given type. If the type ID is invalid, /// then the resulting string only names the type ID. diff --git a/src/reader/spirv/parser_impl_convert_member_decoration_test.cc b/src/reader/spirv/parser_impl_convert_member_decoration_test.cc index 01f29d41e6..d431fb5e51 100644 --- a/src/reader/spirv/parser_impl_convert_member_decoration_test.cc +++ b/src/reader/spirv/parser_impl_convert_member_decoration_test.cc @@ -25,16 +25,17 @@ using ::testing::Eq; TEST_F(SpvParserTest, ConvertMemberDecoration_Empty) { auto p = parser(std::vector{}); - auto* result = p->ConvertMemberDecoration(1, 1, {}); - EXPECT_EQ(result, nullptr); + auto result = p->ConvertMemberDecoration(1, 1, nullptr, {}); + EXPECT_TRUE(result.empty()); EXPECT_THAT(p->error(), Eq("malformed SPIR-V decoration: it's empty")); } TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithoutOperand) { auto p = parser(std::vector{}); - auto* result = p->ConvertMemberDecoration(12, 13, {SpvDecorationOffset}); - EXPECT_EQ(result, nullptr); + auto result = + p->ConvertMemberDecoration(12, 13, nullptr, {SpvDecorationOffset}); + EXPECT_TRUE(result.empty()); EXPECT_THAT(p->error(), Eq("malformed Offset decoration: expected 1 literal " "operand, has 0: member 13 of SPIR-V type 12")); } @@ -42,9 +43,9 @@ TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithoutOperand) { TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithTooManyOperands) { auto p = parser(std::vector{}); - auto* result = - p->ConvertMemberDecoration(12, 13, {SpvDecorationOffset, 3, 4}); - EXPECT_EQ(result, nullptr); + auto result = + p->ConvertMemberDecoration(12, 13, nullptr, {SpvDecorationOffset, 3, 4}); + EXPECT_TRUE(result.empty()); EXPECT_THAT(p->error(), Eq("malformed Offset decoration: expected 1 literal " "operand, has 2: member 13 of SPIR-V type 12")); } @@ -52,32 +53,100 @@ TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithTooManyOperands) { TEST_F(SpvParserTest, ConvertMemberDecoration_Offset) { auto p = parser(std::vector{}); - auto* result = p->ConvertMemberDecoration(1, 1, {SpvDecorationOffset, 8}); - ASSERT_NE(result, nullptr); - EXPECT_TRUE(result->Is()); - auto* offset_deco = result->As(); + auto result = + p->ConvertMemberDecoration(1, 1, nullptr, {SpvDecorationOffset, 8}); + ASSERT_FALSE(result.empty()); + EXPECT_TRUE(result[0]->Is()); + auto* offset_deco = result[0]->As(); ASSERT_NE(offset_deco, nullptr); EXPECT_EQ(offset_deco->offset(), 8u); EXPECT_TRUE(p->error().empty()); } +TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x2_Stride_Natural) { + auto p = parser(std::vector{}); + + spirv::F32 f32; + spirv::Matrix matrix(&f32, 2, 2); + auto result = + p->ConvertMemberDecoration(1, 1, &matrix, {SpvDecorationMatrixStride, 8}); + EXPECT_TRUE(result.empty()); + EXPECT_TRUE(p->error().empty()); +} + +TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x2_Stride_Custom) { + auto p = parser(std::vector{}); + + spirv::F32 f32; + spirv::Matrix matrix(&f32, 2, 2); + auto result = p->ConvertMemberDecoration(1, 1, &matrix, + {SpvDecorationMatrixStride, 16}); + ASSERT_FALSE(result.empty()); + EXPECT_TRUE(result[0]->Is()); + auto* stride_deco = result[0]->As(); + ASSERT_NE(stride_deco, nullptr); + EXPECT_EQ(stride_deco->stride(), 16u); + EXPECT_TRUE(p->error().empty()); +} + +TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x4_Stride_Natural) { + auto p = parser(std::vector{}); + + spirv::F32 f32; + spirv::Matrix matrix(&f32, 2, 4); + auto result = p->ConvertMemberDecoration(1, 1, &matrix, + {SpvDecorationMatrixStride, 16}); + EXPECT_TRUE(result.empty()); + EXPECT_TRUE(p->error().empty()); +} + +TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x4_Stride_Custom) { + auto p = parser(std::vector{}); + + spirv::F32 f32; + spirv::Matrix matrix(&f32, 2, 4); + auto result = p->ConvertMemberDecoration(1, 1, &matrix, + {SpvDecorationMatrixStride, 64}); + ASSERT_FALSE(result.empty()); + EXPECT_TRUE(result[0]->Is()); + auto* stride_deco = result[0]->As(); + ASSERT_NE(stride_deco, nullptr); + EXPECT_EQ(stride_deco->stride(), 64u); + EXPECT_TRUE(p->error().empty()); +} + +TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x3_Stride_Custom) { + auto p = parser(std::vector{}); + + spirv::F32 f32; + spirv::Matrix matrix(&f32, 2, 3); + auto result = p->ConvertMemberDecoration(1, 1, &matrix, + {SpvDecorationMatrixStride, 32}); + ASSERT_FALSE(result.empty()); + EXPECT_TRUE(result[0]->Is()); + auto* stride_deco = result[0]->As(); + ASSERT_NE(stride_deco, nullptr); + EXPECT_EQ(stride_deco->stride(), 32u); + EXPECT_TRUE(p->error().empty()); +} + TEST_F(SpvParserTest, ConvertMemberDecoration_RelaxedPrecision) { // WGSL does not support relaxed precision. Drop it. // It's functionally correct to use full precision f32 instead of // relaxed precision f32. auto p = parser(std::vector{}); - auto* result = - p->ConvertMemberDecoration(1, 1, {SpvDecorationRelaxedPrecision}); - EXPECT_EQ(result, nullptr); + auto result = p->ConvertMemberDecoration(1, 1, nullptr, + {SpvDecorationRelaxedPrecision}); + EXPECT_TRUE(result.empty()); EXPECT_TRUE(p->error().empty()); } TEST_F(SpvParserTest, ConvertMemberDecoration_UnhandledDecoration) { auto p = parser(std::vector{}); - auto* result = p->ConvertMemberDecoration(12, 13, {12345678}); - EXPECT_EQ(result, nullptr); + auto result = p->ConvertMemberDecoration(12, 13, nullptr, {12345678}); + EXPECT_TRUE(result.empty()); EXPECT_THAT(p->error(), Eq("unhandled member decoration: 12345678 on member " "13 of SPIR-V type 12")); } diff --git a/src/reader/spirv/parser_impl_module_var_test.cc b/src/reader/spirv/parser_impl_module_var_test.cc index b844a161a1..e884362ef9 100644 --- a/src/reader/spirv/parser_impl_module_var_test.cc +++ b/src/reader/spirv/parser_impl_module_var_test.cc @@ -2015,7 +2015,7 @@ TEST_F(SpvModuleScopeVarParserTest, ColMajorDecoration_Dropped) { })")) << module_str; } -TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration_Dropped) { +TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration_Natural_Dropped) { auto p = parser(test::Assemble(Preamble() + FragMain() + R"( OpName %myvar "myvar" OpDecorate %myvar DescriptorSet 0 @@ -2054,6 +2054,45 @@ TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration_Dropped) { })")) << module_str; } +TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration) { + auto p = parser(test::Assemble(Preamble() + FragMain() + R"( + OpName %myvar "myvar" + OpDecorate %myvar DescriptorSet 0 + OpDecorate %myvar Binding 0 + OpDecorate %s Block + OpMemberDecorate %s 0 MatrixStride 64 + OpMemberDecorate %s 0 Offset 0 + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %m3v2float = OpTypeMatrix %v2float 3 + + %s = OpTypeStruct %m3v2float + %ptr_sb_s = OpTypePointer StorageBuffer %s + %myvar = OpVariable %ptr_sb_s StorageBuffer + )" + MainBody())); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + EXPECT_TRUE(p->error().empty()); + const auto module_str = p->program().to_str(); + EXPECT_THAT(module_str, HasSubstr(R"( + Struct S { + [[block]] + StructMember{[[ stride 64 tint_internal(disable_validation__ignore_stride) offset 0 ]] field0: __mat_2_3__f32} + } + Variable{ + Decorations{ + GroupDecoration{0} + BindingDecoration{0} + } + myvar + storage + read_write + __type_name_S + } +})")) << module_str; +} + TEST_F(SpvModuleScopeVarParserTest, RowMajorDecoration_IsError) { auto p = parser(test::Assemble(Preamble() + FragMain() + R"( OpName %myvar "myvar" @@ -2620,7 +2659,8 @@ TEST_F(SpvModuleScopeVarParserTest, SampleId_I32_Load_AccessChain) { private undefined __i32 - })")) <Declaration()->decorations()) { - if (!(deco->Is() || - deco->Is() || - deco->Is() || - deco->Is() || - deco->Is() || - deco->Is() || - deco->Is())) { + if (!deco->IsAnyOf()) { + if (deco->Is() && + IsValidationDisabled( + member->Declaration()->decorations(), + ast::DisabledValidation::kIgnoreStrideDecoration)) { + continue; + } AddError("decoration is not valid for structure members", deco->source()); return false; diff --git a/src/transform/decompose_strided_matrix.cc b/src/transform/decompose_strided_matrix.cc new file mode 100644 index 0000000000..786f509b2c --- /dev/null +++ b/src/transform/decompose_strided_matrix.cc @@ -0,0 +1,251 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/transform/decompose_strided_matrix.h" + +#include +#include +#include + +#include "src/program_builder.h" +#include "src/sem/expression.h" +#include "src/sem/member_accessor_expression.h" +#include "src/transform/inline_pointer_lets.h" +#include "src/transform/simplify.h" +#include "src/utils/get_or_create.h" +#include "src/utils/hash.h" + +TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeStridedMatrix); + +namespace tint { +namespace transform { +namespace { + +/// MatrixInfo describes a matrix member with a custom stride +struct MatrixInfo { + /// The stride in bytes between columns of the matrix + uint32_t stride = 0; + /// The type of the matrix + sem::Matrix const* matrix = nullptr; + + /// @returns a new ast::Array that holds an vector column for each row of the + /// matrix. + ast::Array* array(ProgramBuilder* b) const { + return b->ty.array(b->ty.vec(matrix->rows()), + matrix->columns(), stride); + } + + /// Equality operator + bool operator==(const MatrixInfo& info) const { + return stride == info.stride && matrix == info.matrix; + } + /// Hash function + struct Hasher { + size_t operator()(const MatrixInfo& t) const { + return utils::Hash(t.stride, t.matrix); + } + }; +}; + +/// Return type of the callback function of GatherCustomStrideMatrixMembers +enum GatherResult { kContinue, kStop }; + +/// GatherCustomStrideMatrixMembers scans `program` for all matrix members of +/// storage and uniform structs, which are of a matrix type, and have a custom +/// matrix stride attribute. For each matrix member found, `callback` is called. +/// `callback` is a function with the signature: +/// GatherResult(const sem::StructMember* member, +/// sem::Matrix* matrix, +/// uint32_t stride) +/// If `callback` return GatherResult::kStop, then the scanning will immediately +/// terminate, and GatherCustomStrideMatrixMembers() will return, otherwise +/// scanning will continue. +template +void GatherCustomStrideMatrixMembers(const Program* program, F&& callback) { + for (auto* node : program->ASTNodes().Objects()) { + if (auto* str = node->As()) { + auto* str_ty = program->Sem().Get(str); + if (!str_ty->UsedAs(ast::StorageClass::kUniform) && + !str_ty->UsedAs(ast::StorageClass::kStorage)) { + continue; + } + for (auto* member : str_ty->Members()) { + auto* matrix = member->Type()->As(); + if (!matrix) { + continue; + } + auto* deco = ast::GetDecoration( + member->Declaration()->decorations()); + if (!deco) { + continue; + } + uint32_t stride = deco->stride(); + if (matrix->ColumnStride() == stride) { + continue; + } + if (callback(member, matrix, stride) == GatherResult::kStop) { + return; + } + } + } + } +} + +} // namespace + +DecomposeStridedMatrix::DecomposeStridedMatrix() = default; + +DecomposeStridedMatrix::~DecomposeStridedMatrix() = default; + +bool DecomposeStridedMatrix::ShouldRun(const Program* program) { + bool should_run = false; + GatherCustomStrideMatrixMembers( + program, [&](const sem::StructMember*, sem::Matrix*, uint32_t) { + should_run = true; + return GatherResult::kStop; + }); + return should_run; +} + +void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) { + if (!Requires(ctx)) { + return; + } + + // Scan the program for all storage and uniform structure matrix members with + // a custom stride attribute. Replace these matrices with an equivalent array, + // and populate the `decomposed` map with the members that have been replaced. + std::unordered_map decomposed; + GatherCustomStrideMatrixMembers( + ctx.src, [&](const sem::StructMember* member, sem::Matrix* matrix, + uint32_t stride) { + // We've got ourselves a struct member of a matrix type with a custom + // stride. Replace this with an array of column vectors. + MatrixInfo info{stride, matrix}; + auto* replacement = ctx.dst->Member( + member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst)); + ctx.Replace(member->Declaration(), replacement); + decomposed.emplace(member->Declaration(), info); + return GatherResult::kContinue; + }); + + // For all expressions where a single matrix column vector was indexed, we can + // preserve these without calling conversion functions. + // Example: + // ssbo.mat[2] -> ssbo.mat[2] + ctx.ReplaceAll( + [&](ast::ArrayAccessorExpression* expr) -> ast::ArrayAccessorExpression* { + if (auto* access = + ctx.src->Sem().Get(expr->array())) { + auto it = decomposed.find(access->Member()->Declaration()); + if (it != decomposed.end()) { + auto* obj = ctx.CloneWithoutTransform(expr->array()); + auto* idx = ctx.Clone(expr->idx_expr()); + return ctx.dst->IndexAccessor(obj, idx); + } + } + return nullptr; + }); + + // For all struct member accesses to the matrix on the LHS of an assignment, + // we need to convert the matrix to the array before assigning to the + // structure. + // Example: + // ssbo.mat = mat_to_arr(m) + std::unordered_map mat_to_arr; + ctx.ReplaceAll([&](ast::AssignmentStatement* stmt) -> ast::Statement* { + if (auto* access = + ctx.src->Sem().Get(stmt->lhs())) { + auto it = decomposed.find(access->Member()->Declaration()); + if (it == decomposed.end()) { + return nullptr; + } + MatrixInfo info = it->second; + auto fn = utils::GetOrCreate(mat_to_arr, info, [&] { + auto name = ctx.dst->Symbols().New( + "mat" + std::to_string(info.matrix->columns()) + "x" + + std::to_string(info.matrix->rows()) + "_stride_" + + std::to_string(info.stride) + "_to_arr"); + + auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); }; + auto array = [&] { return info.array(ctx.dst); }; + + auto mat = ctx.dst->Sym("mat"); + ast::ExpressionList columns(info.matrix->columns()); + for (uint32_t i = 0; i < static_cast(columns.size()); i++) { + columns[i] = ctx.dst->IndexAccessor(mat, i); + } + ctx.dst->Func(name, + { + ctx.dst->Param(mat, matrix()), + }, + array(), + { + ctx.dst->Return(ctx.dst->Construct(array(), columns)), + }); + return name; + }); + auto* lhs = ctx.CloneWithoutTransform(stmt->lhs()); + auto* rhs = ctx.dst->Call(fn, ctx.Clone(stmt->rhs())); + return ctx.dst->Assign(lhs, rhs); + } + return nullptr; + }); + + // For all other struct member accesses, we need to convert the array to the + // matrix type. Example: + // m = arr_to_mat(ssbo.mat) + std::unordered_map arr_to_mat; + ctx.ReplaceAll([&](ast::MemberAccessorExpression* expr) -> ast::Expression* { + if (auto* access = ctx.src->Sem().Get(expr)) { + auto it = decomposed.find(access->Member()->Declaration()); + if (it == decomposed.end()) { + return nullptr; + } + MatrixInfo info = it->second; + auto fn = utils::GetOrCreate(arr_to_mat, info, [&] { + auto name = ctx.dst->Symbols().New( + "arr_to_mat" + std::to_string(info.matrix->columns()) + "x" + + std::to_string(info.matrix->rows()) + "_stride_" + + std::to_string(info.stride)); + + auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); }; + auto array = [&] { return info.array(ctx.dst); }; + + auto arr = ctx.dst->Sym("arr"); + ast::ExpressionList columns(info.matrix->columns()); + for (uint32_t i = 0; i < static_cast(columns.size()); i++) { + columns[i] = ctx.dst->IndexAccessor(arr, i); + } + ctx.dst->Func( + name, + { + ctx.dst->Param(arr, array()), + }, + matrix(), + { + ctx.dst->Return(ctx.dst->Construct(matrix(), columns)), + }); + return name; + }); + return ctx.dst->Call(fn, ctx.CloneWithoutTransform(expr)); + } + return nullptr; + }); + + ctx.Clone(); +} + +} // namespace transform +} // namespace tint diff --git a/src/transform/decompose_strided_matrix.h b/src/transform/decompose_strided_matrix.h new file mode 100644 index 0000000000..328304914c --- /dev/null +++ b/src/transform/decompose_strided_matrix.h @@ -0,0 +1,54 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TRANSFORM_DECOMPOSE_STRIDED_MATRIX_H_ +#define SRC_TRANSFORM_DECOMPOSE_STRIDED_MATRIX_H_ + +#include "src/transform/transform.h" + +namespace tint { +namespace transform { + +/// DecomposeStridedMatrix transforms replaces matrix members of storage or +/// uniform buffer structures, that have a [[stride]] decoration, into an array +/// of N column vectors. +/// This transform is used by the SPIR-V reader to handle the SPIR-V +/// MatrixStride decoration. +class DecomposeStridedMatrix + : public Castable { + public: + /// Constructor + DecomposeStridedMatrix(); + + /// Destructor + ~DecomposeStridedMatrix() override; + + /// @param program the program to inspect + /// @returns true if this transform should be run for the given program + static bool ShouldRun(const Program* program); + + protected: + /// Runs the transform using the CloneContext built for transforming a + /// program. Run() is responsible for calling Clone() on the CloneContext. + /// @param ctx the CloneContext primed with the input program and + /// ProgramBuilder + /// @param inputs optional extra transform-specific input data + /// @param outputs optional extra transform-specific output data + void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override; +}; + +} // namespace transform +} // namespace tint + +#endif // SRC_TRANSFORM_DECOMPOSE_STRIDED_MATRIX_H_ diff --git a/src/transform/decompose_strided_matrix_test.cc b/src/transform/decompose_strided_matrix_test.cc new file mode 100644 index 0000000000..6a205a67a8 --- /dev/null +++ b/src/transform/decompose_strided_matrix_test.cc @@ -0,0 +1,727 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/transform/decompose_strided_matrix.h" + +#include +#include +#include + +#include "src/ast/disable_validation_decoration.h" +#include "src/program_builder.h" +#include "src/transform/inline_pointer_lets.h" +#include "src/transform/simplify.h" +#include "src/transform/test_helper.h" + +namespace tint { +namespace transform { +namespace { + +using DecomposeStridedMatrixTest = TransformTest; +using f32 = ProgramBuilder::f32; + +TEST_F(DecomposeStridedMatrixTest, Empty) { + auto* src = R"()"; + auto* expect = src; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(DecomposeStridedMatrixTest, MissingDependencyInlinePointerLets) { + auto* src = R"()"; + auto* expect = + R"(error: tint::transform::DecomposeStridedMatrix depends on tint::transform::InlinePointerLets but the dependency was not run)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(DecomposeStridedMatrixTest, MissingDependencySimplify) { + auto* src = R"()"; + auto* expect = + R"(error: tint::transform::DecomposeStridedMatrix depends on tint::transform::Simplify but the dependency was not run)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix) { + // [[block]] + // struct S { + // [[offset(16), stride(32)]] + // [[internal(ignore_stride_decoration)]] + // m : mat2x2; + // }; + // [[group(0), binding(0)]] var s : S; + // + // [[stage(compute), workgroup_size(1)]] + // fn f() { + // let x : mat2x2 = s.m; + // } + ProgramBuilder b; + auto* S = b.Structure( + "S", + { + b.Member( + "m", b.ty.mat2x2(), + { + b.create(16), + b.create(32), + b.ASTNodes().Create( + b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration), + }), + }, + { + b.StructBlock(), + }); + b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform, + b.GroupAndBinding(0, 0)); + b.Func( + "f", {}, b.ty.void_(), + { + b.Decl(b.Const("x", b.ty.mat2x2(), b.MemberAccessor("s", "m"))), + }, + { + b.Stage(ast::PipelineStage::kCompute), + b.WorkgroupSize(1), + }); + + auto* expect = R"( +[[block]] +struct S { + [[size(16)]] + padding : u32; + m : [[stride(32)]] array, 2>; +}; + +[[group(0), binding(0)]] var s : S; + +fn arr_to_mat2x2_stride_32(arr : [[stride(32)]] array, 2>) -> mat2x2 { + return mat2x2(arr[0u], arr[1u]); +} + +[[stage(compute), workgroup_size(1)]] +fn f() { + let x : mat2x2 = arr_to_mat2x2_stride_32(s.m); +} +)"; + + auto got = Run( + Program(std::move(b))); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(DecomposeStridedMatrixTest, ReadUniformColumn) { + // [[block]] + // struct S { + // [[offset(16), stride(32)]] + // [[internal(ignore_stride_decoration)]] + // m : mat2x2; + // }; + // [[group(0), binding(0)]] var s : S; + // + // [[stage(compute), workgroup_size(1)]] + // fn f() { + // let x : vec2 = s.m[1]; + // } + ProgramBuilder b; + auto* S = b.Structure( + "S", + { + b.Member( + "m", b.ty.mat2x2(), + { + b.create(16), + b.create(32), + b.ASTNodes().Create( + b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration), + }), + }, + { + b.StructBlock(), + }); + b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform, + b.GroupAndBinding(0, 0)); + b.Func("f", {}, b.ty.void_(), + { + b.Decl(b.Const("x", b.ty.vec2(), + b.IndexAccessor(b.MemberAccessor("s", "m"), 1))), + }, + { + b.Stage(ast::PipelineStage::kCompute), + b.WorkgroupSize(1), + }); + + auto* expect = R"( +[[block]] +struct S { + [[size(16)]] + padding : u32; + m : [[stride(32)]] array, 2>; +}; + +[[group(0), binding(0)]] var s : S; + +[[stage(compute), workgroup_size(1)]] +fn f() { + let x : vec2 = s.m[1]; +} +)"; + + auto got = Run( + Program(std::move(b))); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix_DefaultStride) { + // [[block]] + // struct S { + // [[offset(16), stride(8)]] + // [[internal(ignore_stride_decoration)]] + // m : mat2x2; + // }; + // [[group(0), binding(0)]] var s : S; + // + // [[stage(compute), workgroup_size(1)]] + // fn f() { + // let x : mat2x2 = s.m; + // } + ProgramBuilder b; + auto* S = b.Structure( + "S", + { + b.Member( + "m", b.ty.mat2x2(), + { + b.create(16), + b.create(8), + b.ASTNodes().Create( + b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration), + }), + }, + { + b.StructBlock(), + }); + b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform, + b.GroupAndBinding(0, 0)); + b.Func( + "f", {}, b.ty.void_(), + { + b.Decl(b.Const("x", b.ty.mat2x2(), b.MemberAccessor("s", "m"))), + }, + { + b.Stage(ast::PipelineStage::kCompute), + b.WorkgroupSize(1), + }); + + auto* expect = R"( +[[block]] +struct S { + [[size(16)]] + padding : u32; + [[stride(8), internal(disable_validation__ignore_stride)]] + m : mat2x2; +}; + +[[group(0), binding(0)]] var s : S; + +[[stage(compute), workgroup_size(1)]] +fn f() { + let x : mat2x2 = s.m; +} +)"; + + auto got = Run( + Program(std::move(b))); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(DecomposeStridedMatrixTest, ReadStorageMatrix) { + // [[block]] + // struct S { + // [[offset(8), stride(32)]] + // [[internal(ignore_stride_decoration)]] + // m : mat2x2; + // }; + // [[group(0), binding(0)]] var s : S; + // + // [[stage(compute), workgroup_size(1)]] + // fn f() { + // let x : mat2x2 = s.m; + // } + ProgramBuilder b; + auto* S = b.Structure( + "S", + { + b.Member( + "m", b.ty.mat2x2(), + { + b.create(8), + b.create(32), + b.ASTNodes().Create( + b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration), + }), + }, + { + b.StructBlock(), + }); + b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, + ast::Access::kReadWrite, b.GroupAndBinding(0, 0)); + b.Func( + "f", {}, b.ty.void_(), + { + b.Decl(b.Const("x", b.ty.mat2x2(), b.MemberAccessor("s", "m"))), + }, + { + b.Stage(ast::PipelineStage::kCompute), + b.WorkgroupSize(1), + }); + + auto* expect = R"( +[[block]] +struct S { + [[size(8)]] + padding : u32; + m : [[stride(32)]] array, 2>; +}; + +[[group(0), binding(0)]] var s : S; + +fn arr_to_mat2x2_stride_32(arr : [[stride(32)]] array, 2>) -> mat2x2 { + return mat2x2(arr[0u], arr[1u]); +} + +[[stage(compute), workgroup_size(1)]] +fn f() { + let x : mat2x2 = arr_to_mat2x2_stride_32(s.m); +} +)"; + + auto got = Run( + Program(std::move(b))); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(DecomposeStridedMatrixTest, ReadStorageColumn) { + // [[block]] + // struct S { + // [[offset(16), stride(32)]] + // [[internal(ignore_stride_decoration)]] + // m : mat2x2; + // }; + // [[group(0), binding(0)]] var s : S; + // + // [[stage(compute), workgroup_size(1)]] + // fn f() { + // let x : vec2 = s.m[1]; + // } + ProgramBuilder b; + auto* S = b.Structure( + "S", + { + b.Member( + "m", b.ty.mat2x2(), + { + b.create(16), + b.create(32), + b.ASTNodes().Create( + b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration), + }), + }, + { + b.StructBlock(), + }); + b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, + ast::Access::kReadWrite, b.GroupAndBinding(0, 0)); + b.Func("f", {}, b.ty.void_(), + { + b.Decl(b.Const("x", b.ty.vec2(), + b.IndexAccessor(b.MemberAccessor("s", "m"), 1))), + }, + { + b.Stage(ast::PipelineStage::kCompute), + b.WorkgroupSize(1), + }); + + auto* expect = R"( +[[block]] +struct S { + [[size(16)]] + padding : u32; + m : [[stride(32)]] array, 2>; +}; + +[[group(0), binding(0)]] var s : S; + +[[stage(compute), workgroup_size(1)]] +fn f() { + let x : vec2 = s.m[1]; +} +)"; + + auto got = Run( + Program(std::move(b))); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(DecomposeStridedMatrixTest, WriteStorageMatrix) { + // [[block]] + // struct S { + // [[offset(8), stride(32)]] + // [[internal(ignore_stride_decoration)]] + // m : mat2x2; + // }; + // [[group(0), binding(0)]] var s : S; + // + // [[stage(compute), workgroup_size(1)]] + // fn f() { + // s.m = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); + // } + ProgramBuilder b; + auto* S = b.Structure( + "S", + { + b.Member( + "m", b.ty.mat2x2(), + { + b.create(8), + b.create(32), + b.ASTNodes().Create( + b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration), + }), + }, + { + b.StructBlock(), + }); + b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, + ast::Access::kReadWrite, b.GroupAndBinding(0, 0)); + b.Func("f", {}, b.ty.void_(), + { + b.Assign(b.MemberAccessor("s", "m"), + b.mat2x2(b.vec2(1.0f, 2.0f), + b.vec2(3.0f, 4.0f))), + }, + { + b.Stage(ast::PipelineStage::kCompute), + b.WorkgroupSize(1), + }); + + auto* expect = R"( +[[block]] +struct S { + [[size(8)]] + padding : u32; + m : [[stride(32)]] array, 2>; +}; + +[[group(0), binding(0)]] var s : S; + +fn mat2x2_stride_32_to_arr(mat : mat2x2) -> [[stride(32)]] array, 2> { + return [[stride(32)]] array, 2>(mat[0u], mat[1u]); +} + +[[stage(compute), workgroup_size(1)]] +fn f() { + s.m = mat2x2_stride_32_to_arr(mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0))); +} +)"; + + auto got = Run( + Program(std::move(b))); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(DecomposeStridedMatrixTest, WriteStorageColumn) { + // [[block]] + // struct S { + // [[offset(8), stride(32)]] + // [[internal(ignore_stride_decoration)]] + // m : mat2x2; + // }; + // [[group(0), binding(0)]] var s : S; + // + // [[stage(compute), workgroup_size(1)]] + // fn f() { + // s.m[1] = vec2(1.0, 2.0); + // } + ProgramBuilder b; + auto* S = b.Structure( + "S", + { + b.Member( + "m", b.ty.mat2x2(), + { + b.create(8), + b.create(32), + b.ASTNodes().Create( + b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration), + }), + }, + { + b.StructBlock(), + }); + b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, + ast::Access::kReadWrite, b.GroupAndBinding(0, 0)); + b.Func("f", {}, b.ty.void_(), + { + b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m"), 1), + b.vec2(1.0f, 2.0f)), + }, + { + b.Stage(ast::PipelineStage::kCompute), + b.WorkgroupSize(1), + }); + + auto* expect = R"( +[[block]] +struct S { + [[size(8)]] + padding : u32; + m : [[stride(32)]] array, 2>; +}; + +[[group(0), binding(0)]] var s : S; + +[[stage(compute), workgroup_size(1)]] +fn f() { + s.m[1] = vec2(1.0, 2.0); +} +)"; + + auto got = Run( + Program(std::move(b))); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(DecomposeStridedMatrixTest, ReadWriteViaPointerLets) { + // [[block]] + // struct S { + // [[offset(8), stride(32)]] + // [[internal(ignore_stride_decoration)]] + // m : mat2x2; + // }; + // [[group(0), binding(0)]] var s : S; + // + // [[stage(compute), workgroup_size(1)]] + // fn f() { + // let a = &s.m; + // let b = &*&*(a); + // let x = *b; + // let y = (*b)[1]; + // let z = x[1]; + // (*b) = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); + // (*b)[1] = vec2(5.0, 6.0); + // } + ProgramBuilder b; + auto* S = b.Structure( + "S", + { + b.Member( + "m", b.ty.mat2x2(), + { + b.create(8), + b.create(32), + b.ASTNodes().Create( + b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration), + }), + }, + { + b.StructBlock(), + }); + b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage, + ast::Access::kReadWrite, b.GroupAndBinding(0, 0)); + b.Func( + "f", {}, b.ty.void_(), + { + b.Decl( + b.Const("a", nullptr, b.AddressOf(b.MemberAccessor("s", "m")))), + b.Decl(b.Const("b", nullptr, + b.AddressOf(b.Deref(b.AddressOf(b.Deref("a")))))), + b.Decl(b.Const("x", nullptr, b.Deref("b"))), + b.Decl(b.Const("y", nullptr, b.IndexAccessor(b.Deref("b"), 1))), + b.Decl(b.Const("z", nullptr, b.IndexAccessor("x", 1))), + b.Assign(b.Deref("b"), b.mat2x2(b.vec2(1.0f, 2.0f), + b.vec2(3.0f, 4.0f))), + b.Assign(b.IndexAccessor(b.Deref("b"), 1), b.vec2(5.0f, 6.0f)), + }, + { + b.Stage(ast::PipelineStage::kCompute), + b.WorkgroupSize(1), + }); + + auto* expect = R"( +[[block]] +struct S { + [[size(8)]] + padding : u32; + m : [[stride(32)]] array, 2>; +}; + +[[group(0), binding(0)]] var s : S; + +fn arr_to_mat2x2_stride_32(arr : [[stride(32)]] array, 2>) -> mat2x2 { + return mat2x2(arr[0u], arr[1u]); +} + +fn mat2x2_stride_32_to_arr(mat : mat2x2) -> [[stride(32)]] array, 2> { + return [[stride(32)]] array, 2>(mat[0u], mat[1u]); +} + +[[stage(compute), workgroup_size(1)]] +fn f() { + let x = arr_to_mat2x2_stride_32(s.m); + let y = s.m[1]; + let z = x[1]; + s.m = mat2x2_stride_32_to_arr(mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0))); + s.m[1] = vec2(5.0, 6.0); +} +)"; + + auto got = Run( + Program(std::move(b))); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(DecomposeStridedMatrixTest, ReadPrivateMatrix) { + // struct S { + // [[offset(8), stride(32)]] + // [[internal(ignore_stride_decoration)]] + // m : mat2x2; + // }; + // var s : S; + // + // [[stage(compute), workgroup_size(1)]] + // fn f() { + // let x : mat2x2 = s.m; + // } + ProgramBuilder b; + auto* S = b.Structure( + "S", + { + b.Member( + "m", b.ty.mat2x2(), + { + b.create(8), + b.create(32), + b.ASTNodes().Create( + b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration), + }), + }); + b.Global("s", b.ty.Of(S), ast::StorageClass::kPrivate); + b.Func( + "f", {}, b.ty.void_(), + { + b.Decl(b.Const("x", b.ty.mat2x2(), b.MemberAccessor("s", "m"))), + }, + { + b.Stage(ast::PipelineStage::kCompute), + b.WorkgroupSize(1), + }); + + auto* expect = R"( +struct S { + [[size(8)]] + padding : u32; + [[stride(32), internal(disable_validation__ignore_stride)]] + m : mat2x2; +}; + +var s : S; + +[[stage(compute), workgroup_size(1)]] +fn f() { + let x : mat2x2 = s.m; +} +)"; + + auto got = Run( + Program(std::move(b))); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(DecomposeStridedMatrixTest, WritePrivateMatrix) { + // struct S { + // [[offset(8), stride(32)]] + // [[internal(ignore_stride_decoration)]] + // m : mat2x2; + // }; + // var s : S; + // + // [[stage(compute), workgroup_size(1)]] + // fn f() { + // s.m = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); + // } + ProgramBuilder b; + auto* S = b.Structure( + "S", + { + b.Member( + "m", b.ty.mat2x2(), + { + b.create(8), + b.create(32), + b.ASTNodes().Create( + b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration), + }), + }); + b.Global("s", b.ty.Of(S), ast::StorageClass::kPrivate); + b.Func("f", {}, b.ty.void_(), + { + b.Assign(b.MemberAccessor("s", "m"), + b.mat2x2(b.vec2(1.0f, 2.0f), + b.vec2(3.0f, 4.0f))), + }, + { + b.Stage(ast::PipelineStage::kCompute), + b.WorkgroupSize(1), + }); + + auto* expect = R"( +struct S { + [[size(8)]] + padding : u32; + [[stride(32), internal(disable_validation__ignore_stride)]] + m : mat2x2; +}; + +var s : S; + +[[stage(compute), workgroup_size(1)]] +fn f() { + s.m = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); +} +)"; + + auto got = Run( + Program(std::move(b))); + + EXPECT_EQ(expect, str(got)); +} + +} // namespace +} // namespace transform +} // namespace tint diff --git a/src/transform/test_helper.h b/src/transform/test_helper.h index 4b35a71675..0c4618a6bc 100644 --- a/src/transform/test_helper.h +++ b/src/transform/test_helper.h @@ -59,6 +59,16 @@ class TransformTestBase : public BASE { // Keep this pointer alive after Transform() returns files_.emplace_back(std::move(file)); + return Run(std::move(program), data); + } + + /// Transforms and returns program `program`, transformed using a transform of + /// type `TRANSFORM`. + /// @param program the input Program + /// @param data the optional DataMap to pass to Transform::Run() + /// @return the transformed output + template + Output Run(Program&& program, const DataMap& data = {}) { if (!program.IsValid()) { return Output(std::move(program)); } diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc index 698efcdaeb..4fadc64e99 100644 --- a/src/writer/wgsl/generator_impl.cc +++ b/src/writer/wgsl/generator_impl.cc @@ -696,6 +696,8 @@ bool GeneratorImpl::EmitDecorations(std::ostream& out, out << "size(" << size->size() << ")"; } else if (auto* align = deco->As()) { out << "align(" << align->align() << ")"; + } else if (auto* stride = deco->As()) { + out << "stride(" << stride->stride() << ")"; } else if (auto* internal = deco->As()) { out << "internal(" << internal->InternalName() << ")"; } else { diff --git a/test/BUILD.gn b/test/BUILD.gn index bd33d6243b..67d2d39c49 100644 --- a/test/BUILD.gn +++ b/test/BUILD.gn @@ -289,6 +289,7 @@ tint_unittests_source_set("tint_unittests_core_src") { "../src/transform/calculate_array_length_test.cc", "../src/transform/canonicalize_entry_point_io_test.cc", "../src/transform/decompose_memory_access_test.cc", + "../src/transform/decompose_strided_matrix_test.cc", "../src/transform/external_texture_transform_test.cc", "../src/transform/first_index_offset_test.cc", "../src/transform/fold_constants_test.cc", diff --git a/test/layout/storage/mat2x2/f32.wgsl b/test/layout/storage/mat2x2/f32.wgsl new file mode 100644 index 0000000000..c0cdb01878 --- /dev/null +++ b/test/layout/storage/mat2x2/f32.wgsl @@ -0,0 +1,11 @@ +[[block]] +struct SSBO { + m : mat2x2; +}; +[[group(0), binding(0)]] var ssbo : SSBO; + +[[stage(compute), workgroup_size(1)]] +fn f() { + let v = ssbo.m; + ssbo.m = v; +} diff --git a/test/layout/storage/mat2x2/f32.wgsl.expected.hlsl b/test/layout/storage/mat2x2/f32.wgsl.expected.hlsl new file mode 100644 index 0000000000..49afc068ce --- /dev/null +++ b/test/layout/storage/mat2x2/f32.wgsl.expected.hlsl @@ -0,0 +1,17 @@ +RWByteAddressBuffer ssbo : register(u0, space0); + +float2x2 tint_symbol(RWByteAddressBuffer buffer, uint offset) { + return float2x2(asfloat(buffer.Load2((offset + 0u))), asfloat(buffer.Load2((offset + 8u)))); +} + +void tint_symbol_2(RWByteAddressBuffer buffer, uint offset, float2x2 value) { + buffer.Store2((offset + 0u), asuint(value[0u])); + buffer.Store2((offset + 8u), asuint(value[1u])); +} + +[numthreads(1, 1, 1)] +void f() { + const float2x2 v = tint_symbol(ssbo, 0u); + tint_symbol_2(ssbo, 0u, v); + return; +} diff --git a/test/layout/storage/mat2x2/f32.wgsl.expected.msl b/test/layout/storage/mat2x2/f32.wgsl.expected.msl new file mode 100644 index 0000000000..32de92114b --- /dev/null +++ b/test/layout/storage/mat2x2/f32.wgsl.expected.msl @@ -0,0 +1,13 @@ +#include + +using namespace metal; +struct SSBO { + /* 0x0000 */ float2x2 m; +}; + +kernel void f(device SSBO& ssbo [[buffer(0)]]) { + float2x2 const v = ssbo.m; + ssbo.m = v; + return; +} + diff --git a/test/layout/storage/mat2x2/f32.wgsl.expected.spvasm b/test/layout/storage/mat2x2/f32.wgsl.expected.spvasm new file mode 100644 index 0000000000..2e4ee55c85 --- /dev/null +++ b/test/layout/storage/mat2x2/f32.wgsl.expected.spvasm @@ -0,0 +1,38 @@ +; SPIR-V +; Version: 1.3 +; Generator: Google Tint Compiler; 0 +; Bound: 17 +; Schema: 0 + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %f "f" + OpExecutionMode %f LocalSize 1 1 1 + OpName %SSBO "SSBO" + OpMemberName %SSBO 0 "m" + OpName %ssbo "ssbo" + OpName %f "f" + OpDecorate %SSBO Block + OpMemberDecorate %SSBO 0 Offset 0 + OpMemberDecorate %SSBO 0 ColMajor + OpMemberDecorate %SSBO 0 MatrixStride 8 + OpDecorate %ssbo DescriptorSet 0 + OpDecorate %ssbo Binding 0 + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 +%mat2v2float = OpTypeMatrix %v2float 2 + %SSBO = OpTypeStruct %mat2v2float +%_ptr_StorageBuffer_SSBO = OpTypePointer StorageBuffer %SSBO + %ssbo = OpVariable %_ptr_StorageBuffer_SSBO StorageBuffer + %void = OpTypeVoid + %7 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 +%_ptr_StorageBuffer_mat2v2float = OpTypePointer StorageBuffer %mat2v2float + %f = OpFunction %void None %7 + %10 = OpLabel + %14 = OpAccessChain %_ptr_StorageBuffer_mat2v2float %ssbo %uint_0 + %15 = OpLoad %mat2v2float %14 + %16 = OpAccessChain %_ptr_StorageBuffer_mat2v2float %ssbo %uint_0 + OpStore %16 %15 + OpReturn + OpFunctionEnd diff --git a/test/layout/storage/mat2x2/f32.wgsl.expected.wgsl b/test/layout/storage/mat2x2/f32.wgsl.expected.wgsl new file mode 100644 index 0000000000..b4f1f51cca --- /dev/null +++ b/test/layout/storage/mat2x2/f32.wgsl.expected.wgsl @@ -0,0 +1,12 @@ +[[block]] +struct SSBO { + m : mat2x2; +}; + +[[group(0), binding(0)]] var ssbo : SSBO; + +[[stage(compute), workgroup_size(1)]] +fn f() { + let v = ssbo.m; + ssbo.m = v; +} diff --git a/test/layout/storage/mat2x2/stride/16.spvasm b/test/layout/storage/mat2x2/stride/16.spvasm new file mode 100644 index 0000000000..fbf1122d52 --- /dev/null +++ b/test/layout/storage/mat2x2/stride/16.spvasm @@ -0,0 +1,33 @@ + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %f "f" + OpExecutionMode %f LocalSize 1 1 1 + OpName %SSBO "SSBO" + OpMemberName %SSBO 0 "m" + OpName %ssbo "ssbo" + OpName %f "f" + OpDecorate %SSBO Block + OpMemberDecorate %SSBO 0 Offset 0 + OpMemberDecorate %SSBO 0 ColMajor + OpMemberDecorate %SSBO 0 MatrixStride 16 + OpDecorate %ssbo DescriptorSet 0 + OpDecorate %ssbo Binding 0 + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %mat2v2float = OpTypeMatrix %v2float 2 + %SSBO = OpTypeStruct %mat2v2float + %_ptr_StorageBuffer_SSBO = OpTypePointer StorageBuffer %SSBO + %ssbo = OpVariable %_ptr_StorageBuffer_SSBO StorageBuffer + %void = OpTypeVoid + %7 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 +%_ptr_StorageBuffer_mat2v2float = OpTypePointer StorageBuffer %mat2v2float + %f = OpFunction %void None %7 + %10 = OpLabel + %14 = OpAccessChain %_ptr_StorageBuffer_mat2v2float %ssbo %uint_0 + %15 = OpLoad %mat2v2float %14 + %16 = OpAccessChain %_ptr_StorageBuffer_mat2v2float %ssbo %uint_0 + OpStore %16 %15 + OpReturn + OpFunctionEnd diff --git a/test/layout/storage/mat2x2/stride/16.spvasm.expected.hlsl b/test/layout/storage/mat2x2/stride/16.spvasm.expected.hlsl new file mode 100644 index 0000000000..f9ad3adefb --- /dev/null +++ b/test/layout/storage/mat2x2/stride/16.spvasm.expected.hlsl @@ -0,0 +1,47 @@ +struct tint_padded_array_element { + float2 el; +}; + +RWByteAddressBuffer ssbo : register(u0, space0); + +float2x2 arr_to_mat2x2_stride_16(tint_padded_array_element arr[2]) { + return float2x2(arr[0u].el, arr[1u].el); +} + +typedef tint_padded_array_element mat2x2_stride_16_to_arr_ret[2]; +mat2x2_stride_16_to_arr_ret mat2x2_stride_16_to_arr(float2x2 mat) { + const tint_padded_array_element tint_symbol_4[2] = {{mat[0u]}, {mat[1u]}}; + return tint_symbol_4; +} + +typedef tint_padded_array_element tint_symbol_ret[2]; +tint_symbol_ret tint_symbol(RWByteAddressBuffer buffer, uint offset) { + tint_padded_array_element arr_1[2] = (tint_padded_array_element[2])0; + { + for(uint i = 0u; (i < 2u); i = (i + 1u)) { + arr_1[i].el = asfloat(buffer.Load2((offset + (i * 16u)))); + } + } + return arr_1; +} + +void tint_symbol_2(RWByteAddressBuffer buffer, uint offset, tint_padded_array_element value[2]) { + tint_padded_array_element array[2] = value; + { + for(uint i_1 = 0u; (i_1 < 2u); i_1 = (i_1 + 1u)) { + buffer.Store2((offset + (i_1 * 16u)), asuint(array[i_1].el)); + } + } +} + +void f_1() { + const float2x2 x_15 = arr_to_mat2x2_stride_16(tint_symbol(ssbo, 0u)); + tint_symbol_2(ssbo, 0u, mat2x2_stride_16_to_arr(x_15)); + return; +} + +[numthreads(1, 1, 1)] +void f() { + f_1(); + return; +} diff --git a/test/layout/storage/mat2x2/stride/16.spvasm.expected.msl b/test/layout/storage/mat2x2/stride/16.spvasm.expected.msl new file mode 100644 index 0000000000..9bb899fd34 --- /dev/null +++ b/test/layout/storage/mat2x2/stride/16.spvasm.expected.msl @@ -0,0 +1,34 @@ +#include + +using namespace metal; +struct tint_padded_array_element { + /* 0x0000 */ packed_float2 el; + /* 0x0008 */ int8_t tint_pad[8]; +}; +struct tint_array_wrapper { + /* 0x0000 */ tint_padded_array_element arr[2]; +}; +struct SSBO { + /* 0x0000 */ tint_array_wrapper m; +}; + +float2x2 arr_to_mat2x2_stride_16(tint_array_wrapper arr) { + return float2x2(arr.arr[0u].el, arr.arr[1u].el); +} + +tint_array_wrapper mat2x2_stride_16_to_arr(float2x2 mat) { + tint_array_wrapper const tint_symbol = {.arr={{.el=mat[0u]}, {.el=mat[1u]}}}; + return tint_symbol; +} + +void f_1(device SSBO& ssbo) { + float2x2 const x_15 = arr_to_mat2x2_stride_16(ssbo.m); + ssbo.m = mat2x2_stride_16_to_arr(x_15); + return; +} + +kernel void f(device SSBO& ssbo [[buffer(0)]]) { + f_1(ssbo); + return; +} + diff --git a/test/layout/storage/mat2x2/stride/16.spvasm.expected.spvasm b/test/layout/storage/mat2x2/stride/16.spvasm.expected.spvasm new file mode 100644 index 0000000000..94c280b7cd --- /dev/null +++ b/test/layout/storage/mat2x2/stride/16.spvasm.expected.spvasm @@ -0,0 +1,70 @@ +; SPIR-V +; Version: 1.3 +; Generator: Google Tint Compiler; 0 +; Bound: 39 +; Schema: 0 + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %f "f" + OpExecutionMode %f LocalSize 1 1 1 + OpName %SSBO "SSBO" + OpMemberName %SSBO 0 "m" + OpName %ssbo "ssbo" + OpName %arr_to_mat2x2_stride_16 "arr_to_mat2x2_stride_16" + OpName %arr "arr" + OpName %mat2x2_stride_16_to_arr "mat2x2_stride_16_to_arr" + OpName %mat "mat" + OpName %f_1 "f_1" + OpName %f "f" + OpDecorate %SSBO Block + OpMemberDecorate %SSBO 0 Offset 0 + OpDecorate %_arr_v2float_uint_2 ArrayStride 16 + OpDecorate %ssbo DescriptorSet 0 + OpDecorate %ssbo Binding 0 + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_v2float_uint_2 = OpTypeArray %v2float %uint_2 + %SSBO = OpTypeStruct %_arr_v2float_uint_2 +%_ptr_StorageBuffer_SSBO = OpTypePointer StorageBuffer %SSBO + %ssbo = OpVariable %_ptr_StorageBuffer_SSBO StorageBuffer +%mat2v2float = OpTypeMatrix %v2float 2 + %9 = OpTypeFunction %mat2v2float %_arr_v2float_uint_2 + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 + %19 = OpTypeFunction %_arr_v2float_uint_2 %mat2v2float + %void = OpTypeVoid + %26 = OpTypeFunction %void +%_ptr_StorageBuffer__arr_v2float_uint_2 = OpTypePointer StorageBuffer %_arr_v2float_uint_2 +%arr_to_mat2x2_stride_16 = OpFunction %mat2v2float None %9 + %arr = OpFunctionParameter %_arr_v2float_uint_2 + %13 = OpLabel + %15 = OpCompositeExtract %v2float %arr 0 + %17 = OpCompositeExtract %v2float %arr 1 + %18 = OpCompositeConstruct %mat2v2float %15 %17 + OpReturnValue %18 + OpFunctionEnd +%mat2x2_stride_16_to_arr = OpFunction %_arr_v2float_uint_2 None %19 + %mat = OpFunctionParameter %mat2v2float + %22 = OpLabel + %23 = OpCompositeExtract %v2float %mat 0 + %24 = OpCompositeExtract %v2float %mat 1 + %25 = OpCompositeConstruct %_arr_v2float_uint_2 %23 %24 + OpReturnValue %25 + OpFunctionEnd + %f_1 = OpFunction %void None %26 + %29 = OpLabel + %32 = OpAccessChain %_ptr_StorageBuffer__arr_v2float_uint_2 %ssbo %uint_0 + %33 = OpLoad %_arr_v2float_uint_2 %32 + %30 = OpFunctionCall %mat2v2float %arr_to_mat2x2_stride_16 %33 + %34 = OpAccessChain %_ptr_StorageBuffer__arr_v2float_uint_2 %ssbo %uint_0 + %35 = OpFunctionCall %_arr_v2float_uint_2 %mat2x2_stride_16_to_arr %30 + OpStore %34 %35 + OpReturn + OpFunctionEnd + %f = OpFunction %void None %26 + %37 = OpLabel + %38 = OpFunctionCall %void %f_1 + OpReturn + OpFunctionEnd diff --git a/test/layout/storage/mat2x2/stride/16.spvasm.expected.wgsl b/test/layout/storage/mat2x2/stride/16.spvasm.expected.wgsl new file mode 100644 index 0000000000..0ac837a1d7 --- /dev/null +++ b/test/layout/storage/mat2x2/stride/16.spvasm.expected.wgsl @@ -0,0 +1,25 @@ +[[block]] +struct SSBO { + m : [[stride(16)]] array, 2>; +}; + +[[group(0), binding(0)]] var ssbo : SSBO; + +fn arr_to_mat2x2_stride_16(arr : [[stride(16)]] array, 2>) -> mat2x2 { + return mat2x2(arr[0u], arr[1u]); +} + +fn mat2x2_stride_16_to_arr(mat : mat2x2) -> [[stride(16)]] array, 2> { + return [[stride(16)]] array, 2>(mat[0u], mat[1u]); +} + +fn f_1() { + let x_15 : mat2x2 = arr_to_mat2x2_stride_16(ssbo.m); + ssbo.m = mat2x2_stride_16_to_arr(x_15); + return; +} + +[[stage(compute), workgroup_size(1, 1, 1)]] +fn f() { + f_1(); +}