mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-16 08:27:05 +00:00
reader/spirv: Handle the MatrixStride decoration
Add `transform::DecomposeStridedMatrix`, which replaces matrix members of storage or uniform buffer structures, that have a [[stride]] decoration, into an array of N column vectors. This is required to correctly handle `mat2x2` matrices in UBOs, as std140 rules will expect a default stride of 16 bytes, when in WGSL the default structure layout expects a stride of 8 bytes. Bug: tint:1047 Change-Id: If5ca3c6ec087bbc1ac31a8d9a657b99bf34042a4 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/59840 Reviewed-by: David Neto <dneto@google.com> Commit-Queue: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
committed by
Tint LUCI CQ
parent
c6cbe3fda6
commit
97668c8c37
@@ -17,6 +17,10 @@
|
||||
#include <utility>
|
||||
|
||||
#include "src/reader/spirv/parser_impl.h"
|
||||
#include "src/transform/decompose_strided_matrix.h"
|
||||
#include "src/transform/inline_pointer_lets.h"
|
||||
#include "src/transform/manager.h"
|
||||
#include "src/transform/simplify.h"
|
||||
|
||||
namespace tint {
|
||||
namespace reader {
|
||||
@@ -40,7 +44,19 @@ Program Parse(const std::vector<uint32_t>& input) {
|
||||
|
||||
ProgramBuilder output;
|
||||
CloneContext(&output, &program_with_disjoint_ast, false).Clone();
|
||||
return Program(std::move(output));
|
||||
auto program = Program(std::move(output));
|
||||
|
||||
// If the generated program contains matrices with a custom MatrixStride
|
||||
// attribute then we need to decompose these into an array of vectors
|
||||
if (transform::DecomposeStridedMatrix::ShouldRun(&program)) {
|
||||
transform::Manager manager;
|
||||
manager.Add<transform::InlinePointerLets>();
|
||||
manager.Add<transform::Simplify>();
|
||||
manager.Add<transform::DecomposeStridedMatrix>();
|
||||
return manager.Run(&program).program;
|
||||
}
|
||||
|
||||
return program;
|
||||
}
|
||||
|
||||
} // namespace spirv
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
|
||||
#include "source/opt/build_module.h"
|
||||
#include "src/ast/bitcast_expression.h"
|
||||
#include "src/ast/disable_validation_decoration.h"
|
||||
#include "src/ast/interpolate_decoration.h"
|
||||
#include "src/ast/override_decoration.h"
|
||||
#include "src/ast/struct_block_decoration.h"
|
||||
@@ -439,13 +440,14 @@ std::string ParserImpl::ShowType(uint32_t type_id) {
|
||||
return "SPIR-V type " + std::to_string(type_id);
|
||||
}
|
||||
|
||||
ast::Decoration* ParserImpl::ConvertMemberDecoration(
|
||||
ast::DecorationList ParserImpl::ConvertMemberDecoration(
|
||||
uint32_t struct_type_id,
|
||||
uint32_t member_index,
|
||||
const Type* member_ty,
|
||||
const Decoration& decoration) {
|
||||
if (decoration.empty()) {
|
||||
Fail() << "malformed SPIR-V decoration: it's empty";
|
||||
return nullptr;
|
||||
return {};
|
||||
}
|
||||
switch (decoration[0]) {
|
||||
case SpvDecorationOffset:
|
||||
@@ -454,38 +456,49 @@ ast::Decoration* ParserImpl::ConvertMemberDecoration(
|
||||
<< "malformed Offset decoration: expected 1 literal operand, has "
|
||||
<< decoration.size() - 1 << ": member " << member_index << " of "
|
||||
<< ShowType(struct_type_id);
|
||||
return nullptr;
|
||||
return {};
|
||||
}
|
||||
return create<ast::StructMemberOffsetDecoration>(Source{}, decoration[1]);
|
||||
return {
|
||||
create<ast::StructMemberOffsetDecoration>(Source{}, decoration[1]),
|
||||
};
|
||||
case SpvDecorationNonReadable:
|
||||
// WGSL doesn't have a member decoration for this. Silently drop it.
|
||||
return nullptr;
|
||||
return {};
|
||||
case SpvDecorationNonWritable:
|
||||
// WGSL doesn't have a member decoration for this.
|
||||
return nullptr;
|
||||
return {};
|
||||
case SpvDecorationColMajor:
|
||||
// WGSL only supports column major matrices.
|
||||
return nullptr;
|
||||
return {};
|
||||
case SpvDecorationRelaxedPrecision:
|
||||
// WGSL doesn't support relaxed precision.
|
||||
return nullptr;
|
||||
return {};
|
||||
case SpvDecorationRowMajor:
|
||||
Fail() << "WGSL does not support row-major matrices: can't "
|
||||
"translate member "
|
||||
<< member_index << " of " << ShowType(struct_type_id);
|
||||
return nullptr;
|
||||
return {};
|
||||
case SpvDecorationMatrixStride: {
|
||||
if (decoration.size() != 2) {
|
||||
Fail() << "malformed MatrixStride decoration: expected 1 literal "
|
||||
"operand, has "
|
||||
<< decoration.size() - 1 << ": member " << member_index << " of "
|
||||
<< ShowType(struct_type_id);
|
||||
return nullptr;
|
||||
return {};
|
||||
}
|
||||
// TODO(dneto): Fail if the matrix stride is not allocation size of the
|
||||
// column vector of the underlying matrix. This would need to unpack
|
||||
// any levels of array-ness.
|
||||
return nullptr;
|
||||
uint32_t stride = decoration[1];
|
||||
uint32_t natural_stride = 0;
|
||||
if (auto* mat = member_ty->As<Matrix>()) {
|
||||
natural_stride = (mat->rows == 2) ? 8 : 16;
|
||||
}
|
||||
if (stride == natural_stride) {
|
||||
return {};
|
||||
}
|
||||
return {
|
||||
create<ast::StrideDecoration>(Source{}, decoration[1]),
|
||||
builder_.ASTNodes().Create<ast::DisableValidationDecoration>(
|
||||
builder_.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
|
||||
};
|
||||
}
|
||||
default:
|
||||
// TODO(dneto): Support the remaining member decorations.
|
||||
@@ -493,7 +506,7 @@ ast::Decoration* ParserImpl::ConvertMemberDecoration(
|
||||
}
|
||||
Fail() << "unhandled member decoration: " << decoration[0] << " on member "
|
||||
<< member_index << " of " << ShowType(struct_type_id);
|
||||
return nullptr;
|
||||
return {};
|
||||
}
|
||||
|
||||
bool ParserImpl::BuildInternalModule() {
|
||||
@@ -1126,14 +1139,14 @@ const Type* ParserImpl::ConvertType(
|
||||
// the members are non-writable.
|
||||
is_non_writable = true;
|
||||
} else {
|
||||
auto* ast_member_decoration =
|
||||
ConvertMemberDecoration(type_id, member_index, decoration);
|
||||
auto decos = ConvertMemberDecoration(type_id, member_index,
|
||||
ast_member_ty, decoration);
|
||||
for (auto* deco : decos) {
|
||||
ast_member_decorations.emplace_back(deco);
|
||||
}
|
||||
if (!success_) {
|
||||
return nullptr;
|
||||
}
|
||||
if (ast_member_decoration) {
|
||||
ast_member_decorations.push_back(ast_member_decoration);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -272,16 +272,19 @@ class ParserImpl : Reader {
|
||||
ast::Decoration* SetLocation(ast::DecorationList* decos,
|
||||
ast::Decoration* replacement);
|
||||
|
||||
/// Converts a SPIR-V struct member decoration. If the decoration is
|
||||
/// recognized but deliberately dropped, then returns nullptr without a
|
||||
/// diagnostic. On failure, emits a diagnostic and returns nullptr.
|
||||
/// Converts a SPIR-V struct member decoration into a number of AST
|
||||
/// decorations. If the decoration is recognized but deliberately dropped,
|
||||
/// then returns an empty list without a diagnostic. On failure, emits a
|
||||
/// diagnostic and returns an empty list.
|
||||
/// @param struct_type_id the ID of the struct type
|
||||
/// @param member_index the index of the member
|
||||
/// @param member_ty the type of the member
|
||||
/// @param decoration an encoded SPIR-V Decoration
|
||||
/// @returns the corresponding ast::StructuMemberDecoration
|
||||
ast::Decoration* ConvertMemberDecoration(uint32_t struct_type_id,
|
||||
uint32_t member_index,
|
||||
const Decoration& decoration);
|
||||
/// @returns the AST decorations
|
||||
ast::DecorationList ConvertMemberDecoration(uint32_t struct_type_id,
|
||||
uint32_t member_index,
|
||||
const Type* member_ty,
|
||||
const Decoration& decoration);
|
||||
|
||||
/// Returns a string for the given type. If the type ID is invalid,
|
||||
/// then the resulting string only names the type ID.
|
||||
|
||||
@@ -25,16 +25,17 @@ using ::testing::Eq;
|
||||
TEST_F(SpvParserTest, ConvertMemberDecoration_Empty) {
|
||||
auto p = parser(std::vector<uint32_t>{});
|
||||
|
||||
auto* result = p->ConvertMemberDecoration(1, 1, {});
|
||||
EXPECT_EQ(result, nullptr);
|
||||
auto result = p->ConvertMemberDecoration(1, 1, nullptr, {});
|
||||
EXPECT_TRUE(result.empty());
|
||||
EXPECT_THAT(p->error(), Eq("malformed SPIR-V decoration: it's empty"));
|
||||
}
|
||||
|
||||
TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithoutOperand) {
|
||||
auto p = parser(std::vector<uint32_t>{});
|
||||
|
||||
auto* result = p->ConvertMemberDecoration(12, 13, {SpvDecorationOffset});
|
||||
EXPECT_EQ(result, nullptr);
|
||||
auto result =
|
||||
p->ConvertMemberDecoration(12, 13, nullptr, {SpvDecorationOffset});
|
||||
EXPECT_TRUE(result.empty());
|
||||
EXPECT_THAT(p->error(), Eq("malformed Offset decoration: expected 1 literal "
|
||||
"operand, has 0: member 13 of SPIR-V type 12"));
|
||||
}
|
||||
@@ -42,9 +43,9 @@ TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithoutOperand) {
|
||||
TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithTooManyOperands) {
|
||||
auto p = parser(std::vector<uint32_t>{});
|
||||
|
||||
auto* result =
|
||||
p->ConvertMemberDecoration(12, 13, {SpvDecorationOffset, 3, 4});
|
||||
EXPECT_EQ(result, nullptr);
|
||||
auto result =
|
||||
p->ConvertMemberDecoration(12, 13, nullptr, {SpvDecorationOffset, 3, 4});
|
||||
EXPECT_TRUE(result.empty());
|
||||
EXPECT_THAT(p->error(), Eq("malformed Offset decoration: expected 1 literal "
|
||||
"operand, has 2: member 13 of SPIR-V type 12"));
|
||||
}
|
||||
@@ -52,32 +53,100 @@ TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithTooManyOperands) {
|
||||
TEST_F(SpvParserTest, ConvertMemberDecoration_Offset) {
|
||||
auto p = parser(std::vector<uint32_t>{});
|
||||
|
||||
auto* result = p->ConvertMemberDecoration(1, 1, {SpvDecorationOffset, 8});
|
||||
ASSERT_NE(result, nullptr);
|
||||
EXPECT_TRUE(result->Is<ast::StructMemberOffsetDecoration>());
|
||||
auto* offset_deco = result->As<ast::StructMemberOffsetDecoration>();
|
||||
auto result =
|
||||
p->ConvertMemberDecoration(1, 1, nullptr, {SpvDecorationOffset, 8});
|
||||
ASSERT_FALSE(result.empty());
|
||||
EXPECT_TRUE(result[0]->Is<ast::StructMemberOffsetDecoration>());
|
||||
auto* offset_deco = result[0]->As<ast::StructMemberOffsetDecoration>();
|
||||
ASSERT_NE(offset_deco, nullptr);
|
||||
EXPECT_EQ(offset_deco->offset(), 8u);
|
||||
EXPECT_TRUE(p->error().empty());
|
||||
}
|
||||
|
||||
TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x2_Stride_Natural) {
|
||||
auto p = parser(std::vector<uint32_t>{});
|
||||
|
||||
spirv::F32 f32;
|
||||
spirv::Matrix matrix(&f32, 2, 2);
|
||||
auto result =
|
||||
p->ConvertMemberDecoration(1, 1, &matrix, {SpvDecorationMatrixStride, 8});
|
||||
EXPECT_TRUE(result.empty());
|
||||
EXPECT_TRUE(p->error().empty());
|
||||
}
|
||||
|
||||
TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x2_Stride_Custom) {
|
||||
auto p = parser(std::vector<uint32_t>{});
|
||||
|
||||
spirv::F32 f32;
|
||||
spirv::Matrix matrix(&f32, 2, 2);
|
||||
auto result = p->ConvertMemberDecoration(1, 1, &matrix,
|
||||
{SpvDecorationMatrixStride, 16});
|
||||
ASSERT_FALSE(result.empty());
|
||||
EXPECT_TRUE(result[0]->Is<ast::StrideDecoration>());
|
||||
auto* stride_deco = result[0]->As<ast::StrideDecoration>();
|
||||
ASSERT_NE(stride_deco, nullptr);
|
||||
EXPECT_EQ(stride_deco->stride(), 16u);
|
||||
EXPECT_TRUE(p->error().empty());
|
||||
}
|
||||
|
||||
TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x4_Stride_Natural) {
|
||||
auto p = parser(std::vector<uint32_t>{});
|
||||
|
||||
spirv::F32 f32;
|
||||
spirv::Matrix matrix(&f32, 2, 4);
|
||||
auto result = p->ConvertMemberDecoration(1, 1, &matrix,
|
||||
{SpvDecorationMatrixStride, 16});
|
||||
EXPECT_TRUE(result.empty());
|
||||
EXPECT_TRUE(p->error().empty());
|
||||
}
|
||||
|
||||
TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x4_Stride_Custom) {
|
||||
auto p = parser(std::vector<uint32_t>{});
|
||||
|
||||
spirv::F32 f32;
|
||||
spirv::Matrix matrix(&f32, 2, 4);
|
||||
auto result = p->ConvertMemberDecoration(1, 1, &matrix,
|
||||
{SpvDecorationMatrixStride, 64});
|
||||
ASSERT_FALSE(result.empty());
|
||||
EXPECT_TRUE(result[0]->Is<ast::StrideDecoration>());
|
||||
auto* stride_deco = result[0]->As<ast::StrideDecoration>();
|
||||
ASSERT_NE(stride_deco, nullptr);
|
||||
EXPECT_EQ(stride_deco->stride(), 64u);
|
||||
EXPECT_TRUE(p->error().empty());
|
||||
}
|
||||
|
||||
TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x3_Stride_Custom) {
|
||||
auto p = parser(std::vector<uint32_t>{});
|
||||
|
||||
spirv::F32 f32;
|
||||
spirv::Matrix matrix(&f32, 2, 3);
|
||||
auto result = p->ConvertMemberDecoration(1, 1, &matrix,
|
||||
{SpvDecorationMatrixStride, 32});
|
||||
ASSERT_FALSE(result.empty());
|
||||
EXPECT_TRUE(result[0]->Is<ast::StrideDecoration>());
|
||||
auto* stride_deco = result[0]->As<ast::StrideDecoration>();
|
||||
ASSERT_NE(stride_deco, nullptr);
|
||||
EXPECT_EQ(stride_deco->stride(), 32u);
|
||||
EXPECT_TRUE(p->error().empty());
|
||||
}
|
||||
|
||||
TEST_F(SpvParserTest, ConvertMemberDecoration_RelaxedPrecision) {
|
||||
// WGSL does not support relaxed precision. Drop it.
|
||||
// It's functionally correct to use full precision f32 instead of
|
||||
// relaxed precision f32.
|
||||
auto p = parser(std::vector<uint32_t>{});
|
||||
|
||||
auto* result =
|
||||
p->ConvertMemberDecoration(1, 1, {SpvDecorationRelaxedPrecision});
|
||||
EXPECT_EQ(result, nullptr);
|
||||
auto result = p->ConvertMemberDecoration(1, 1, nullptr,
|
||||
{SpvDecorationRelaxedPrecision});
|
||||
EXPECT_TRUE(result.empty());
|
||||
EXPECT_TRUE(p->error().empty());
|
||||
}
|
||||
|
||||
TEST_F(SpvParserTest, ConvertMemberDecoration_UnhandledDecoration) {
|
||||
auto p = parser(std::vector<uint32_t>{});
|
||||
|
||||
auto* result = p->ConvertMemberDecoration(12, 13, {12345678});
|
||||
EXPECT_EQ(result, nullptr);
|
||||
auto result = p->ConvertMemberDecoration(12, 13, nullptr, {12345678});
|
||||
EXPECT_TRUE(result.empty());
|
||||
EXPECT_THAT(p->error(), Eq("unhandled member decoration: 12345678 on member "
|
||||
"13 of SPIR-V type 12"));
|
||||
}
|
||||
|
||||
@@ -2015,7 +2015,7 @@ TEST_F(SpvModuleScopeVarParserTest, ColMajorDecoration_Dropped) {
|
||||
})")) << module_str;
|
||||
}
|
||||
|
||||
TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration_Dropped) {
|
||||
TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration_Natural_Dropped) {
|
||||
auto p = parser(test::Assemble(Preamble() + FragMain() + R"(
|
||||
OpName %myvar "myvar"
|
||||
OpDecorate %myvar DescriptorSet 0
|
||||
@@ -2054,6 +2054,45 @@ TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration_Dropped) {
|
||||
})")) << module_str;
|
||||
}
|
||||
|
||||
TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration) {
|
||||
auto p = parser(test::Assemble(Preamble() + FragMain() + R"(
|
||||
OpName %myvar "myvar"
|
||||
OpDecorate %myvar DescriptorSet 0
|
||||
OpDecorate %myvar Binding 0
|
||||
OpDecorate %s Block
|
||||
OpMemberDecorate %s 0 MatrixStride 64
|
||||
OpMemberDecorate %s 0 Offset 0
|
||||
%void = OpTypeVoid
|
||||
%voidfn = OpTypeFunction %void
|
||||
%float = OpTypeFloat 32
|
||||
%v2float = OpTypeVector %float 2
|
||||
%m3v2float = OpTypeMatrix %v2float 3
|
||||
|
||||
%s = OpTypeStruct %m3v2float
|
||||
%ptr_sb_s = OpTypePointer StorageBuffer %s
|
||||
%myvar = OpVariable %ptr_sb_s StorageBuffer
|
||||
)" + MainBody()));
|
||||
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
|
||||
EXPECT_TRUE(p->error().empty());
|
||||
const auto module_str = p->program().to_str();
|
||||
EXPECT_THAT(module_str, HasSubstr(R"(
|
||||
Struct S {
|
||||
[[block]]
|
||||
StructMember{[[ stride 64 tint_internal(disable_validation__ignore_stride) offset 0 ]] field0: __mat_2_3__f32}
|
||||
}
|
||||
Variable{
|
||||
Decorations{
|
||||
GroupDecoration{0}
|
||||
BindingDecoration{0}
|
||||
}
|
||||
myvar
|
||||
storage
|
||||
read_write
|
||||
__type_name_S
|
||||
}
|
||||
})")) << module_str;
|
||||
}
|
||||
|
||||
TEST_F(SpvModuleScopeVarParserTest, RowMajorDecoration_IsError) {
|
||||
auto p = parser(test::Assemble(Preamble() + FragMain() + R"(
|
||||
OpName %myvar "myvar"
|
||||
@@ -2620,7 +2659,8 @@ TEST_F(SpvModuleScopeVarParserTest, SampleId_I32_Load_AccessChain) {
|
||||
private
|
||||
undefined
|
||||
__i32
|
||||
})")) <<module_str;
|
||||
})"))
|
||||
<< module_str;
|
||||
|
||||
// Correct creation of value
|
||||
EXPECT_THAT(module_str, HasSubstr(R"(
|
||||
@@ -3006,7 +3046,8 @@ TEST_F(SpvModuleScopeVarParserTest, SampleMask_In_U32_Direct) {
|
||||
private
|
||||
undefined
|
||||
__array__u32_1
|
||||
})")) <<module_str;
|
||||
})"))
|
||||
<< module_str;
|
||||
|
||||
// Correct creation of value
|
||||
EXPECT_THAT(module_str, HasSubstr(R"(
|
||||
@@ -3149,7 +3190,8 @@ TEST_F(SpvModuleScopeVarParserTest, SampleMask_In_U32_AccessChain) {
|
||||
private
|
||||
undefined
|
||||
__array__u32_1
|
||||
})")) <<module_str;
|
||||
})"))
|
||||
<< module_str;
|
||||
|
||||
// Correct creation of value
|
||||
EXPECT_THAT(module_str, HasSubstr(R"(
|
||||
@@ -5543,7 +5585,6 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
// {"NumWorkgroups", "%uint", "num_workgroups"}
|
||||
// {"NumWorkgroups", "%int", "num_workgroups"}
|
||||
|
||||
|
||||
TEST_F(SpvModuleScopeVarParserTest, RegisterInputOutputVars) {
|
||||
const std::string assembly =
|
||||
R"(
|
||||
|
||||
@@ -184,18 +184,21 @@ class ParserImplWrapperForTest {
|
||||
return impl_.GetDecorationsForMember(id, member_index);
|
||||
}
|
||||
|
||||
/// Converts a SPIR-V struct member decoration. If the decoration is
|
||||
/// recognized but deliberately dropped, then returns nullptr without a
|
||||
/// diagnostic. On failure, emits a diagnostic and returns nullptr.
|
||||
/// Converts a SPIR-V struct member decoration into a number of AST
|
||||
/// decorations. If the decoration is recognized but deliberately dropped,
|
||||
/// then returns an empty list without a diagnostic. On failure, emits a
|
||||
/// diagnostic and returns an empty list.
|
||||
/// @param struct_type_id the ID of the struct type
|
||||
/// @param member_index the index of the member
|
||||
/// @param member_ty the type of the member
|
||||
/// @param decoration an encoded SPIR-V Decoration
|
||||
/// @returns the corresponding ast::StructuMemberDecoration
|
||||
ast::Decoration* ConvertMemberDecoration(uint32_t struct_type_id,
|
||||
uint32_t member_index,
|
||||
const Decoration& decoration) {
|
||||
/// @returns the AST decorations
|
||||
ast::DecorationList ConvertMemberDecoration(uint32_t struct_type_id,
|
||||
uint32_t member_index,
|
||||
const Type* member_ty,
|
||||
const Decoration& decoration) {
|
||||
return impl_.ConvertMemberDecoration(struct_type_id, member_index,
|
||||
decoration);
|
||||
member_ty, decoration);
|
||||
}
|
||||
|
||||
/// For a SPIR-V ID that might define a sampler, image, or sampled image
|
||||
|
||||
Reference in New Issue
Block a user