reader/spirv: Handle the MatrixStride decoration
Add `transform::DecomposeStridedMatrix`, which replaces matrix members of storage or uniform buffer structures, that have a [[stride]] decoration, into an array of N column vectors. This is required to correctly handle `mat2x2` matrices in UBOs, as std140 rules will expect a default stride of 16 bytes, when in WGSL the default structure layout expects a stride of 8 bytes. Bug: tint:1047 Change-Id: If5ca3c6ec087bbc1ac31a8d9a657b99bf34042a4 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/59840 Reviewed-by: David Neto <dneto@google.com> Commit-Queue: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
parent
c6cbe3fda6
commit
97668c8c37
|
@ -426,6 +426,8 @@ libtint_source_set("libtint_core_all_src") {
|
|||
"transform/canonicalize_entry_point_io.h",
|
||||
"transform/decompose_memory_access.cc",
|
||||
"transform/decompose_memory_access.h",
|
||||
"transform/decompose_strided_matrix.cc",
|
||||
"transform/decompose_strided_matrix.h",
|
||||
"transform/external_texture_transform.cc",
|
||||
"transform/external_texture_transform.h",
|
||||
"transform/first_index_offset.cc",
|
||||
|
|
|
@ -296,6 +296,8 @@ set(TINT_LIB_SRCS
|
|||
transform/canonicalize_entry_point_io.h
|
||||
transform/decompose_memory_access.cc
|
||||
transform/decompose_memory_access.h
|
||||
transform/decompose_strided_matrix.cc
|
||||
transform/decompose_strided_matrix.h
|
||||
transform/external_texture_transform.cc
|
||||
transform/external_texture_transform.h
|
||||
transform/first_index_offset.cc
|
||||
|
@ -904,6 +906,7 @@ if(${TINT_BUILD_TESTS})
|
|||
transform/calculate_array_length_test.cc
|
||||
transform/canonicalize_entry_point_io_test.cc
|
||||
transform/decompose_memory_access_test.cc
|
||||
transform/decompose_strided_matrix_test.cc
|
||||
transform/external_texture_transform_test.cc
|
||||
transform/first_index_offset_test.cc
|
||||
transform/fold_constants_test.cc
|
||||
|
|
|
@ -40,6 +40,8 @@ std::string DisableValidationDecoration::InternalName() const {
|
|||
return "disable_validation__entry_point_parameter";
|
||||
case DisabledValidation::kIgnoreConstructibleFunctionParameter:
|
||||
return "disable_validation__ignore_constructible_function_parameter";
|
||||
case DisabledValidation::kIgnoreStrideDecoration:
|
||||
return "disable_validation__ignore_stride";
|
||||
}
|
||||
return "<invalid>";
|
||||
}
|
||||
|
|
|
@ -40,6 +40,9 @@ enum class DisabledValidation {
|
|||
/// When applied to a function parameter, the validator will not
|
||||
/// check if parameter type is constructible
|
||||
kIgnoreConstructibleFunctionParameter,
|
||||
/// When applied to a member decoration, a stride decoration may be applied to
|
||||
/// non-array types.
|
||||
kIgnoreStrideDecoration,
|
||||
};
|
||||
|
||||
/// An internal decoration used to tell the validator to ignore specific
|
||||
|
|
|
@ -32,7 +32,7 @@ void InternalDecoration::to_str(const sem::Info&,
|
|||
std::ostream& out,
|
||||
size_t indent) const {
|
||||
make_indent(out, indent);
|
||||
out << "tint_internal(" << InternalName() << ")" << std::endl;
|
||||
out << "tint_internal(" << InternalName() << ")";
|
||||
}
|
||||
|
||||
} // namespace ast
|
||||
|
|
|
@ -17,6 +17,10 @@
|
|||
#include <utility>
|
||||
|
||||
#include "src/reader/spirv/parser_impl.h"
|
||||
#include "src/transform/decompose_strided_matrix.h"
|
||||
#include "src/transform/inline_pointer_lets.h"
|
||||
#include "src/transform/manager.h"
|
||||
#include "src/transform/simplify.h"
|
||||
|
||||
namespace tint {
|
||||
namespace reader {
|
||||
|
@ -40,7 +44,19 @@ Program Parse(const std::vector<uint32_t>& input) {
|
|||
|
||||
ProgramBuilder output;
|
||||
CloneContext(&output, &program_with_disjoint_ast, false).Clone();
|
||||
return Program(std::move(output));
|
||||
auto program = Program(std::move(output));
|
||||
|
||||
// If the generated program contains matrices with a custom MatrixStride
|
||||
// attribute then we need to decompose these into an array of vectors
|
||||
if (transform::DecomposeStridedMatrix::ShouldRun(&program)) {
|
||||
transform::Manager manager;
|
||||
manager.Add<transform::InlinePointerLets>();
|
||||
manager.Add<transform::Simplify>();
|
||||
manager.Add<transform::DecomposeStridedMatrix>();
|
||||
return manager.Run(&program).program;
|
||||
}
|
||||
|
||||
return program;
|
||||
}
|
||||
|
||||
} // namespace spirv
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
#include "source/opt/build_module.h"
|
||||
#include "src/ast/bitcast_expression.h"
|
||||
#include "src/ast/disable_validation_decoration.h"
|
||||
#include "src/ast/interpolate_decoration.h"
|
||||
#include "src/ast/override_decoration.h"
|
||||
#include "src/ast/struct_block_decoration.h"
|
||||
|
@ -439,13 +440,14 @@ std::string ParserImpl::ShowType(uint32_t type_id) {
|
|||
return "SPIR-V type " + std::to_string(type_id);
|
||||
}
|
||||
|
||||
ast::Decoration* ParserImpl::ConvertMemberDecoration(
|
||||
ast::DecorationList ParserImpl::ConvertMemberDecoration(
|
||||
uint32_t struct_type_id,
|
||||
uint32_t member_index,
|
||||
const Type* member_ty,
|
||||
const Decoration& decoration) {
|
||||
if (decoration.empty()) {
|
||||
Fail() << "malformed SPIR-V decoration: it's empty";
|
||||
return nullptr;
|
||||
return {};
|
||||
}
|
||||
switch (decoration[0]) {
|
||||
case SpvDecorationOffset:
|
||||
|
@ -454,38 +456,49 @@ ast::Decoration* ParserImpl::ConvertMemberDecoration(
|
|||
<< "malformed Offset decoration: expected 1 literal operand, has "
|
||||
<< decoration.size() - 1 << ": member " << member_index << " of "
|
||||
<< ShowType(struct_type_id);
|
||||
return nullptr;
|
||||
return {};
|
||||
}
|
||||
return create<ast::StructMemberOffsetDecoration>(Source{}, decoration[1]);
|
||||
return {
|
||||
create<ast::StructMemberOffsetDecoration>(Source{}, decoration[1]),
|
||||
};
|
||||
case SpvDecorationNonReadable:
|
||||
// WGSL doesn't have a member decoration for this. Silently drop it.
|
||||
return nullptr;
|
||||
return {};
|
||||
case SpvDecorationNonWritable:
|
||||
// WGSL doesn't have a member decoration for this.
|
||||
return nullptr;
|
||||
return {};
|
||||
case SpvDecorationColMajor:
|
||||
// WGSL only supports column major matrices.
|
||||
return nullptr;
|
||||
return {};
|
||||
case SpvDecorationRelaxedPrecision:
|
||||
// WGSL doesn't support relaxed precision.
|
||||
return nullptr;
|
||||
return {};
|
||||
case SpvDecorationRowMajor:
|
||||
Fail() << "WGSL does not support row-major matrices: can't "
|
||||
"translate member "
|
||||
<< member_index << " of " << ShowType(struct_type_id);
|
||||
return nullptr;
|
||||
return {};
|
||||
case SpvDecorationMatrixStride: {
|
||||
if (decoration.size() != 2) {
|
||||
Fail() << "malformed MatrixStride decoration: expected 1 literal "
|
||||
"operand, has "
|
||||
<< decoration.size() - 1 << ": member " << member_index << " of "
|
||||
<< ShowType(struct_type_id);
|
||||
return nullptr;
|
||||
return {};
|
||||
}
|
||||
// TODO(dneto): Fail if the matrix stride is not allocation size of the
|
||||
// column vector of the underlying matrix. This would need to unpack
|
||||
// any levels of array-ness.
|
||||
return nullptr;
|
||||
uint32_t stride = decoration[1];
|
||||
uint32_t natural_stride = 0;
|
||||
if (auto* mat = member_ty->As<Matrix>()) {
|
||||
natural_stride = (mat->rows == 2) ? 8 : 16;
|
||||
}
|
||||
if (stride == natural_stride) {
|
||||
return {};
|
||||
}
|
||||
return {
|
||||
create<ast::StrideDecoration>(Source{}, decoration[1]),
|
||||
builder_.ASTNodes().Create<ast::DisableValidationDecoration>(
|
||||
builder_.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
|
||||
};
|
||||
}
|
||||
default:
|
||||
// TODO(dneto): Support the remaining member decorations.
|
||||
|
@ -493,7 +506,7 @@ ast::Decoration* ParserImpl::ConvertMemberDecoration(
|
|||
}
|
||||
Fail() << "unhandled member decoration: " << decoration[0] << " on member "
|
||||
<< member_index << " of " << ShowType(struct_type_id);
|
||||
return nullptr;
|
||||
return {};
|
||||
}
|
||||
|
||||
bool ParserImpl::BuildInternalModule() {
|
||||
|
@ -1126,14 +1139,14 @@ const Type* ParserImpl::ConvertType(
|
|||
// the members are non-writable.
|
||||
is_non_writable = true;
|
||||
} else {
|
||||
auto* ast_member_decoration =
|
||||
ConvertMemberDecoration(type_id, member_index, decoration);
|
||||
auto decos = ConvertMemberDecoration(type_id, member_index,
|
||||
ast_member_ty, decoration);
|
||||
for (auto* deco : decos) {
|
||||
ast_member_decorations.emplace_back(deco);
|
||||
}
|
||||
if (!success_) {
|
||||
return nullptr;
|
||||
}
|
||||
if (ast_member_decoration) {
|
||||
ast_member_decorations.push_back(ast_member_decoration);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -272,16 +272,19 @@ class ParserImpl : Reader {
|
|||
ast::Decoration* SetLocation(ast::DecorationList* decos,
|
||||
ast::Decoration* replacement);
|
||||
|
||||
/// Converts a SPIR-V struct member decoration. If the decoration is
|
||||
/// recognized but deliberately dropped, then returns nullptr without a
|
||||
/// diagnostic. On failure, emits a diagnostic and returns nullptr.
|
||||
/// Converts a SPIR-V struct member decoration into a number of AST
|
||||
/// decorations. If the decoration is recognized but deliberately dropped,
|
||||
/// then returns an empty list without a diagnostic. On failure, emits a
|
||||
/// diagnostic and returns an empty list.
|
||||
/// @param struct_type_id the ID of the struct type
|
||||
/// @param member_index the index of the member
|
||||
/// @param member_ty the type of the member
|
||||
/// @param decoration an encoded SPIR-V Decoration
|
||||
/// @returns the corresponding ast::StructuMemberDecoration
|
||||
ast::Decoration* ConvertMemberDecoration(uint32_t struct_type_id,
|
||||
uint32_t member_index,
|
||||
const Decoration& decoration);
|
||||
/// @returns the AST decorations
|
||||
ast::DecorationList ConvertMemberDecoration(uint32_t struct_type_id,
|
||||
uint32_t member_index,
|
||||
const Type* member_ty,
|
||||
const Decoration& decoration);
|
||||
|
||||
/// Returns a string for the given type. If the type ID is invalid,
|
||||
/// then the resulting string only names the type ID.
|
||||
|
|
|
@ -25,16 +25,17 @@ using ::testing::Eq;
|
|||
TEST_F(SpvParserTest, ConvertMemberDecoration_Empty) {
|
||||
auto p = parser(std::vector<uint32_t>{});
|
||||
|
||||
auto* result = p->ConvertMemberDecoration(1, 1, {});
|
||||
EXPECT_EQ(result, nullptr);
|
||||
auto result = p->ConvertMemberDecoration(1, 1, nullptr, {});
|
||||
EXPECT_TRUE(result.empty());
|
||||
EXPECT_THAT(p->error(), Eq("malformed SPIR-V decoration: it's empty"));
|
||||
}
|
||||
|
||||
TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithoutOperand) {
|
||||
auto p = parser(std::vector<uint32_t>{});
|
||||
|
||||
auto* result = p->ConvertMemberDecoration(12, 13, {SpvDecorationOffset});
|
||||
EXPECT_EQ(result, nullptr);
|
||||
auto result =
|
||||
p->ConvertMemberDecoration(12, 13, nullptr, {SpvDecorationOffset});
|
||||
EXPECT_TRUE(result.empty());
|
||||
EXPECT_THAT(p->error(), Eq("malformed Offset decoration: expected 1 literal "
|
||||
"operand, has 0: member 13 of SPIR-V type 12"));
|
||||
}
|
||||
|
@ -42,9 +43,9 @@ TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithoutOperand) {
|
|||
TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithTooManyOperands) {
|
||||
auto p = parser(std::vector<uint32_t>{});
|
||||
|
||||
auto* result =
|
||||
p->ConvertMemberDecoration(12, 13, {SpvDecorationOffset, 3, 4});
|
||||
EXPECT_EQ(result, nullptr);
|
||||
auto result =
|
||||
p->ConvertMemberDecoration(12, 13, nullptr, {SpvDecorationOffset, 3, 4});
|
||||
EXPECT_TRUE(result.empty());
|
||||
EXPECT_THAT(p->error(), Eq("malformed Offset decoration: expected 1 literal "
|
||||
"operand, has 2: member 13 of SPIR-V type 12"));
|
||||
}
|
||||
|
@ -52,32 +53,100 @@ TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithTooManyOperands) {
|
|||
TEST_F(SpvParserTest, ConvertMemberDecoration_Offset) {
|
||||
auto p = parser(std::vector<uint32_t>{});
|
||||
|
||||
auto* result = p->ConvertMemberDecoration(1, 1, {SpvDecorationOffset, 8});
|
||||
ASSERT_NE(result, nullptr);
|
||||
EXPECT_TRUE(result->Is<ast::StructMemberOffsetDecoration>());
|
||||
auto* offset_deco = result->As<ast::StructMemberOffsetDecoration>();
|
||||
auto result =
|
||||
p->ConvertMemberDecoration(1, 1, nullptr, {SpvDecorationOffset, 8});
|
||||
ASSERT_FALSE(result.empty());
|
||||
EXPECT_TRUE(result[0]->Is<ast::StructMemberOffsetDecoration>());
|
||||
auto* offset_deco = result[0]->As<ast::StructMemberOffsetDecoration>();
|
||||
ASSERT_NE(offset_deco, nullptr);
|
||||
EXPECT_EQ(offset_deco->offset(), 8u);
|
||||
EXPECT_TRUE(p->error().empty());
|
||||
}
|
||||
|
||||
TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x2_Stride_Natural) {
|
||||
auto p = parser(std::vector<uint32_t>{});
|
||||
|
||||
spirv::F32 f32;
|
||||
spirv::Matrix matrix(&f32, 2, 2);
|
||||
auto result =
|
||||
p->ConvertMemberDecoration(1, 1, &matrix, {SpvDecorationMatrixStride, 8});
|
||||
EXPECT_TRUE(result.empty());
|
||||
EXPECT_TRUE(p->error().empty());
|
||||
}
|
||||
|
||||
TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x2_Stride_Custom) {
|
||||
auto p = parser(std::vector<uint32_t>{});
|
||||
|
||||
spirv::F32 f32;
|
||||
spirv::Matrix matrix(&f32, 2, 2);
|
||||
auto result = p->ConvertMemberDecoration(1, 1, &matrix,
|
||||
{SpvDecorationMatrixStride, 16});
|
||||
ASSERT_FALSE(result.empty());
|
||||
EXPECT_TRUE(result[0]->Is<ast::StrideDecoration>());
|
||||
auto* stride_deco = result[0]->As<ast::StrideDecoration>();
|
||||
ASSERT_NE(stride_deco, nullptr);
|
||||
EXPECT_EQ(stride_deco->stride(), 16u);
|
||||
EXPECT_TRUE(p->error().empty());
|
||||
}
|
||||
|
||||
TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x4_Stride_Natural) {
|
||||
auto p = parser(std::vector<uint32_t>{});
|
||||
|
||||
spirv::F32 f32;
|
||||
spirv::Matrix matrix(&f32, 2, 4);
|
||||
auto result = p->ConvertMemberDecoration(1, 1, &matrix,
|
||||
{SpvDecorationMatrixStride, 16});
|
||||
EXPECT_TRUE(result.empty());
|
||||
EXPECT_TRUE(p->error().empty());
|
||||
}
|
||||
|
||||
TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x4_Stride_Custom) {
|
||||
auto p = parser(std::vector<uint32_t>{});
|
||||
|
||||
spirv::F32 f32;
|
||||
spirv::Matrix matrix(&f32, 2, 4);
|
||||
auto result = p->ConvertMemberDecoration(1, 1, &matrix,
|
||||
{SpvDecorationMatrixStride, 64});
|
||||
ASSERT_FALSE(result.empty());
|
||||
EXPECT_TRUE(result[0]->Is<ast::StrideDecoration>());
|
||||
auto* stride_deco = result[0]->As<ast::StrideDecoration>();
|
||||
ASSERT_NE(stride_deco, nullptr);
|
||||
EXPECT_EQ(stride_deco->stride(), 64u);
|
||||
EXPECT_TRUE(p->error().empty());
|
||||
}
|
||||
|
||||
TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x3_Stride_Custom) {
|
||||
auto p = parser(std::vector<uint32_t>{});
|
||||
|
||||
spirv::F32 f32;
|
||||
spirv::Matrix matrix(&f32, 2, 3);
|
||||
auto result = p->ConvertMemberDecoration(1, 1, &matrix,
|
||||
{SpvDecorationMatrixStride, 32});
|
||||
ASSERT_FALSE(result.empty());
|
||||
EXPECT_TRUE(result[0]->Is<ast::StrideDecoration>());
|
||||
auto* stride_deco = result[0]->As<ast::StrideDecoration>();
|
||||
ASSERT_NE(stride_deco, nullptr);
|
||||
EXPECT_EQ(stride_deco->stride(), 32u);
|
||||
EXPECT_TRUE(p->error().empty());
|
||||
}
|
||||
|
||||
TEST_F(SpvParserTest, ConvertMemberDecoration_RelaxedPrecision) {
|
||||
// WGSL does not support relaxed precision. Drop it.
|
||||
// It's functionally correct to use full precision f32 instead of
|
||||
// relaxed precision f32.
|
||||
auto p = parser(std::vector<uint32_t>{});
|
||||
|
||||
auto* result =
|
||||
p->ConvertMemberDecoration(1, 1, {SpvDecorationRelaxedPrecision});
|
||||
EXPECT_EQ(result, nullptr);
|
||||
auto result = p->ConvertMemberDecoration(1, 1, nullptr,
|
||||
{SpvDecorationRelaxedPrecision});
|
||||
EXPECT_TRUE(result.empty());
|
||||
EXPECT_TRUE(p->error().empty());
|
||||
}
|
||||
|
||||
TEST_F(SpvParserTest, ConvertMemberDecoration_UnhandledDecoration) {
|
||||
auto p = parser(std::vector<uint32_t>{});
|
||||
|
||||
auto* result = p->ConvertMemberDecoration(12, 13, {12345678});
|
||||
EXPECT_EQ(result, nullptr);
|
||||
auto result = p->ConvertMemberDecoration(12, 13, nullptr, {12345678});
|
||||
EXPECT_TRUE(result.empty());
|
||||
EXPECT_THAT(p->error(), Eq("unhandled member decoration: 12345678 on member "
|
||||
"13 of SPIR-V type 12"));
|
||||
}
|
||||
|
|
|
@ -2015,7 +2015,7 @@ TEST_F(SpvModuleScopeVarParserTest, ColMajorDecoration_Dropped) {
|
|||
})")) << module_str;
|
||||
}
|
||||
|
||||
TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration_Dropped) {
|
||||
TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration_Natural_Dropped) {
|
||||
auto p = parser(test::Assemble(Preamble() + FragMain() + R"(
|
||||
OpName %myvar "myvar"
|
||||
OpDecorate %myvar DescriptorSet 0
|
||||
|
@ -2054,6 +2054,45 @@ TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration_Dropped) {
|
|||
})")) << module_str;
|
||||
}
|
||||
|
||||
TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration) {
|
||||
auto p = parser(test::Assemble(Preamble() + FragMain() + R"(
|
||||
OpName %myvar "myvar"
|
||||
OpDecorate %myvar DescriptorSet 0
|
||||
OpDecorate %myvar Binding 0
|
||||
OpDecorate %s Block
|
||||
OpMemberDecorate %s 0 MatrixStride 64
|
||||
OpMemberDecorate %s 0 Offset 0
|
||||
%void = OpTypeVoid
|
||||
%voidfn = OpTypeFunction %void
|
||||
%float = OpTypeFloat 32
|
||||
%v2float = OpTypeVector %float 2
|
||||
%m3v2float = OpTypeMatrix %v2float 3
|
||||
|
||||
%s = OpTypeStruct %m3v2float
|
||||
%ptr_sb_s = OpTypePointer StorageBuffer %s
|
||||
%myvar = OpVariable %ptr_sb_s StorageBuffer
|
||||
)" + MainBody()));
|
||||
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
|
||||
EXPECT_TRUE(p->error().empty());
|
||||
const auto module_str = p->program().to_str();
|
||||
EXPECT_THAT(module_str, HasSubstr(R"(
|
||||
Struct S {
|
||||
[[block]]
|
||||
StructMember{[[ stride 64 tint_internal(disable_validation__ignore_stride) offset 0 ]] field0: __mat_2_3__f32}
|
||||
}
|
||||
Variable{
|
||||
Decorations{
|
||||
GroupDecoration{0}
|
||||
BindingDecoration{0}
|
||||
}
|
||||
myvar
|
||||
storage
|
||||
read_write
|
||||
__type_name_S
|
||||
}
|
||||
})")) << module_str;
|
||||
}
|
||||
|
||||
TEST_F(SpvModuleScopeVarParserTest, RowMajorDecoration_IsError) {
|
||||
auto p = parser(test::Assemble(Preamble() + FragMain() + R"(
|
||||
OpName %myvar "myvar"
|
||||
|
@ -2620,7 +2659,8 @@ TEST_F(SpvModuleScopeVarParserTest, SampleId_I32_Load_AccessChain) {
|
|||
private
|
||||
undefined
|
||||
__i32
|
||||
})")) <<module_str;
|
||||
})"))
|
||||
<< module_str;
|
||||
|
||||
// Correct creation of value
|
||||
EXPECT_THAT(module_str, HasSubstr(R"(
|
||||
|
@ -3006,7 +3046,8 @@ TEST_F(SpvModuleScopeVarParserTest, SampleMask_In_U32_Direct) {
|
|||
private
|
||||
undefined
|
||||
__array__u32_1
|
||||
})")) <<module_str;
|
||||
})"))
|
||||
<< module_str;
|
||||
|
||||
// Correct creation of value
|
||||
EXPECT_THAT(module_str, HasSubstr(R"(
|
||||
|
@ -3149,7 +3190,8 @@ TEST_F(SpvModuleScopeVarParserTest, SampleMask_In_U32_AccessChain) {
|
|||
private
|
||||
undefined
|
||||
__array__u32_1
|
||||
})")) <<module_str;
|
||||
})"))
|
||||
<< module_str;
|
||||
|
||||
// Correct creation of value
|
||||
EXPECT_THAT(module_str, HasSubstr(R"(
|
||||
|
@ -5543,7 +5585,6 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
// {"NumWorkgroups", "%uint", "num_workgroups"}
|
||||
// {"NumWorkgroups", "%int", "num_workgroups"}
|
||||
|
||||
|
||||
TEST_F(SpvModuleScopeVarParserTest, RegisterInputOutputVars) {
|
||||
const std::string assembly =
|
||||
R"(
|
||||
|
|
|
@ -184,18 +184,21 @@ class ParserImplWrapperForTest {
|
|||
return impl_.GetDecorationsForMember(id, member_index);
|
||||
}
|
||||
|
||||
/// Converts a SPIR-V struct member decoration. If the decoration is
|
||||
/// recognized but deliberately dropped, then returns nullptr without a
|
||||
/// diagnostic. On failure, emits a diagnostic and returns nullptr.
|
||||
/// Converts a SPIR-V struct member decoration into a number of AST
|
||||
/// decorations. If the decoration is recognized but deliberately dropped,
|
||||
/// then returns an empty list without a diagnostic. On failure, emits a
|
||||
/// diagnostic and returns an empty list.
|
||||
/// @param struct_type_id the ID of the struct type
|
||||
/// @param member_index the index of the member
|
||||
/// @param member_ty the type of the member
|
||||
/// @param decoration an encoded SPIR-V Decoration
|
||||
/// @returns the corresponding ast::StructuMemberDecoration
|
||||
ast::Decoration* ConvertMemberDecoration(uint32_t struct_type_id,
|
||||
uint32_t member_index,
|
||||
const Decoration& decoration) {
|
||||
/// @returns the AST decorations
|
||||
ast::DecorationList ConvertMemberDecoration(uint32_t struct_type_id,
|
||||
uint32_t member_index,
|
||||
const Type* member_ty,
|
||||
const Decoration& decoration) {
|
||||
return impl_.ConvertMemberDecoration(struct_type_id, member_index,
|
||||
decoration);
|
||||
member_ty, decoration);
|
||||
}
|
||||
|
||||
/// For a SPIR-V ID that might define a sampler, image, or sampled image
|
||||
|
|
|
@ -3937,13 +3937,20 @@ bool Resolver::ValidateStructure(const sem::Struct* str) {
|
|||
auto has_position = false;
|
||||
ast::InvariantDecoration* invariant_attribute = nullptr;
|
||||
for (auto* deco : member->Declaration()->decorations()) {
|
||||
if (!(deco->Is<ast::BuiltinDecoration>() ||
|
||||
deco->Is<ast::InterpolateDecoration>() ||
|
||||
deco->Is<ast::InvariantDecoration>() ||
|
||||
deco->Is<ast::LocationDecoration>() ||
|
||||
deco->Is<ast::StructMemberOffsetDecoration>() ||
|
||||
deco->Is<ast::StructMemberSizeDecoration>() ||
|
||||
deco->Is<ast::StructMemberAlignDecoration>())) {
|
||||
if (!deco->IsAnyOf<ast::BuiltinDecoration, //
|
||||
ast::InternalDecoration, //
|
||||
ast::InterpolateDecoration, //
|
||||
ast::InvariantDecoration, //
|
||||
ast::LocationDecoration, //
|
||||
ast::StructMemberOffsetDecoration, //
|
||||
ast::StructMemberSizeDecoration, //
|
||||
ast::StructMemberAlignDecoration>()) {
|
||||
if (deco->Is<ast::StrideDecoration>() &&
|
||||
IsValidationDisabled(
|
||||
member->Declaration()->decorations(),
|
||||
ast::DisabledValidation::kIgnoreStrideDecoration)) {
|
||||
continue;
|
||||
}
|
||||
AddError("decoration is not valid for structure members",
|
||||
deco->source());
|
||||
return false;
|
||||
|
|
|
@ -0,0 +1,251 @@
|
|||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/transform/decompose_strided_matrix.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/program_builder.h"
|
||||
#include "src/sem/expression.h"
|
||||
#include "src/sem/member_accessor_expression.h"
|
||||
#include "src/transform/inline_pointer_lets.h"
|
||||
#include "src/transform/simplify.h"
|
||||
#include "src/utils/get_or_create.h"
|
||||
#include "src/utils/hash.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeStridedMatrix);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
/// MatrixInfo describes a matrix member with a custom stride
|
||||
struct MatrixInfo {
|
||||
/// The stride in bytes between columns of the matrix
|
||||
uint32_t stride = 0;
|
||||
/// The type of the matrix
|
||||
sem::Matrix const* matrix = nullptr;
|
||||
|
||||
/// @returns a new ast::Array that holds an vector column for each row of the
|
||||
/// matrix.
|
||||
ast::Array* array(ProgramBuilder* b) const {
|
||||
return b->ty.array(b->ty.vec<ProgramBuilder::f32>(matrix->rows()),
|
||||
matrix->columns(), stride);
|
||||
}
|
||||
|
||||
/// Equality operator
|
||||
bool operator==(const MatrixInfo& info) const {
|
||||
return stride == info.stride && matrix == info.matrix;
|
||||
}
|
||||
/// Hash function
|
||||
struct Hasher {
|
||||
size_t operator()(const MatrixInfo& t) const {
|
||||
return utils::Hash(t.stride, t.matrix);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
/// Return type of the callback function of GatherCustomStrideMatrixMembers
|
||||
enum GatherResult { kContinue, kStop };
|
||||
|
||||
/// GatherCustomStrideMatrixMembers scans `program` for all matrix members of
|
||||
/// storage and uniform structs, which are of a matrix type, and have a custom
|
||||
/// matrix stride attribute. For each matrix member found, `callback` is called.
|
||||
/// `callback` is a function with the signature:
|
||||
/// GatherResult(const sem::StructMember* member,
|
||||
/// sem::Matrix* matrix,
|
||||
/// uint32_t stride)
|
||||
/// If `callback` return GatherResult::kStop, then the scanning will immediately
|
||||
/// terminate, and GatherCustomStrideMatrixMembers() will return, otherwise
|
||||
/// scanning will continue.
|
||||
template <typename F>
|
||||
void GatherCustomStrideMatrixMembers(const Program* program, F&& callback) {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (auto* str = node->As<ast::Struct>()) {
|
||||
auto* str_ty = program->Sem().Get(str);
|
||||
if (!str_ty->UsedAs(ast::StorageClass::kUniform) &&
|
||||
!str_ty->UsedAs(ast::StorageClass::kStorage)) {
|
||||
continue;
|
||||
}
|
||||
for (auto* member : str_ty->Members()) {
|
||||
auto* matrix = member->Type()->As<sem::Matrix>();
|
||||
if (!matrix) {
|
||||
continue;
|
||||
}
|
||||
auto* deco = ast::GetDecoration<ast::StrideDecoration>(
|
||||
member->Declaration()->decorations());
|
||||
if (!deco) {
|
||||
continue;
|
||||
}
|
||||
uint32_t stride = deco->stride();
|
||||
if (matrix->ColumnStride() == stride) {
|
||||
continue;
|
||||
}
|
||||
if (callback(member, matrix, stride) == GatherResult::kStop) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
DecomposeStridedMatrix::DecomposeStridedMatrix() = default;
|
||||
|
||||
DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
|
||||
|
||||
bool DecomposeStridedMatrix::ShouldRun(const Program* program) {
|
||||
bool should_run = false;
|
||||
GatherCustomStrideMatrixMembers(
|
||||
program, [&](const sem::StructMember*, sem::Matrix*, uint32_t) {
|
||||
should_run = true;
|
||||
return GatherResult::kStop;
|
||||
});
|
||||
return should_run;
|
||||
}
|
||||
|
||||
void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
if (!Requires<InlinePointerLets, Simplify>(ctx)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Scan the program for all storage and uniform structure matrix members with
|
||||
// a custom stride attribute. Replace these matrices with an equivalent array,
|
||||
// and populate the `decomposed` map with the members that have been replaced.
|
||||
std::unordered_map<ast::StructMember*, MatrixInfo> decomposed;
|
||||
GatherCustomStrideMatrixMembers(
|
||||
ctx.src, [&](const sem::StructMember* member, sem::Matrix* matrix,
|
||||
uint32_t stride) {
|
||||
// We've got ourselves a struct member of a matrix type with a custom
|
||||
// stride. Replace this with an array of column vectors.
|
||||
MatrixInfo info{stride, matrix};
|
||||
auto* replacement = ctx.dst->Member(
|
||||
member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst));
|
||||
ctx.Replace(member->Declaration(), replacement);
|
||||
decomposed.emplace(member->Declaration(), info);
|
||||
return GatherResult::kContinue;
|
||||
});
|
||||
|
||||
// For all expressions where a single matrix column vector was indexed, we can
|
||||
// preserve these without calling conversion functions.
|
||||
// Example:
|
||||
// ssbo.mat[2] -> ssbo.mat[2]
|
||||
ctx.ReplaceAll(
|
||||
[&](ast::ArrayAccessorExpression* expr) -> ast::ArrayAccessorExpression* {
|
||||
if (auto* access =
|
||||
ctx.src->Sem().Get<sem::StructMemberAccess>(expr->array())) {
|
||||
auto it = decomposed.find(access->Member()->Declaration());
|
||||
if (it != decomposed.end()) {
|
||||
auto* obj = ctx.CloneWithoutTransform(expr->array());
|
||||
auto* idx = ctx.Clone(expr->idx_expr());
|
||||
return ctx.dst->IndexAccessor(obj, idx);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
// For all struct member accesses to the matrix on the LHS of an assignment,
|
||||
// we need to convert the matrix to the array before assigning to the
|
||||
// structure.
|
||||
// Example:
|
||||
// ssbo.mat = mat_to_arr(m)
|
||||
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
|
||||
ctx.ReplaceAll([&](ast::AssignmentStatement* stmt) -> ast::Statement* {
|
||||
if (auto* access =
|
||||
ctx.src->Sem().Get<sem::StructMemberAccess>(stmt->lhs())) {
|
||||
auto it = decomposed.find(access->Member()->Declaration());
|
||||
if (it == decomposed.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
MatrixInfo info = it->second;
|
||||
auto fn = utils::GetOrCreate(mat_to_arr, info, [&] {
|
||||
auto name = ctx.dst->Symbols().New(
|
||||
"mat" + std::to_string(info.matrix->columns()) + "x" +
|
||||
std::to_string(info.matrix->rows()) + "_stride_" +
|
||||
std::to_string(info.stride) + "_to_arr");
|
||||
|
||||
auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
|
||||
auto array = [&] { return info.array(ctx.dst); };
|
||||
|
||||
auto mat = ctx.dst->Sym("mat");
|
||||
ast::ExpressionList columns(info.matrix->columns());
|
||||
for (uint32_t i = 0; i < static_cast<uint32_t>(columns.size()); i++) {
|
||||
columns[i] = ctx.dst->IndexAccessor(mat, i);
|
||||
}
|
||||
ctx.dst->Func(name,
|
||||
{
|
||||
ctx.dst->Param(mat, matrix()),
|
||||
},
|
||||
array(),
|
||||
{
|
||||
ctx.dst->Return(ctx.dst->Construct(array(), columns)),
|
||||
});
|
||||
return name;
|
||||
});
|
||||
auto* lhs = ctx.CloneWithoutTransform(stmt->lhs());
|
||||
auto* rhs = ctx.dst->Call(fn, ctx.Clone(stmt->rhs()));
|
||||
return ctx.dst->Assign(lhs, rhs);
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
// For all other struct member accesses, we need to convert the array to the
|
||||
// matrix type. Example:
|
||||
// m = arr_to_mat(ssbo.mat)
|
||||
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
|
||||
ctx.ReplaceAll([&](ast::MemberAccessorExpression* expr) -> ast::Expression* {
|
||||
if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr)) {
|
||||
auto it = decomposed.find(access->Member()->Declaration());
|
||||
if (it == decomposed.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
MatrixInfo info = it->second;
|
||||
auto fn = utils::GetOrCreate(arr_to_mat, info, [&] {
|
||||
auto name = ctx.dst->Symbols().New(
|
||||
"arr_to_mat" + std::to_string(info.matrix->columns()) + "x" +
|
||||
std::to_string(info.matrix->rows()) + "_stride_" +
|
||||
std::to_string(info.stride));
|
||||
|
||||
auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
|
||||
auto array = [&] { return info.array(ctx.dst); };
|
||||
|
||||
auto arr = ctx.dst->Sym("arr");
|
||||
ast::ExpressionList columns(info.matrix->columns());
|
||||
for (uint32_t i = 0; i < static_cast<uint32_t>(columns.size()); i++) {
|
||||
columns[i] = ctx.dst->IndexAccessor(arr, i);
|
||||
}
|
||||
ctx.dst->Func(
|
||||
name,
|
||||
{
|
||||
ctx.dst->Param(arr, array()),
|
||||
},
|
||||
matrix(),
|
||||
{
|
||||
ctx.dst->Return(ctx.dst->Construct(matrix(), columns)),
|
||||
});
|
||||
return name;
|
||||
});
|
||||
return ctx.dst->Call(fn, ctx.CloneWithoutTransform(expr));
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
|
@ -0,0 +1,54 @@
|
|||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TRANSFORM_DECOMPOSE_STRIDED_MATRIX_H_
|
||||
#define SRC_TRANSFORM_DECOMPOSE_STRIDED_MATRIX_H_
|
||||
|
||||
#include "src/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// DecomposeStridedMatrix transforms replaces matrix members of storage or
|
||||
/// uniform buffer structures, that have a [[stride]] decoration, into an array
|
||||
/// of N column vectors.
|
||||
/// This transform is used by the SPIR-V reader to handle the SPIR-V
|
||||
/// MatrixStride decoration.
|
||||
class DecomposeStridedMatrix
|
||||
: public Castable<DecomposeStridedMatrix, Transform> {
|
||||
public:
|
||||
/// Constructor
|
||||
DecomposeStridedMatrix();
|
||||
|
||||
/// Destructor
|
||||
~DecomposeStridedMatrix() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @returns true if this transform should be run for the given program
|
||||
static bool ShouldRun(const Program* program);
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TRANSFORM_DECOMPOSE_STRIDED_MATRIX_H_
|
|
@ -0,0 +1,727 @@
|
|||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/transform/decompose_strided_matrix.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/ast/disable_validation_decoration.h"
|
||||
#include "src/program_builder.h"
|
||||
#include "src/transform/inline_pointer_lets.h"
|
||||
#include "src/transform/simplify.h"
|
||||
#include "src/transform/test_helper.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using DecomposeStridedMatrixTest = TransformTest;
|
||||
using f32 = ProgramBuilder::f32;
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, Empty) {
|
||||
auto* src = R"()";
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, MissingDependencyInlinePointerLets) {
|
||||
auto* src = R"()";
|
||||
auto* expect =
|
||||
R"(error: tint::transform::DecomposeStridedMatrix depends on tint::transform::InlinePointerLets but the dependency was not run)";
|
||||
|
||||
auto got = Run<Simplify, DecomposeStridedMatrix>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, MissingDependencySimplify) {
|
||||
auto* src = R"()";
|
||||
auto* expect =
|
||||
R"(error: tint::transform::DecomposeStridedMatrix depends on tint::transform::Simplify but the dependency was not run)";
|
||||
|
||||
auto got = Run<InlinePointerLets, DecomposeStridedMatrix>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix) {
|
||||
// [[block]]
|
||||
// struct S {
|
||||
// [[offset(16), stride(32)]]
|
||||
// [[internal(ignore_stride_decoration)]]
|
||||
// m : mat2x2<f32>;
|
||||
// };
|
||||
// [[group(0), binding(0)]] var<uniform> s : S;
|
||||
//
|
||||
// [[stage(compute), workgroup_size(1)]]
|
||||
// fn f() {
|
||||
// let x : mat2x2<f32> = s.m;
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure(
|
||||
"S",
|
||||
{
|
||||
b.Member(
|
||||
"m", b.ty.mat2x2<f32>(),
|
||||
{
|
||||
b.create<ast::StructMemberOffsetDecoration>(16),
|
||||
b.create<ast::StrideDecoration>(32),
|
||||
b.ASTNodes().Create<ast::DisableValidationDecoration>(
|
||||
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
|
||||
}),
|
||||
},
|
||||
{
|
||||
b.StructBlock(),
|
||||
});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
|
||||
b.GroupAndBinding(0, 0));
|
||||
b.Func(
|
||||
"f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
[[block]]
|
||||
struct S {
|
||||
[[size(16)]]
|
||||
padding : u32;
|
||||
m : [[stride(32)]] array<vec2<f32>, 2>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<uniform> s : S;
|
||||
|
||||
fn arr_to_mat2x2_stride_32(arr : [[stride(32)]] array<vec2<f32>, 2>) -> mat2x2<f32> {
|
||||
return mat2x2<f32>(arr[0u], arr[1u]);
|
||||
}
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn f() {
|
||||
let x : mat2x2<f32> = arr_to_mat2x2_stride_32(s.m);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, ReadUniformColumn) {
|
||||
// [[block]]
|
||||
// struct S {
|
||||
// [[offset(16), stride(32)]]
|
||||
// [[internal(ignore_stride_decoration)]]
|
||||
// m : mat2x2<f32>;
|
||||
// };
|
||||
// [[group(0), binding(0)]] var<uniform> s : S;
|
||||
//
|
||||
// [[stage(compute), workgroup_size(1)]]
|
||||
// fn f() {
|
||||
// let x : vec2<f32> = s.m[1];
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure(
|
||||
"S",
|
||||
{
|
||||
b.Member(
|
||||
"m", b.ty.mat2x2<f32>(),
|
||||
{
|
||||
b.create<ast::StructMemberOffsetDecoration>(16),
|
||||
b.create<ast::StrideDecoration>(32),
|
||||
b.ASTNodes().Create<ast::DisableValidationDecoration>(
|
||||
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
|
||||
}),
|
||||
},
|
||||
{
|
||||
b.StructBlock(),
|
||||
});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
|
||||
b.GroupAndBinding(0, 0));
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("x", b.ty.vec2<f32>(),
|
||||
b.IndexAccessor(b.MemberAccessor("s", "m"), 1))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
[[block]]
|
||||
struct S {
|
||||
[[size(16)]]
|
||||
padding : u32;
|
||||
m : [[stride(32)]] array<vec2<f32>, 2>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<uniform> s : S;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn f() {
|
||||
let x : vec2<f32> = s.m[1];
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix_DefaultStride) {
|
||||
// [[block]]
|
||||
// struct S {
|
||||
// [[offset(16), stride(8)]]
|
||||
// [[internal(ignore_stride_decoration)]]
|
||||
// m : mat2x2<f32>;
|
||||
// };
|
||||
// [[group(0), binding(0)]] var<uniform> s : S;
|
||||
//
|
||||
// [[stage(compute), workgroup_size(1)]]
|
||||
// fn f() {
|
||||
// let x : mat2x2<f32> = s.m;
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure(
|
||||
"S",
|
||||
{
|
||||
b.Member(
|
||||
"m", b.ty.mat2x2<f32>(),
|
||||
{
|
||||
b.create<ast::StructMemberOffsetDecoration>(16),
|
||||
b.create<ast::StrideDecoration>(8),
|
||||
b.ASTNodes().Create<ast::DisableValidationDecoration>(
|
||||
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
|
||||
}),
|
||||
},
|
||||
{
|
||||
b.StructBlock(),
|
||||
});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
|
||||
b.GroupAndBinding(0, 0));
|
||||
b.Func(
|
||||
"f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
[[block]]
|
||||
struct S {
|
||||
[[size(16)]]
|
||||
padding : u32;
|
||||
[[stride(8), internal(disable_validation__ignore_stride)]]
|
||||
m : mat2x2<f32>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<uniform> s : S;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn f() {
|
||||
let x : mat2x2<f32> = s.m;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, ReadStorageMatrix) {
|
||||
// [[block]]
|
||||
// struct S {
|
||||
// [[offset(8), stride(32)]]
|
||||
// [[internal(ignore_stride_decoration)]]
|
||||
// m : mat2x2<f32>;
|
||||
// };
|
||||
// [[group(0), binding(0)]] var<storage, read_write> s : S;
|
||||
//
|
||||
// [[stage(compute), workgroup_size(1)]]
|
||||
// fn f() {
|
||||
// let x : mat2x2<f32> = s.m;
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure(
|
||||
"S",
|
||||
{
|
||||
b.Member(
|
||||
"m", b.ty.mat2x2<f32>(),
|
||||
{
|
||||
b.create<ast::StructMemberOffsetDecoration>(8),
|
||||
b.create<ast::StrideDecoration>(32),
|
||||
b.ASTNodes().Create<ast::DisableValidationDecoration>(
|
||||
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
|
||||
}),
|
||||
},
|
||||
{
|
||||
b.StructBlock(),
|
||||
});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
|
||||
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
|
||||
b.Func(
|
||||
"f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
[[block]]
|
||||
struct S {
|
||||
[[size(8)]]
|
||||
padding : u32;
|
||||
m : [[stride(32)]] array<vec2<f32>, 2>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> s : S;
|
||||
|
||||
fn arr_to_mat2x2_stride_32(arr : [[stride(32)]] array<vec2<f32>, 2>) -> mat2x2<f32> {
|
||||
return mat2x2<f32>(arr[0u], arr[1u]);
|
||||
}
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn f() {
|
||||
let x : mat2x2<f32> = arr_to_mat2x2_stride_32(s.m);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, ReadStorageColumn) {
|
||||
// [[block]]
|
||||
// struct S {
|
||||
// [[offset(16), stride(32)]]
|
||||
// [[internal(ignore_stride_decoration)]]
|
||||
// m : mat2x2<f32>;
|
||||
// };
|
||||
// [[group(0), binding(0)]] var<storage, read_write> s : S;
|
||||
//
|
||||
// [[stage(compute), workgroup_size(1)]]
|
||||
// fn f() {
|
||||
// let x : vec2<f32> = s.m[1];
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure(
|
||||
"S",
|
||||
{
|
||||
b.Member(
|
||||
"m", b.ty.mat2x2<f32>(),
|
||||
{
|
||||
b.create<ast::StructMemberOffsetDecoration>(16),
|
||||
b.create<ast::StrideDecoration>(32),
|
||||
b.ASTNodes().Create<ast::DisableValidationDecoration>(
|
||||
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
|
||||
}),
|
||||
},
|
||||
{
|
||||
b.StructBlock(),
|
||||
});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
|
||||
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("x", b.ty.vec2<f32>(),
|
||||
b.IndexAccessor(b.MemberAccessor("s", "m"), 1))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
[[block]]
|
||||
struct S {
|
||||
[[size(16)]]
|
||||
padding : u32;
|
||||
m : [[stride(32)]] array<vec2<f32>, 2>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> s : S;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn f() {
|
||||
let x : vec2<f32> = s.m[1];
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, WriteStorageMatrix) {
|
||||
// [[block]]
|
||||
// struct S {
|
||||
// [[offset(8), stride(32)]]
|
||||
// [[internal(ignore_stride_decoration)]]
|
||||
// m : mat2x2<f32>;
|
||||
// };
|
||||
// [[group(0), binding(0)]] var<storage, read_write> s : S;
|
||||
//
|
||||
// [[stage(compute), workgroup_size(1)]]
|
||||
// fn f() {
|
||||
// s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure(
|
||||
"S",
|
||||
{
|
||||
b.Member(
|
||||
"m", b.ty.mat2x2<f32>(),
|
||||
{
|
||||
b.create<ast::StructMemberOffsetDecoration>(8),
|
||||
b.create<ast::StrideDecoration>(32),
|
||||
b.ASTNodes().Create<ast::DisableValidationDecoration>(
|
||||
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
|
||||
}),
|
||||
},
|
||||
{
|
||||
b.StructBlock(),
|
||||
});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
|
||||
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Assign(b.MemberAccessor("s", "m"),
|
||||
b.mat2x2<f32>(b.vec2<f32>(1.0f, 2.0f),
|
||||
b.vec2<f32>(3.0f, 4.0f))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
[[block]]
|
||||
struct S {
|
||||
[[size(8)]]
|
||||
padding : u32;
|
||||
m : [[stride(32)]] array<vec2<f32>, 2>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> s : S;
|
||||
|
||||
fn mat2x2_stride_32_to_arr(mat : mat2x2<f32>) -> [[stride(32)]] array<vec2<f32>, 2> {
|
||||
return [[stride(32)]] array<vec2<f32>, 2>(mat[0u], mat[1u]);
|
||||
}
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn f() {
|
||||
s.m = mat2x2_stride_32_to_arr(mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0)));
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, WriteStorageColumn) {
|
||||
// [[block]]
|
||||
// struct S {
|
||||
// [[offset(8), stride(32)]]
|
||||
// [[internal(ignore_stride_decoration)]]
|
||||
// m : mat2x2<f32>;
|
||||
// };
|
||||
// [[group(0), binding(0)]] var<storage, read_write> s : S;
|
||||
//
|
||||
// [[stage(compute), workgroup_size(1)]]
|
||||
// fn f() {
|
||||
// s.m[1] = vec2<f32>(1.0, 2.0);
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure(
|
||||
"S",
|
||||
{
|
||||
b.Member(
|
||||
"m", b.ty.mat2x2<f32>(),
|
||||
{
|
||||
b.create<ast::StructMemberOffsetDecoration>(8),
|
||||
b.create<ast::StrideDecoration>(32),
|
||||
b.ASTNodes().Create<ast::DisableValidationDecoration>(
|
||||
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
|
||||
}),
|
||||
},
|
||||
{
|
||||
b.StructBlock(),
|
||||
});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
|
||||
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m"), 1),
|
||||
b.vec2<f32>(1.0f, 2.0f)),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
[[block]]
|
||||
struct S {
|
||||
[[size(8)]]
|
||||
padding : u32;
|
||||
m : [[stride(32)]] array<vec2<f32>, 2>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> s : S;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn f() {
|
||||
s.m[1] = vec2<f32>(1.0, 2.0);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, ReadWriteViaPointerLets) {
|
||||
// [[block]]
|
||||
// struct S {
|
||||
// [[offset(8), stride(32)]]
|
||||
// [[internal(ignore_stride_decoration)]]
|
||||
// m : mat2x2<f32>;
|
||||
// };
|
||||
// [[group(0), binding(0)]] var<storage, read_write> s : S;
|
||||
//
|
||||
// [[stage(compute), workgroup_size(1)]]
|
||||
// fn f() {
|
||||
// let a = &s.m;
|
||||
// let b = &*&*(a);
|
||||
// let x = *b;
|
||||
// let y = (*b)[1];
|
||||
// let z = x[1];
|
||||
// (*b) = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
|
||||
// (*b)[1] = vec2<f32>(5.0, 6.0);
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure(
|
||||
"S",
|
||||
{
|
||||
b.Member(
|
||||
"m", b.ty.mat2x2<f32>(),
|
||||
{
|
||||
b.create<ast::StructMemberOffsetDecoration>(8),
|
||||
b.create<ast::StrideDecoration>(32),
|
||||
b.ASTNodes().Create<ast::DisableValidationDecoration>(
|
||||
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
|
||||
}),
|
||||
},
|
||||
{
|
||||
b.StructBlock(),
|
||||
});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
|
||||
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
|
||||
b.Func(
|
||||
"f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(
|
||||
b.Const("a", nullptr, b.AddressOf(b.MemberAccessor("s", "m")))),
|
||||
b.Decl(b.Const("b", nullptr,
|
||||
b.AddressOf(b.Deref(b.AddressOf(b.Deref("a")))))),
|
||||
b.Decl(b.Const("x", nullptr, b.Deref("b"))),
|
||||
b.Decl(b.Const("y", nullptr, b.IndexAccessor(b.Deref("b"), 1))),
|
||||
b.Decl(b.Const("z", nullptr, b.IndexAccessor("x", 1))),
|
||||
b.Assign(b.Deref("b"), b.mat2x2<f32>(b.vec2<f32>(1.0f, 2.0f),
|
||||
b.vec2<f32>(3.0f, 4.0f))),
|
||||
b.Assign(b.IndexAccessor(b.Deref("b"), 1), b.vec2<f32>(5.0f, 6.0f)),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
[[block]]
|
||||
struct S {
|
||||
[[size(8)]]
|
||||
padding : u32;
|
||||
m : [[stride(32)]] array<vec2<f32>, 2>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> s : S;
|
||||
|
||||
fn arr_to_mat2x2_stride_32(arr : [[stride(32)]] array<vec2<f32>, 2>) -> mat2x2<f32> {
|
||||
return mat2x2<f32>(arr[0u], arr[1u]);
|
||||
}
|
||||
|
||||
fn mat2x2_stride_32_to_arr(mat : mat2x2<f32>) -> [[stride(32)]] array<vec2<f32>, 2> {
|
||||
return [[stride(32)]] array<vec2<f32>, 2>(mat[0u], mat[1u]);
|
||||
}
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn f() {
|
||||
let x = arr_to_mat2x2_stride_32(s.m);
|
||||
let y = s.m[1];
|
||||
let z = x[1];
|
||||
s.m = mat2x2_stride_32_to_arr(mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0)));
|
||||
s.m[1] = vec2<f32>(5.0, 6.0);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, ReadPrivateMatrix) {
|
||||
// struct S {
|
||||
// [[offset(8), stride(32)]]
|
||||
// [[internal(ignore_stride_decoration)]]
|
||||
// m : mat2x2<f32>;
|
||||
// };
|
||||
// var<private> s : S;
|
||||
//
|
||||
// [[stage(compute), workgroup_size(1)]]
|
||||
// fn f() {
|
||||
// let x : mat2x2<f32> = s.m;
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure(
|
||||
"S",
|
||||
{
|
||||
b.Member(
|
||||
"m", b.ty.mat2x2<f32>(),
|
||||
{
|
||||
b.create<ast::StructMemberOffsetDecoration>(8),
|
||||
b.create<ast::StrideDecoration>(32),
|
||||
b.ASTNodes().Create<ast::DisableValidationDecoration>(
|
||||
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
|
||||
}),
|
||||
});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kPrivate);
|
||||
b.Func(
|
||||
"f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
[[size(8)]]
|
||||
padding : u32;
|
||||
[[stride(32), internal(disable_validation__ignore_stride)]]
|
||||
m : mat2x2<f32>;
|
||||
};
|
||||
|
||||
var<private> s : S;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn f() {
|
||||
let x : mat2x2<f32> = s.m;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, WritePrivateMatrix) {
|
||||
// struct S {
|
||||
// [[offset(8), stride(32)]]
|
||||
// [[internal(ignore_stride_decoration)]]
|
||||
// m : mat2x2<f32>;
|
||||
// };
|
||||
// var<private> s : S;
|
||||
//
|
||||
// [[stage(compute), workgroup_size(1)]]
|
||||
// fn f() {
|
||||
// s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure(
|
||||
"S",
|
||||
{
|
||||
b.Member(
|
||||
"m", b.ty.mat2x2<f32>(),
|
||||
{
|
||||
b.create<ast::StructMemberOffsetDecoration>(8),
|
||||
b.create<ast::StrideDecoration>(32),
|
||||
b.ASTNodes().Create<ast::DisableValidationDecoration>(
|
||||
b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
|
||||
}),
|
||||
});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kPrivate);
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Assign(b.MemberAccessor("s", "m"),
|
||||
b.mat2x2<f32>(b.vec2<f32>(1.0f, 2.0f),
|
||||
b.vec2<f32>(3.0f, 4.0f))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
[[size(8)]]
|
||||
padding : u32;
|
||||
[[stride(32), internal(disable_validation__ignore_stride)]]
|
||||
m : mat2x2<f32>;
|
||||
};
|
||||
|
||||
var<private> s : S;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn f() {
|
||||
s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
|
@ -59,6 +59,16 @@ class TransformTestBase : public BASE {
|
|||
// Keep this pointer alive after Transform() returns
|
||||
files_.emplace_back(std::move(file));
|
||||
|
||||
return Run<TRANSFORMS...>(std::move(program), data);
|
||||
}
|
||||
|
||||
/// Transforms and returns program `program`, transformed using a transform of
|
||||
/// type `TRANSFORM`.
|
||||
/// @param program the input Program
|
||||
/// @param data the optional DataMap to pass to Transform::Run()
|
||||
/// @return the transformed output
|
||||
template <typename... TRANSFORMS>
|
||||
Output Run(Program&& program, const DataMap& data = {}) {
|
||||
if (!program.IsValid()) {
|
||||
return Output(std::move(program));
|
||||
}
|
||||
|
|
|
@ -696,6 +696,8 @@ bool GeneratorImpl::EmitDecorations(std::ostream& out,
|
|||
out << "size(" << size->size() << ")";
|
||||
} else if (auto* align = deco->As<ast::StructMemberAlignDecoration>()) {
|
||||
out << "align(" << align->align() << ")";
|
||||
} else if (auto* stride = deco->As<ast::StrideDecoration>()) {
|
||||
out << "stride(" << stride->stride() << ")";
|
||||
} else if (auto* internal = deco->As<ast::InternalDecoration>()) {
|
||||
out << "internal(" << internal->InternalName() << ")";
|
||||
} else {
|
||||
|
|
|
@ -289,6 +289,7 @@ tint_unittests_source_set("tint_unittests_core_src") {
|
|||
"../src/transform/calculate_array_length_test.cc",
|
||||
"../src/transform/canonicalize_entry_point_io_test.cc",
|
||||
"../src/transform/decompose_memory_access_test.cc",
|
||||
"../src/transform/decompose_strided_matrix_test.cc",
|
||||
"../src/transform/external_texture_transform_test.cc",
|
||||
"../src/transform/first_index_offset_test.cc",
|
||||
"../src/transform/fold_constants_test.cc",
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
[[block]]
|
||||
struct SSBO {
|
||||
m : mat2x2<f32>;
|
||||
};
|
||||
[[group(0), binding(0)]] var<storage, read_write> ssbo : SSBO;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn f() {
|
||||
let v = ssbo.m;
|
||||
ssbo.m = v;
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
RWByteAddressBuffer ssbo : register(u0, space0);
|
||||
|
||||
float2x2 tint_symbol(RWByteAddressBuffer buffer, uint offset) {
|
||||
return float2x2(asfloat(buffer.Load2((offset + 0u))), asfloat(buffer.Load2((offset + 8u))));
|
||||
}
|
||||
|
||||
void tint_symbol_2(RWByteAddressBuffer buffer, uint offset, float2x2 value) {
|
||||
buffer.Store2((offset + 0u), asuint(value[0u]));
|
||||
buffer.Store2((offset + 8u), asuint(value[1u]));
|
||||
}
|
||||
|
||||
[numthreads(1, 1, 1)]
|
||||
void f() {
|
||||
const float2x2 v = tint_symbol(ssbo, 0u);
|
||||
tint_symbol_2(ssbo, 0u, v);
|
||||
return;
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
struct SSBO {
|
||||
/* 0x0000 */ float2x2 m;
|
||||
};
|
||||
|
||||
kernel void f(device SSBO& ssbo [[buffer(0)]]) {
|
||||
float2x2 const v = ssbo.m;
|
||||
ssbo.m = v;
|
||||
return;
|
||||
}
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
; SPIR-V
|
||||
; Version: 1.3
|
||||
; Generator: Google Tint Compiler; 0
|
||||
; Bound: 17
|
||||
; Schema: 0
|
||||
OpCapability Shader
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %f "f"
|
||||
OpExecutionMode %f LocalSize 1 1 1
|
||||
OpName %SSBO "SSBO"
|
||||
OpMemberName %SSBO 0 "m"
|
||||
OpName %ssbo "ssbo"
|
||||
OpName %f "f"
|
||||
OpDecorate %SSBO Block
|
||||
OpMemberDecorate %SSBO 0 Offset 0
|
||||
OpMemberDecorate %SSBO 0 ColMajor
|
||||
OpMemberDecorate %SSBO 0 MatrixStride 8
|
||||
OpDecorate %ssbo DescriptorSet 0
|
||||
OpDecorate %ssbo Binding 0
|
||||
%float = OpTypeFloat 32
|
||||
%v2float = OpTypeVector %float 2
|
||||
%mat2v2float = OpTypeMatrix %v2float 2
|
||||
%SSBO = OpTypeStruct %mat2v2float
|
||||
%_ptr_StorageBuffer_SSBO = OpTypePointer StorageBuffer %SSBO
|
||||
%ssbo = OpVariable %_ptr_StorageBuffer_SSBO StorageBuffer
|
||||
%void = OpTypeVoid
|
||||
%7 = OpTypeFunction %void
|
||||
%uint = OpTypeInt 32 0
|
||||
%uint_0 = OpConstant %uint 0
|
||||
%_ptr_StorageBuffer_mat2v2float = OpTypePointer StorageBuffer %mat2v2float
|
||||
%f = OpFunction %void None %7
|
||||
%10 = OpLabel
|
||||
%14 = OpAccessChain %_ptr_StorageBuffer_mat2v2float %ssbo %uint_0
|
||||
%15 = OpLoad %mat2v2float %14
|
||||
%16 = OpAccessChain %_ptr_StorageBuffer_mat2v2float %ssbo %uint_0
|
||||
OpStore %16 %15
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -0,0 +1,12 @@
|
|||
[[block]]
|
||||
struct SSBO {
|
||||
m : mat2x2<f32>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> ssbo : SSBO;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn f() {
|
||||
let v = ssbo.m;
|
||||
ssbo.m = v;
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
OpCapability Shader
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %f "f"
|
||||
OpExecutionMode %f LocalSize 1 1 1
|
||||
OpName %SSBO "SSBO"
|
||||
OpMemberName %SSBO 0 "m"
|
||||
OpName %ssbo "ssbo"
|
||||
OpName %f "f"
|
||||
OpDecorate %SSBO Block
|
||||
OpMemberDecorate %SSBO 0 Offset 0
|
||||
OpMemberDecorate %SSBO 0 ColMajor
|
||||
OpMemberDecorate %SSBO 0 MatrixStride 16
|
||||
OpDecorate %ssbo DescriptorSet 0
|
||||
OpDecorate %ssbo Binding 0
|
||||
%float = OpTypeFloat 32
|
||||
%v2float = OpTypeVector %float 2
|
||||
%mat2v2float = OpTypeMatrix %v2float 2
|
||||
%SSBO = OpTypeStruct %mat2v2float
|
||||
%_ptr_StorageBuffer_SSBO = OpTypePointer StorageBuffer %SSBO
|
||||
%ssbo = OpVariable %_ptr_StorageBuffer_SSBO StorageBuffer
|
||||
%void = OpTypeVoid
|
||||
%7 = OpTypeFunction %void
|
||||
%uint = OpTypeInt 32 0
|
||||
%uint_0 = OpConstant %uint 0
|
||||
%_ptr_StorageBuffer_mat2v2float = OpTypePointer StorageBuffer %mat2v2float
|
||||
%f = OpFunction %void None %7
|
||||
%10 = OpLabel
|
||||
%14 = OpAccessChain %_ptr_StorageBuffer_mat2v2float %ssbo %uint_0
|
||||
%15 = OpLoad %mat2v2float %14
|
||||
%16 = OpAccessChain %_ptr_StorageBuffer_mat2v2float %ssbo %uint_0
|
||||
OpStore %16 %15
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -0,0 +1,47 @@
|
|||
struct tint_padded_array_element {
|
||||
float2 el;
|
||||
};
|
||||
|
||||
RWByteAddressBuffer ssbo : register(u0, space0);
|
||||
|
||||
float2x2 arr_to_mat2x2_stride_16(tint_padded_array_element arr[2]) {
|
||||
return float2x2(arr[0u].el, arr[1u].el);
|
||||
}
|
||||
|
||||
typedef tint_padded_array_element mat2x2_stride_16_to_arr_ret[2];
|
||||
mat2x2_stride_16_to_arr_ret mat2x2_stride_16_to_arr(float2x2 mat) {
|
||||
const tint_padded_array_element tint_symbol_4[2] = {{mat[0u]}, {mat[1u]}};
|
||||
return tint_symbol_4;
|
||||
}
|
||||
|
||||
typedef tint_padded_array_element tint_symbol_ret[2];
|
||||
tint_symbol_ret tint_symbol(RWByteAddressBuffer buffer, uint offset) {
|
||||
tint_padded_array_element arr_1[2] = (tint_padded_array_element[2])0;
|
||||
{
|
||||
for(uint i = 0u; (i < 2u); i = (i + 1u)) {
|
||||
arr_1[i].el = asfloat(buffer.Load2((offset + (i * 16u))));
|
||||
}
|
||||
}
|
||||
return arr_1;
|
||||
}
|
||||
|
||||
void tint_symbol_2(RWByteAddressBuffer buffer, uint offset, tint_padded_array_element value[2]) {
|
||||
tint_padded_array_element array[2] = value;
|
||||
{
|
||||
for(uint i_1 = 0u; (i_1 < 2u); i_1 = (i_1 + 1u)) {
|
||||
buffer.Store2((offset + (i_1 * 16u)), asuint(array[i_1].el));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void f_1() {
|
||||
const float2x2 x_15 = arr_to_mat2x2_stride_16(tint_symbol(ssbo, 0u));
|
||||
tint_symbol_2(ssbo, 0u, mat2x2_stride_16_to_arr(x_15));
|
||||
return;
|
||||
}
|
||||
|
||||
[numthreads(1, 1, 1)]
|
||||
void f() {
|
||||
f_1();
|
||||
return;
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
struct tint_padded_array_element {
|
||||
/* 0x0000 */ packed_float2 el;
|
||||
/* 0x0008 */ int8_t tint_pad[8];
|
||||
};
|
||||
struct tint_array_wrapper {
|
||||
/* 0x0000 */ tint_padded_array_element arr[2];
|
||||
};
|
||||
struct SSBO {
|
||||
/* 0x0000 */ tint_array_wrapper m;
|
||||
};
|
||||
|
||||
float2x2 arr_to_mat2x2_stride_16(tint_array_wrapper arr) {
|
||||
return float2x2(arr.arr[0u].el, arr.arr[1u].el);
|
||||
}
|
||||
|
||||
tint_array_wrapper mat2x2_stride_16_to_arr(float2x2 mat) {
|
||||
tint_array_wrapper const tint_symbol = {.arr={{.el=mat[0u]}, {.el=mat[1u]}}};
|
||||
return tint_symbol;
|
||||
}
|
||||
|
||||
void f_1(device SSBO& ssbo) {
|
||||
float2x2 const x_15 = arr_to_mat2x2_stride_16(ssbo.m);
|
||||
ssbo.m = mat2x2_stride_16_to_arr(x_15);
|
||||
return;
|
||||
}
|
||||
|
||||
kernel void f(device SSBO& ssbo [[buffer(0)]]) {
|
||||
f_1(ssbo);
|
||||
return;
|
||||
}
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
; SPIR-V
|
||||
; Version: 1.3
|
||||
; Generator: Google Tint Compiler; 0
|
||||
; Bound: 39
|
||||
; Schema: 0
|
||||
OpCapability Shader
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %f "f"
|
||||
OpExecutionMode %f LocalSize 1 1 1
|
||||
OpName %SSBO "SSBO"
|
||||
OpMemberName %SSBO 0 "m"
|
||||
OpName %ssbo "ssbo"
|
||||
OpName %arr_to_mat2x2_stride_16 "arr_to_mat2x2_stride_16"
|
||||
OpName %arr "arr"
|
||||
OpName %mat2x2_stride_16_to_arr "mat2x2_stride_16_to_arr"
|
||||
OpName %mat "mat"
|
||||
OpName %f_1 "f_1"
|
||||
OpName %f "f"
|
||||
OpDecorate %SSBO Block
|
||||
OpMemberDecorate %SSBO 0 Offset 0
|
||||
OpDecorate %_arr_v2float_uint_2 ArrayStride 16
|
||||
OpDecorate %ssbo DescriptorSet 0
|
||||
OpDecorate %ssbo Binding 0
|
||||
%float = OpTypeFloat 32
|
||||
%v2float = OpTypeVector %float 2
|
||||
%uint = OpTypeInt 32 0
|
||||
%uint_2 = OpConstant %uint 2
|
||||
%_arr_v2float_uint_2 = OpTypeArray %v2float %uint_2
|
||||
%SSBO = OpTypeStruct %_arr_v2float_uint_2
|
||||
%_ptr_StorageBuffer_SSBO = OpTypePointer StorageBuffer %SSBO
|
||||
%ssbo = OpVariable %_ptr_StorageBuffer_SSBO StorageBuffer
|
||||
%mat2v2float = OpTypeMatrix %v2float 2
|
||||
%9 = OpTypeFunction %mat2v2float %_arr_v2float_uint_2
|
||||
%uint_0 = OpConstant %uint 0
|
||||
%uint_1 = OpConstant %uint 1
|
||||
%19 = OpTypeFunction %_arr_v2float_uint_2 %mat2v2float
|
||||
%void = OpTypeVoid
|
||||
%26 = OpTypeFunction %void
|
||||
%_ptr_StorageBuffer__arr_v2float_uint_2 = OpTypePointer StorageBuffer %_arr_v2float_uint_2
|
||||
%arr_to_mat2x2_stride_16 = OpFunction %mat2v2float None %9
|
||||
%arr = OpFunctionParameter %_arr_v2float_uint_2
|
||||
%13 = OpLabel
|
||||
%15 = OpCompositeExtract %v2float %arr 0
|
||||
%17 = OpCompositeExtract %v2float %arr 1
|
||||
%18 = OpCompositeConstruct %mat2v2float %15 %17
|
||||
OpReturnValue %18
|
||||
OpFunctionEnd
|
||||
%mat2x2_stride_16_to_arr = OpFunction %_arr_v2float_uint_2 None %19
|
||||
%mat = OpFunctionParameter %mat2v2float
|
||||
%22 = OpLabel
|
||||
%23 = OpCompositeExtract %v2float %mat 0
|
||||
%24 = OpCompositeExtract %v2float %mat 1
|
||||
%25 = OpCompositeConstruct %_arr_v2float_uint_2 %23 %24
|
||||
OpReturnValue %25
|
||||
OpFunctionEnd
|
||||
%f_1 = OpFunction %void None %26
|
||||
%29 = OpLabel
|
||||
%32 = OpAccessChain %_ptr_StorageBuffer__arr_v2float_uint_2 %ssbo %uint_0
|
||||
%33 = OpLoad %_arr_v2float_uint_2 %32
|
||||
%30 = OpFunctionCall %mat2v2float %arr_to_mat2x2_stride_16 %33
|
||||
%34 = OpAccessChain %_ptr_StorageBuffer__arr_v2float_uint_2 %ssbo %uint_0
|
||||
%35 = OpFunctionCall %_arr_v2float_uint_2 %mat2x2_stride_16_to_arr %30
|
||||
OpStore %34 %35
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
%f = OpFunction %void None %26
|
||||
%37 = OpLabel
|
||||
%38 = OpFunctionCall %void %f_1
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -0,0 +1,25 @@
|
|||
[[block]]
|
||||
struct SSBO {
|
||||
m : [[stride(16)]] array<vec2<f32>, 2>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> ssbo : SSBO;
|
||||
|
||||
fn arr_to_mat2x2_stride_16(arr : [[stride(16)]] array<vec2<f32>, 2>) -> mat2x2<f32> {
|
||||
return mat2x2<f32>(arr[0u], arr[1u]);
|
||||
}
|
||||
|
||||
fn mat2x2_stride_16_to_arr(mat : mat2x2<f32>) -> [[stride(16)]] array<vec2<f32>, 2> {
|
||||
return [[stride(16)]] array<vec2<f32>, 2>(mat[0u], mat[1u]);
|
||||
}
|
||||
|
||||
fn f_1() {
|
||||
let x_15 : mat2x2<f32> = arr_to_mat2x2_stride_16(ssbo.m);
|
||||
ssbo.m = mat2x2_stride_16_to_arr(x_15);
|
||||
return;
|
||||
}
|
||||
|
||||
[[stage(compute), workgroup_size(1, 1, 1)]]
|
||||
fn f() {
|
||||
f_1();
|
||||
}
|
Loading…
Reference in New Issue