reader/spirv: Handle the MatrixStride decoration

Add `transform::DecomposeStridedMatrix`, which replaces matrix members of storage or uniform buffer structures, that have a [[stride]] decoration, into an array
of N column vectors.

This is required to correctly handle `mat2x2` matrices in UBOs, as std140 rules will expect a default stride of 16 bytes, when in WGSL the default structure layout expects a stride of 8 bytes.

Bug: tint:1047
Change-Id: If5ca3c6ec087bbc1ac31a8d9a657b99bf34042a4
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/59840
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ben Clayton
2021-07-27 08:17:29 +00:00
committed by Tint LUCI CQ
parent c6cbe3fda6
commit 97668c8c37
28 changed files with 1572 additions and 65 deletions

View File

@@ -17,6 +17,10 @@
#include <utility>
#include "src/reader/spirv/parser_impl.h"
#include "src/transform/decompose_strided_matrix.h"
#include "src/transform/inline_pointer_lets.h"
#include "src/transform/manager.h"
#include "src/transform/simplify.h"
namespace tint {
namespace reader {
@@ -40,7 +44,19 @@ Program Parse(const std::vector<uint32_t>& input) {
ProgramBuilder output;
CloneContext(&output, &program_with_disjoint_ast, false).Clone();
return Program(std::move(output));
auto program = Program(std::move(output));
// If the generated program contains matrices with a custom MatrixStride
// attribute then we need to decompose these into an array of vectors
if (transform::DecomposeStridedMatrix::ShouldRun(&program)) {
transform::Manager manager;
manager.Add<transform::InlinePointerLets>();
manager.Add<transform::Simplify>();
manager.Add<transform::DecomposeStridedMatrix>();
return manager.Run(&program).program;
}
return program;
}
} // namespace spirv

View File

@@ -21,6 +21,7 @@
#include "source/opt/build_module.h"
#include "src/ast/bitcast_expression.h"
#include "src/ast/disable_validation_decoration.h"
#include "src/ast/interpolate_decoration.h"
#include "src/ast/override_decoration.h"
#include "src/ast/struct_block_decoration.h"
@@ -439,13 +440,14 @@ std::string ParserImpl::ShowType(uint32_t type_id) {
return "SPIR-V type " + std::to_string(type_id);
}
ast::Decoration* ParserImpl::ConvertMemberDecoration(
ast::DecorationList ParserImpl::ConvertMemberDecoration(
uint32_t struct_type_id,
uint32_t member_index,
const Type* member_ty,
const Decoration& decoration) {
if (decoration.empty()) {
Fail() << "malformed SPIR-V decoration: it's empty";
return nullptr;
return {};
}
switch (decoration[0]) {
case SpvDecorationOffset:
@@ -454,38 +456,49 @@ ast::Decoration* ParserImpl::ConvertMemberDecoration(
<< "malformed Offset decoration: expected 1 literal operand, has "
<< decoration.size() - 1 << ": member " << member_index << " of "
<< ShowType(struct_type_id);
return nullptr;
return {};
}
return create<ast::StructMemberOffsetDecoration>(Source{}, decoration[1]);
return {
create<ast::StructMemberOffsetDecoration>(Source{}, decoration[1]),
};
case SpvDecorationNonReadable:
// WGSL doesn't have a member decoration for this. Silently drop it.
return nullptr;
return {};
case SpvDecorationNonWritable:
// WGSL doesn't have a member decoration for this.
return nullptr;
return {};
case SpvDecorationColMajor:
// WGSL only supports column major matrices.
return nullptr;
return {};
case SpvDecorationRelaxedPrecision:
// WGSL doesn't support relaxed precision.
return nullptr;
return {};
case SpvDecorationRowMajor:
Fail() << "WGSL does not support row-major matrices: can't "
"translate member "
<< member_index << " of " << ShowType(struct_type_id);
return nullptr;
return {};
case SpvDecorationMatrixStride: {
if (decoration.size() != 2) {
Fail() << "malformed MatrixStride decoration: expected 1 literal "
"operand, has "
<< decoration.size() - 1 << ": member " << member_index << " of "
<< ShowType(struct_type_id);
return nullptr;
return {};
}
// TODO(dneto): Fail if the matrix stride is not allocation size of the
// column vector of the underlying matrix. This would need to unpack
// any levels of array-ness.
return nullptr;
uint32_t stride = decoration[1];
uint32_t natural_stride = 0;
if (auto* mat = member_ty->As<Matrix>()) {
natural_stride = (mat->rows == 2) ? 8 : 16;
}
if (stride == natural_stride) {
return {};
}
return {
create<ast::StrideDecoration>(Source{}, decoration[1]),
builder_.ASTNodes().Create<ast::DisableValidationDecoration>(
builder_.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
};
}
default:
// TODO(dneto): Support the remaining member decorations.
@@ -493,7 +506,7 @@ ast::Decoration* ParserImpl::ConvertMemberDecoration(
}
Fail() << "unhandled member decoration: " << decoration[0] << " on member "
<< member_index << " of " << ShowType(struct_type_id);
return nullptr;
return {};
}
bool ParserImpl::BuildInternalModule() {
@@ -1126,14 +1139,14 @@ const Type* ParserImpl::ConvertType(
// the members are non-writable.
is_non_writable = true;
} else {
auto* ast_member_decoration =
ConvertMemberDecoration(type_id, member_index, decoration);
auto decos = ConvertMemberDecoration(type_id, member_index,
ast_member_ty, decoration);
for (auto* deco : decos) {
ast_member_decorations.emplace_back(deco);
}
if (!success_) {
return nullptr;
}
if (ast_member_decoration) {
ast_member_decorations.push_back(ast_member_decoration);
}
}
}

View File

@@ -272,16 +272,19 @@ class ParserImpl : Reader {
ast::Decoration* SetLocation(ast::DecorationList* decos,
ast::Decoration* replacement);
/// Converts a SPIR-V struct member decoration. If the decoration is
/// recognized but deliberately dropped, then returns nullptr without a
/// diagnostic. On failure, emits a diagnostic and returns nullptr.
/// Converts a SPIR-V struct member decoration into a number of AST
/// decorations. If the decoration is recognized but deliberately dropped,
/// then returns an empty list without a diagnostic. On failure, emits a
/// diagnostic and returns an empty list.
/// @param struct_type_id the ID of the struct type
/// @param member_index the index of the member
/// @param member_ty the type of the member
/// @param decoration an encoded SPIR-V Decoration
/// @returns the corresponding ast::StructuMemberDecoration
ast::Decoration* ConvertMemberDecoration(uint32_t struct_type_id,
uint32_t member_index,
const Decoration& decoration);
/// @returns the AST decorations
ast::DecorationList ConvertMemberDecoration(uint32_t struct_type_id,
uint32_t member_index,
const Type* member_ty,
const Decoration& decoration);
/// Returns a string for the given type. If the type ID is invalid,
/// then the resulting string only names the type ID.

View File

@@ -25,16 +25,17 @@ using ::testing::Eq;
TEST_F(SpvParserTest, ConvertMemberDecoration_Empty) {
auto p = parser(std::vector<uint32_t>{});
auto* result = p->ConvertMemberDecoration(1, 1, {});
EXPECT_EQ(result, nullptr);
auto result = p->ConvertMemberDecoration(1, 1, nullptr, {});
EXPECT_TRUE(result.empty());
EXPECT_THAT(p->error(), Eq("malformed SPIR-V decoration: it's empty"));
}
TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithoutOperand) {
auto p = parser(std::vector<uint32_t>{});
auto* result = p->ConvertMemberDecoration(12, 13, {SpvDecorationOffset});
EXPECT_EQ(result, nullptr);
auto result =
p->ConvertMemberDecoration(12, 13, nullptr, {SpvDecorationOffset});
EXPECT_TRUE(result.empty());
EXPECT_THAT(p->error(), Eq("malformed Offset decoration: expected 1 literal "
"operand, has 0: member 13 of SPIR-V type 12"));
}
@@ -42,9 +43,9 @@ TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithoutOperand) {
TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithTooManyOperands) {
auto p = parser(std::vector<uint32_t>{});
auto* result =
p->ConvertMemberDecoration(12, 13, {SpvDecorationOffset, 3, 4});
EXPECT_EQ(result, nullptr);
auto result =
p->ConvertMemberDecoration(12, 13, nullptr, {SpvDecorationOffset, 3, 4});
EXPECT_TRUE(result.empty());
EXPECT_THAT(p->error(), Eq("malformed Offset decoration: expected 1 literal "
"operand, has 2: member 13 of SPIR-V type 12"));
}
@@ -52,32 +53,100 @@ TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithTooManyOperands) {
TEST_F(SpvParserTest, ConvertMemberDecoration_Offset) {
auto p = parser(std::vector<uint32_t>{});
auto* result = p->ConvertMemberDecoration(1, 1, {SpvDecorationOffset, 8});
ASSERT_NE(result, nullptr);
EXPECT_TRUE(result->Is<ast::StructMemberOffsetDecoration>());
auto* offset_deco = result->As<ast::StructMemberOffsetDecoration>();
auto result =
p->ConvertMemberDecoration(1, 1, nullptr, {SpvDecorationOffset, 8});
ASSERT_FALSE(result.empty());
EXPECT_TRUE(result[0]->Is<ast::StructMemberOffsetDecoration>());
auto* offset_deco = result[0]->As<ast::StructMemberOffsetDecoration>();
ASSERT_NE(offset_deco, nullptr);
EXPECT_EQ(offset_deco->offset(), 8u);
EXPECT_TRUE(p->error().empty());
}
TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x2_Stride_Natural) {
auto p = parser(std::vector<uint32_t>{});
spirv::F32 f32;
spirv::Matrix matrix(&f32, 2, 2);
auto result =
p->ConvertMemberDecoration(1, 1, &matrix, {SpvDecorationMatrixStride, 8});
EXPECT_TRUE(result.empty());
EXPECT_TRUE(p->error().empty());
}
TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x2_Stride_Custom) {
auto p = parser(std::vector<uint32_t>{});
spirv::F32 f32;
spirv::Matrix matrix(&f32, 2, 2);
auto result = p->ConvertMemberDecoration(1, 1, &matrix,
{SpvDecorationMatrixStride, 16});
ASSERT_FALSE(result.empty());
EXPECT_TRUE(result[0]->Is<ast::StrideDecoration>());
auto* stride_deco = result[0]->As<ast::StrideDecoration>();
ASSERT_NE(stride_deco, nullptr);
EXPECT_EQ(stride_deco->stride(), 16u);
EXPECT_TRUE(p->error().empty());
}
TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x4_Stride_Natural) {
auto p = parser(std::vector<uint32_t>{});
spirv::F32 f32;
spirv::Matrix matrix(&f32, 2, 4);
auto result = p->ConvertMemberDecoration(1, 1, &matrix,
{SpvDecorationMatrixStride, 16});
EXPECT_TRUE(result.empty());
EXPECT_TRUE(p->error().empty());
}
TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x4_Stride_Custom) {
auto p = parser(std::vector<uint32_t>{});
spirv::F32 f32;
spirv::Matrix matrix(&f32, 2, 4);
auto result = p->ConvertMemberDecoration(1, 1, &matrix,
{SpvDecorationMatrixStride, 64});
ASSERT_FALSE(result.empty());
EXPECT_TRUE(result[0]->Is<ast::StrideDecoration>());
auto* stride_deco = result[0]->As<ast::StrideDecoration>();
ASSERT_NE(stride_deco, nullptr);
EXPECT_EQ(stride_deco->stride(), 64u);
EXPECT_TRUE(p->error().empty());
}
TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x3_Stride_Custom) {
auto p = parser(std::vector<uint32_t>{});
spirv::F32 f32;
spirv::Matrix matrix(&f32, 2, 3);
auto result = p->ConvertMemberDecoration(1, 1, &matrix,
{SpvDecorationMatrixStride, 32});
ASSERT_FALSE(result.empty());
EXPECT_TRUE(result[0]->Is<ast::StrideDecoration>());
auto* stride_deco = result[0]->As<ast::StrideDecoration>();
ASSERT_NE(stride_deco, nullptr);
EXPECT_EQ(stride_deco->stride(), 32u);
EXPECT_TRUE(p->error().empty());
}
TEST_F(SpvParserTest, ConvertMemberDecoration_RelaxedPrecision) {
// WGSL does not support relaxed precision. Drop it.
// It's functionally correct to use full precision f32 instead of
// relaxed precision f32.
auto p = parser(std::vector<uint32_t>{});
auto* result =
p->ConvertMemberDecoration(1, 1, {SpvDecorationRelaxedPrecision});
EXPECT_EQ(result, nullptr);
auto result = p->ConvertMemberDecoration(1, 1, nullptr,
{SpvDecorationRelaxedPrecision});
EXPECT_TRUE(result.empty());
EXPECT_TRUE(p->error().empty());
}
TEST_F(SpvParserTest, ConvertMemberDecoration_UnhandledDecoration) {
auto p = parser(std::vector<uint32_t>{});
auto* result = p->ConvertMemberDecoration(12, 13, {12345678});
EXPECT_EQ(result, nullptr);
auto result = p->ConvertMemberDecoration(12, 13, nullptr, {12345678});
EXPECT_TRUE(result.empty());
EXPECT_THAT(p->error(), Eq("unhandled member decoration: 12345678 on member "
"13 of SPIR-V type 12"));
}

View File

@@ -2015,7 +2015,7 @@ TEST_F(SpvModuleScopeVarParserTest, ColMajorDecoration_Dropped) {
})")) << module_str;
}
TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration_Dropped) {
TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration_Natural_Dropped) {
auto p = parser(test::Assemble(Preamble() + FragMain() + R"(
OpName %myvar "myvar"
OpDecorate %myvar DescriptorSet 0
@@ -2054,6 +2054,45 @@ TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration_Dropped) {
})")) << module_str;
}
TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration) {
auto p = parser(test::Assemble(Preamble() + FragMain() + R"(
OpName %myvar "myvar"
OpDecorate %myvar DescriptorSet 0
OpDecorate %myvar Binding 0
OpDecorate %s Block
OpMemberDecorate %s 0 MatrixStride 64
OpMemberDecorate %s 0 Offset 0
%void = OpTypeVoid
%voidfn = OpTypeFunction %void
%float = OpTypeFloat 32
%v2float = OpTypeVector %float 2
%m3v2float = OpTypeMatrix %v2float 3
%s = OpTypeStruct %m3v2float
%ptr_sb_s = OpTypePointer StorageBuffer %s
%myvar = OpVariable %ptr_sb_s StorageBuffer
)" + MainBody()));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
EXPECT_TRUE(p->error().empty());
const auto module_str = p->program().to_str();
EXPECT_THAT(module_str, HasSubstr(R"(
Struct S {
[[block]]
StructMember{[[ stride 64 tint_internal(disable_validation__ignore_stride) offset 0 ]] field0: __mat_2_3__f32}
}
Variable{
Decorations{
GroupDecoration{0}
BindingDecoration{0}
}
myvar
storage
read_write
__type_name_S
}
})")) << module_str;
}
TEST_F(SpvModuleScopeVarParserTest, RowMajorDecoration_IsError) {
auto p = parser(test::Assemble(Preamble() + FragMain() + R"(
OpName %myvar "myvar"
@@ -2620,7 +2659,8 @@ TEST_F(SpvModuleScopeVarParserTest, SampleId_I32_Load_AccessChain) {
private
undefined
__i32
})")) <<module_str;
})"))
<< module_str;
// Correct creation of value
EXPECT_THAT(module_str, HasSubstr(R"(
@@ -3006,7 +3046,8 @@ TEST_F(SpvModuleScopeVarParserTest, SampleMask_In_U32_Direct) {
private
undefined
__array__u32_1
})")) <<module_str;
})"))
<< module_str;
// Correct creation of value
EXPECT_THAT(module_str, HasSubstr(R"(
@@ -3149,7 +3190,8 @@ TEST_F(SpvModuleScopeVarParserTest, SampleMask_In_U32_AccessChain) {
private
undefined
__array__u32_1
})")) <<module_str;
})"))
<< module_str;
// Correct creation of value
EXPECT_THAT(module_str, HasSubstr(R"(
@@ -5543,7 +5585,6 @@ INSTANTIATE_TEST_SUITE_P(
// {"NumWorkgroups", "%uint", "num_workgroups"}
// {"NumWorkgroups", "%int", "num_workgroups"}
TEST_F(SpvModuleScopeVarParserTest, RegisterInputOutputVars) {
const std::string assembly =
R"(

View File

@@ -184,18 +184,21 @@ class ParserImplWrapperForTest {
return impl_.GetDecorationsForMember(id, member_index);
}
/// Converts a SPIR-V struct member decoration. If the decoration is
/// recognized but deliberately dropped, then returns nullptr without a
/// diagnostic. On failure, emits a diagnostic and returns nullptr.
/// Converts a SPIR-V struct member decoration into a number of AST
/// decorations. If the decoration is recognized but deliberately dropped,
/// then returns an empty list without a diagnostic. On failure, emits a
/// diagnostic and returns an empty list.
/// @param struct_type_id the ID of the struct type
/// @param member_index the index of the member
/// @param member_ty the type of the member
/// @param decoration an encoded SPIR-V Decoration
/// @returns the corresponding ast::StructuMemberDecoration
ast::Decoration* ConvertMemberDecoration(uint32_t struct_type_id,
uint32_t member_index,
const Decoration& decoration) {
/// @returns the AST decorations
ast::DecorationList ConvertMemberDecoration(uint32_t struct_type_id,
uint32_t member_index,
const Type* member_ty,
const Decoration& decoration) {
return impl_.ConvertMemberDecoration(struct_type_id, member_index,
decoration);
member_ty, decoration);
}
/// For a SPIR-V ID that might define a sampler, image, or sampled image