msl: Implement more types for structure layout

Implement layout logic for vectors, matrices and default-stride arrays.

Custom stride arrays are complex, and will be tackled as a followup change.

This change also emits byte offsets for all structure members as comments. This is even emitted for non-storage uses, which can be cleaned up as a followup.

Fixes a whole lot of TINT_ICE() for non-complex WGSL shaders.

Bug: tint:626
Change-Id: I92a78451d29bdb04dbf28862ad22317f27bced60
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/44864
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
This commit is contained in:
Ben Clayton 2021-03-18 10:03:04 +00:00 committed by Commit Bot service account
parent a88090b04d
commit 5907039597
4 changed files with 347 additions and 47 deletions

View File

@ -94,6 +94,17 @@ class Struct : public Castable<Struct, Node> {
return storage_class_usage_.count(usage) > 0; return storage_class_usage_.count(usage) > 0;
} }
/// @returns true iff this structure has been used by storage class that's
/// host-sharable.
bool IsHostSharable() const {
for (auto sc : storage_class_usage_) {
if (ast::IsHostSharable(sc)) {
return true;
}
}
return false;
}
private: private:
type::Struct* const type_; type::Struct* const type_;
StructMemberList const members_; StructMemberList const members_;

View File

@ -15,6 +15,7 @@
#include "src/writer/msl/generator_impl.h" #include "src/writer/msl/generator_impl.h"
#include <algorithm> #include <algorithm>
#include <iomanip>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -27,6 +28,7 @@
#include "src/ast/sint_literal.h" #include "src/ast/sint_literal.h"
#include "src/ast/uint_literal.h" #include "src/ast/uint_literal.h"
#include "src/ast/variable_decl_statement.h" #include "src/ast/variable_decl_statement.h"
#include "src/semantic/array.h"
#include "src/semantic/call.h" #include "src/semantic/call.h"
#include "src/semantic/function.h" #include "src/semantic/function.h"
#include "src/semantic/member_accessor_expression.h" #include "src/semantic/member_accessor_expression.h"
@ -1977,6 +1979,23 @@ bool GeneratorImpl::EmitType(type::Type* type, const std::string& name) {
return true; return true;
} }
bool GeneratorImpl::EmitPackedType(type::Type* type, const std::string& name) {
if (auto* alias = type->As<type::Alias>()) {
return EmitPackedType(alias->type(), name);
}
if (auto* vec = type->As<type::Vector>()) {
out_ << "packed_";
if (!EmitType(vec->type(), "")) {
return false;
}
out_ << vec->size();
return true;
}
return EmitType(type, name);
}
bool GeneratorImpl::EmitStructType(const type::Struct* str) { bool GeneratorImpl::EmitStructType(const type::Struct* str) {
// TODO(dsinclair): Block decoration? // TODO(dsinclair): Block decoration?
// if (str->impl()->decoration() != ast::Decoration::kNone) { // if (str->impl()->decoration() != ast::Decoration::kNone) {
@ -1984,17 +2003,32 @@ bool GeneratorImpl::EmitStructType(const type::Struct* str) {
out_ << "struct " << program_->Symbols().NameFor(str->symbol()) << " {" out_ << "struct " << program_->Symbols().NameFor(str->symbol()) << " {"
<< std::endl; << std::endl;
auto* sem_str = program_->Sem().Get(str);
if (!sem_str) {
TINT_ICE(diagnostics_) << "struct missing semantic info";
return false;
}
bool is_host_sharable = sem_str->IsHostSharable();
// Emits a `/* 0xnnnn */` byte offset comment for a struct member.
auto add_byte_offset_comment = [&](uint32_t offset) {
std::ios_base::fmtflags saved_flag_state(out_.flags());
out_ << "/* 0x" << std::hex << std::setfill('0') << std::setw(4) << offset
<< " */ ";
out_.flags(saved_flag_state);
};
uint32_t pad_count = 0; uint32_t pad_count = 0;
auto add_padding = [&](uint32_t size) { auto add_padding = [&](uint32_t size) {
out_ << "int8_t pad_" << pad_count << "[" << size << "];" << std::endl; out_ << "int8_t _tint_pad_" << pad_count << "[" << size << "];"
<< std::endl;
pad_count++; pad_count++;
}; };
increment_indent(); increment_indent();
uint32_t current_offset = 0; uint32_t msl_offset = 0;
for (auto* mem : str->impl()->members()) { for (auto* mem : str->impl()->members()) {
std::string attributes;
make_indent(); make_indent();
auto* sem_mem = program_->Sem().Get(mem); auto* sem_mem = program_->Sem().Get(mem);
@ -2003,21 +2037,36 @@ bool GeneratorImpl::EmitStructType(const type::Struct* str) {
return false; return false;
} }
auto const offset = sem_mem->Offset(); auto wgsl_offset = sem_mem->Offset();
if (offset != current_offset) {
add_padding(offset - current_offset); if (is_host_sharable) {
if (wgsl_offset < msl_offset) {
// Unimplementable layout
TINT_ICE(diagnostics_)
<< "Structure member WGSL offset (" << wgsl_offset
<< ") is behind MSL offset (" << msl_offset << ")";
return false;
}
// Generate padding if required
if (auto padding = wgsl_offset - msl_offset) {
add_byte_offset_comment(msl_offset);
add_padding(padding);
msl_offset += padding;
make_indent(); make_indent();
} }
for (auto* deco : mem->decorations()) { add_byte_offset_comment(msl_offset);
if (auto* loc = deco->As<ast::LocationDecoration>()) {
attributes = " [[user(locn" + std::to_string(loc->value()) + ")]]";
}
}
if (!EmitPackedType(mem->type(),
program_->Symbols().NameFor(mem->symbol()))) {
return false;
}
} else {
if (!EmitType(mem->type(), program_->Symbols().NameFor(mem->symbol()))) { if (!EmitType(mem->type(), program_->Symbols().NameFor(mem->symbol()))) {
return false; return false;
} }
}
auto* ty = mem->type()->UnwrapAliasIfNeeded(); auto* ty = mem->type()->UnwrapAliasIfNeeded();
@ -2026,32 +2075,32 @@ bool GeneratorImpl::EmitStructType(const type::Struct* str) {
out_ << " " << program_->Symbols().NameFor(mem->symbol()); out_ << " " << program_->Symbols().NameFor(mem->symbol());
} }
out_ << attributes; // Emit decorations
for (auto* deco : mem->decorations()) {
if (auto* loc = deco->As<ast::LocationDecoration>()) {
out_ << " [[user(locn" + std::to_string(loc->value()) + ")]]";
}
}
out_ << ";" << std::endl; out_ << ";" << std::endl;
if (ty->is_scalar()) { if (is_host_sharable) {
current_offset = offset + 4; // Calculate new MSL offset
} else if (ty->Is<type::Struct>()) { auto size_align = MslPackedTypeSizeAndAlign(ty);
/// Structure will already contain padding matching the WGSL size if (msl_offset % size_align.align) {
current_offset = offset + sem_mem->Size(); TINT_ICE(diagnostics_) << "Misaligned MSL structure member "
} else { << ty->FriendlyName(program_->Symbols()) << " "
/// TODO(bclayton): Implement for vector, matrix, array and nested << program_->Symbols().NameFor(mem->symbol());
/// structures.
TINT_UNREACHABLE(diagnostics_)
<< "Unhandled type " << ty->TypeInfo().name;
return false; return false;
} }
msl_offset += size_align.size;
}
} }
auto* sem_str = program_->Sem().Get(str); if (is_host_sharable && sem_str->Size() != msl_offset) {
if (!sem_str) {
TINT_ICE(diagnostics_) << "struct missing semantic info";
return false;
}
if (sem_str->Size() != current_offset) {
make_indent(); make_indent();
add_padding(sem_str->Size() - current_offset); add_byte_offset_comment(msl_offset);
add_padding(sem_str->Size() - msl_offset);
} }
decrement_indent(); decrement_indent();
@ -2157,6 +2206,84 @@ bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
return true; return true;
} }
GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(
type::Type* ty) {
ty = ty->UnwrapAliasIfNeeded();
if (ty->IsAnyOf<type::U32, type::I32, type::F32>()) {
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
// 2.1 Scalar Data Types
return {4, 4};
}
if (auto* vec = ty->As<type::Vector>()) {
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
// 2.2.3 Packed Vector Types
auto num_els = vec->size();
auto* el_ty = vec->type()->UnwrapAll();
if (el_ty->IsAnyOf<type::U32, type::I32, type::F32>()) {
return SizeAndAlign{num_els * 4, 4};
}
}
if (auto* mat = ty->As<type::Matrix>()) {
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
// 2.3 Matrix Data Types
auto cols = mat->columns();
auto rows = mat->rows();
auto* el_ty = mat->type()->UnwrapAll();
if (el_ty->IsAnyOf<type::U32, type::I32, type::F32>()) {
static constexpr SizeAndAlign table[] = {
/* float2x2 */ {16, 8},
/* float2x3 */ {32, 16},
/* float2x4 */ {32, 16},
/* float3x2 */ {24, 8},
/* float3x3 */ {48, 16},
/* float3x4 */ {48, 16},
/* float4x2 */ {32, 8},
/* float4x3 */ {64, 16},
/* float4x4 */ {64, 16},
};
if (cols >= 2 && cols <= 4 && rows >= 2 && rows <= 4) {
return table[(3 * (cols - 2)) + (rows - 2)];
}
}
}
if (auto* arr = ty->As<type::Array>()) {
auto* sem = program_->Sem().Get(arr);
if (!sem) {
TINT_ICE(diagnostics_) << "Array missing semantic info";
return {};
}
auto el_size_align = MslPackedTypeSizeAndAlign(arr->type());
if (sem->Stride() != el_size_align.size) {
// TODO(crbug.com/tint/649): transform::Msl needs to replace these arrays
// with a new array type that has the element type padded to the required
// stride.
TINT_UNIMPLEMENTED(diagnostics_)
<< "Arrays with custom strides not yet implemented";
return {};
}
auto num_els = std::max<uint32_t>(arr->size(), 1);
return SizeAndAlign{el_size_align.size * num_els, el_size_align.align};
}
if (auto* str = ty->As<type::Struct>()) {
// TODO(crbug.com/tint/650): There's an assumption here that MSL's default
// structure size and alignment matches WGSL's. We need to confirm this.
auto* sem = program_->Sem().Get(str);
if (!sem) {
TINT_ICE(diagnostics_) << "Array missing semantic info";
return {};
}
return SizeAndAlign{sem->Size(), sem->Align()};
}
TINT_UNREACHABLE(diagnostics_) << "Unhandled type " << ty->TypeInfo().name;
return {};
}
} // namespace msl } // namespace msl
} // namespace writer } // namespace writer
} // namespace tint } // namespace tint

View File

@ -191,11 +191,18 @@ class GeneratorImpl : public TextGenerator {
/// @param stmt the statement to emit /// @param stmt the statement to emit
/// @returns true if the statement was emitted /// @returns true if the statement was emitted
bool EmitSwitch(ast::SwitchStatement* stmt); bool EmitSwitch(ast::SwitchStatement* stmt);
/// Handles generating type /// Handles generating a type
/// @param type the type to generate /// @param type the type to generate
/// @param name the name of the variable, only used for array emission /// @param name the name of the variable, only used for array emission
/// @returns true if the type is emitted /// @returns true if the type is emitted
bool EmitType(type::Type* type, const std::string& name); bool EmitType(type::Type* type, const std::string& name);
/// Handles generating an MSL-packed storage type.
/// If the type does not have a packed form, the standard non-packed form is
/// emitted.
/// @param type the type to generate
/// @param name the name of the variable, only used for array emission
/// @returns true if the type is emitted
bool EmitPackedType(type::Type* type, const std::string& name);
/// Handles generating a struct declaration /// Handles generating a struct declaration
/// @param str the struct to generate /// @param str the struct to generate
/// @returns true if the struct is emitted /// @returns true if the struct is emitted
@ -266,6 +273,16 @@ class GeneratorImpl : public TextGenerator {
return program_->TypeOf(expr); return program_->TypeOf(expr);
} }
// A pair of byte size and alignment `uint32_t`s.
struct SizeAndAlign {
uint32_t size;
uint32_t align;
};
/// @returns the MSL packed type size and alignment in bytes for the given
/// type.
SizeAndAlign MslPackedTypeSizeAndAlign(type::Type* ty);
ScopeStack<const semantic::Variable*> global_variables_; ScopeStack<const semantic::Variable*> global_variables_;
Symbol current_ep_sym_; Symbol current_ep_sym_;
bool generating_entry_point_ = false; bool generating_entry_point_ = false;

View File

@ -170,29 +170,172 @@ TEST_F(MslGeneratorImplTest, EmitType_StructDecl) {
)"); )");
} }
/// TODO(bclayton): Add tests for vector, matrix, array and nested structures. TEST_F(MslGeneratorImplTest, EmitType_Struct_Layout_NonComposites) {
TEST_F(MslGeneratorImplTest, EmitType_Struct_InjectPadding) {
auto* s = Structure( auto* s = Structure(
"S", { "S", {
Member("a", ty.i32(), {MemberSize(32)}), Member("a", ty.i32(), {MemberSize(32)}),
Member("b", ty.f32()), Member("b", ty.f32(), {MemberAlign(128), MemberSize(128)}),
Member("c", ty.f32(), {MemberAlign(128), MemberSize(128)}), Member("c", ty.vec2<f32>()),
Member("d", ty.u32()),
Member("e", ty.vec3<f32>()),
Member("f", ty.u32()),
Member("g", ty.vec4<f32>()),
Member("h", ty.u32()),
Member("i", ty.mat2x2<f32>()),
Member("j", ty.u32()),
Member("k", ty.mat2x3<f32>()),
Member("l", ty.u32()),
Member("m", ty.mat2x4<f32>()),
Member("n", ty.u32()),
Member("o", ty.mat3x2<f32>()),
Member("p", ty.u32()),
Member("q", ty.mat3x3<f32>()),
Member("r", ty.u32()),
Member("s", ty.mat3x4<f32>()),
Member("t", ty.u32()),
Member("u", ty.mat4x2<f32>()),
Member("v", ty.u32()),
Member("w", ty.mat4x3<f32>()),
Member("x", ty.u32()),
Member("y", ty.mat4x4<f32>()),
Member("z", ty.f32()),
}); });
Global("G", s, ast::StorageClass::kStorage);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitStructType(s)) << gen.error(); ASSERT_TRUE(gen.EmitStructType(s)) << gen.error();
EXPECT_EQ(gen.result(), R"(struct S { EXPECT_EQ(gen.result(), R"(struct S {
int a; /* 0x0000 */ int a;
int8_t pad_0[28]; /* 0x0004 */ int8_t _tint_pad_0[124];
float b; /* 0x0080 */ float b;
int8_t pad_1[92]; /* 0x0084 */ int8_t _tint_pad_1[124];
float c; /* 0x0100 */ packed_float2 c;
int8_t pad_2[124]; /* 0x0108 */ uint d;
/* 0x010c */ int8_t _tint_pad_2[4];
/* 0x0110 */ packed_float3 e;
/* 0x011c */ uint f;
/* 0x0120 */ packed_float4 g;
/* 0x0130 */ uint h;
/* 0x0134 */ int8_t _tint_pad_3[4];
/* 0x0138 */ float2x2 i;
/* 0x0148 */ uint j;
/* 0x014c */ int8_t _tint_pad_4[4];
/* 0x0150 */ float2x3 k;
/* 0x0170 */ uint l;
/* 0x0174 */ int8_t _tint_pad_5[12];
/* 0x0180 */ float2x4 m;
/* 0x01a0 */ uint n;
/* 0x01a4 */ int8_t _tint_pad_6[4];
/* 0x01a8 */ float3x2 o;
/* 0x01c0 */ uint p;
/* 0x01c4 */ int8_t _tint_pad_7[12];
/* 0x01d0 */ float3x3 q;
/* 0x0200 */ uint r;
/* 0x0204 */ int8_t _tint_pad_8[12];
/* 0x0210 */ float3x4 s;
/* 0x0240 */ uint t;
/* 0x0244 */ int8_t _tint_pad_9[4];
/* 0x0248 */ float4x2 u;
/* 0x0268 */ uint v;
/* 0x026c */ int8_t _tint_pad_10[4];
/* 0x0270 */ float4x3 w;
/* 0x02b0 */ uint x;
/* 0x02b4 */ int8_t _tint_pad_11[12];
/* 0x02c0 */ float4x4 y;
/* 0x0300 */ float z;
/* 0x0304 */ int8_t _tint_pad_12[124];
}; };
)"); )");
} }
TEST_F(MslGeneratorImplTest, EmitType_Struct_Layout_Structures) {
// inner_x: size(1024), align(512)
auto* inner_x =
Structure("inner_x", {
Member("a", ty.i32()),
Member("b", ty.f32(), {MemberAlign(512)}),
});
// inner_y: size(516), align(4)
auto* inner_y =
Structure("inner_y", {
Member("a", ty.i32(), {MemberSize(512)}),
Member("b", ty.f32()),
});
auto* s = Structure("S", {
Member("a", ty.i32()),
Member("b", inner_x),
Member("c", ty.f32()),
Member("d", inner_y),
Member("e", ty.f32()),
});
Global("G", s, ast::StorageClass::kStorage);
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitStructType(s)) << gen.error();
EXPECT_EQ(gen.result(), R"(struct S {
/* 0x0000 */ int a;
/* 0x0004 */ int8_t _tint_pad_0[508];
/* 0x0200 */ inner_x b;
/* 0x0600 */ float c;
/* 0x0604 */ inner_y d;
/* 0x0808 */ float e;
/* 0x080c */ int8_t _tint_pad_1[500];
};
)");
}
TEST_F(MslGeneratorImplTest, EmitType_Struct_Layout_ArrayDefaultStride) {
// inner: size(1024), align(512)
auto* inner =
Structure("inner", {
Member("a", ty.i32()),
Member("b", ty.f32(), {MemberAlign(512)}),
});
// array_x: size(28), align(4)
auto* array_x = ty.array<f32, 7>();
// array_y: size(4096), align(512)
auto* array_y = ty.array(inner, 4);
// array_z: size(4), align(4)
auto* array_z = ty.array<f32>();
auto* s = Structure("S", {
Member("a", ty.i32()),
Member("b", array_x),
Member("c", ty.f32()),
Member("d", array_y),
Member("e", ty.f32()),
Member("f", array_z),
});
Global("G", s, ast::StorageClass::kStorage);
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitStructType(s)) << gen.error();
EXPECT_EQ(gen.result(), R"(struct S {
/* 0x0000 */ int a;
/* 0x0004 */ float b[7];
/* 0x0020 */ float c;
/* 0x0024 */ int8_t _tint_pad_0[476];
/* 0x0200 */ inner d[4];
/* 0x1200 */ float e;
/* 0x1204 */ float f[1];
/* 0x1208 */ int8_t _tint_pad_1[504];
};
)");
}
// TODO(crbug.com/tint/649): Add tests for array with explicit stride.
// TODO(dsinclair): How to translate [[block]] // TODO(dsinclair): How to translate [[block]]
TEST_F(MslGeneratorImplTest, DISABLED_EmitType_Struct_WithDecoration) { TEST_F(MslGeneratorImplTest, DISABLED_EmitType_Struct_WithDecoration) {
auto* s = Structure("S", auto* s = Structure("S",
@ -202,12 +345,14 @@ TEST_F(MslGeneratorImplTest, DISABLED_EmitType_Struct_WithDecoration) {
}, },
{create<ast::StructBlockDecoration>()}); {create<ast::StructBlockDecoration>()});
Global("G", s, ast::StorageClass::kStorage);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitType(s, "")) << gen.error(); ASSERT_TRUE(gen.EmitType(s, "")) << gen.error();
EXPECT_EQ(gen.result(), R"(struct { EXPECT_EQ(gen.result(), R"(struct {
int a; /* 0x0000 */ int a;
float b; /* 0x0004 */ float b;
})"); })");
} }