Allow non-struct buffer store types

For SPIR-V, wrap non-struct types in structs in the
AddSpirvBlockDecoration transform.

For MSL, wrap runtime-sized arrays in structs in the
ModuleScopeVarToEntryPointParam transform.

Bug: tint:1372
Change-Id: Icced5d77b4538e816aa9fab57a634a9f4c52fdab
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/76162
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
James Price
2022-01-19 15:55:56 +00:00
parent c3cec4d901
commit 7395e29e70
132 changed files with 2309 additions and 143 deletions

View File

@@ -459,42 +459,6 @@ bool Resolver::ValidateGlobalVariable(const sem::Variable* var) {
return false;
}
switch (var->StorageClass()) {
case ast::StorageClass::kStorage: {
// https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables
// A variable in the storage storage class is a storage buffer variable.
// Its store type must be a host-shareable structure type with block
// attribute, satisfying the storage class constraints.
auto* str = var->Type()->UnwrapRef()->As<sem::Struct>();
if (!str) {
AddError(
"variables declared in the <storage> storage class must be of a "
"structure type",
decl->source);
return false;
}
break;
}
case ast::StorageClass::kUniform: {
// https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables
// A variable in the uniform storage class is a uniform buffer variable.
// Its store type must be a host-shareable structure type with block
// attribute, satisfying the storage class constraints.
auto* str = var->Type()->UnwrapRef()->As<sem::Struct>();
if (!str) {
AddError(
"variables declared in the <uniform> storage class must be of a "
"structure type",
decl->source);
return false;
}
break;
}
default:
break;
}
if (!decl->is_const) {
if (!ValidateAtomicVariable(var)) {
return false;
@@ -580,14 +544,6 @@ bool Resolver::ValidateVariable(const sem::Variable* var) {
return false;
}
if (auto* r = storage_ty->As<sem::Array>()) {
if (r->IsRuntimeSized()) {
AddError("runtime arrays may only appear as the last member of a struct",
decl->source);
return false;
}
}
if (auto* r = storage_ty->As<sem::MultisampledTexture>()) {
if (r->dim() != ast::TextureDimension::k2d) {
AddError("only 2d multisampled textures are supported", decl->source);

View File

@@ -92,6 +92,40 @@ note: while analysing structure member S.m
}
TEST_F(ResolverStorageClassValidationTest, StorageBufferBool) {
// var<storage> g : bool;
Global(Source{{56, 78}}, "g", ty.bool_(), ast::StorageClass::kStorage,
ast::DecorationList{
create<ast::BindingDecoration>(0),
create<ast::GroupDecoration>(0),
});
ASSERT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
R"(56:78 error: Type 'bool' cannot be used in storage class 'storage' as it is non-host-shareable
56:78 note: while instantiating variable g)");
}
TEST_F(ResolverStorageClassValidationTest, StorageBufferPointer) {
// var<storage> g : ptr<private, f32>;
Global(Source{{56, 78}}, "g",
ty.pointer(ty.f32(), ast::StorageClass::kPrivate),
ast::StorageClass::kStorage,
ast::DecorationList{
create<ast::BindingDecoration>(0),
create<ast::GroupDecoration>(0),
});
ASSERT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
R"(56:78 error: Type 'ptr<private, f32, read_write>' cannot be used in storage class 'storage' as it is non-host-shareable
56:78 note: while instantiating variable g)");
}
TEST_F(ResolverStorageClassValidationTest, StorageBufferIntScalar) {
// var<storage> g : i32;
Global(Source{{56, 78}}, "g", ty.i32(), ast::StorageClass::kStorage,
ast::DecorationList{
@@ -99,14 +133,10 @@ TEST_F(ResolverStorageClassValidationTest, StorageBufferBool) {
create<ast::GroupDecoration>(0),
});
ASSERT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
R"(56:78 error: variables declared in the <storage> storage class must be of a structure type)");
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverStorageClassValidationTest, StorageBufferPointer) {
TEST_F(ResolverStorageClassValidationTest, StorageBufferVector) {
// var<storage> g : vec4<f32>;
Global(Source{{56, 78}}, "g", ty.vec4<f32>(), ast::StorageClass::kStorage,
ast::DecorationList{
@@ -114,11 +144,7 @@ TEST_F(ResolverStorageClassValidationTest, StorageBufferPointer) {
create<ast::GroupDecoration>(0),
});
ASSERT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
R"(56:78 error: variables declared in the <storage> storage class must be of a structure type)");
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverStorageClassValidationTest, StorageBufferArray) {
@@ -132,11 +158,7 @@ TEST_F(ResolverStorageClassValidationTest, StorageBufferArray) {
create<ast::GroupDecoration>(0),
});
ASSERT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
R"(56:78 error: variables declared in the <storage> storage class must be of a structure type)");
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverStorageClassValidationTest, StorageBufferBoolAlias) {
@@ -240,8 +262,10 @@ TEST_F(ResolverStorageClassValidationTest, UniformBufferBool) {
}
TEST_F(ResolverStorageClassValidationTest, UniformBufferPointer) {
// var<uniform> g : vec4<f32>;
Global(Source{{56, 78}}, "g", ty.vec4<f32>(), ast::StorageClass::kUniform,
// var<uniform> g : ptr<private, f32>;
Global(Source{{56, 78}}, "g",
ty.pointer(ty.f32(), ast::StorageClass::kPrivate),
ast::StorageClass::kUniform,
ast::DecorationList{
create<ast::BindingDecoration>(0),
create<ast::GroupDecoration>(0),
@@ -251,7 +275,30 @@ TEST_F(ResolverStorageClassValidationTest, UniformBufferPointer) {
EXPECT_EQ(
r()->error(),
R"(56:78 error: variables declared in the <uniform> storage class must be of a structure type)");
R"(56:78 error: Type 'ptr<private, f32, read_write>' cannot be used in storage class 'uniform' as it is non-host-shareable
56:78 note: while instantiating variable g)");
}
TEST_F(ResolverStorageClassValidationTest, UniformBufferIntScalar) {
// var<uniform> g : i32;
Global(Source{{56, 78}}, "g", ty.i32(), ast::StorageClass::kUniform,
ast::DecorationList{
create<ast::BindingDecoration>(0),
create<ast::GroupDecoration>(0),
});
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverStorageClassValidationTest, UniformBufferVector) {
// var<uniform> g : vec4<f32>;
Global(Source{{56, 78}}, "g", ty.vec4<f32>(), ast::StorageClass::kUniform,
ast::DecorationList{
create<ast::BindingDecoration>(0),
create<ast::GroupDecoration>(0),
});
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverStorageClassValidationTest, UniformBufferArray) {
@@ -264,11 +311,7 @@ TEST_F(ResolverStorageClassValidationTest, UniformBufferArray) {
create<ast::GroupDecoration>(0),
});
ASSERT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
R"(56:78 error: variables declared in the <uniform> storage class must be of a structure type)");
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverStorageClassValidationTest, UniformBufferBoolAlias) {

View File

@@ -52,9 +52,9 @@ void AddSpirvBlockDecoration::Run(CloneContext& ctx, const DataMap&, DataMap&) {
}
}
// A map from a struct in the source program to a block-decorated wrapper that
// A map from a type in the source program to a block-decorated wrapper that
// contains it in the destination program.
std::unordered_map<const sem::Struct*, const ast::Struct*> wrapper_structs;
std::unordered_map<const sem::Type*, const ast::Struct*> wrapper_structs;
// Process global variables that are buffers.
for (auto* var : ctx.src->AST().GlobalVariables()) {
@@ -64,40 +64,33 @@ void AddSpirvBlockDecoration::Run(CloneContext& ctx, const DataMap&, DataMap&) {
continue;
}
auto* str = sem.Get<sem::Struct>(var->type);
if (!str) {
// TODO(jrprice): We'll need to wrap these too, when WGSL supports this.
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "non-struct buffer types are not yet supported";
continue;
}
auto* ty = sem.Get(var->type);
auto* str = ty->As<sem::Struct>();
if (!str || nested_structs.count(str)) {
const char* kMemberName = "inner";
if (nested_structs.count(str)) {
const char* kInnerStructMemberName = "inner";
// This struct is nested somewhere else, so we need to wrap it first.
auto* wrapper = utils::GetOrCreate(wrapper_structs, str, [&]() {
// This is a non-struct or a struct that is nested somewhere else, so we
// need to wrap it first.
auto* wrapper = utils::GetOrCreate(wrapper_structs, ty, [&]() {
auto* block =
ctx.dst->ASTNodes().Create<SpirvBlockDecoration>(ctx.dst->ID());
auto wrapper_name =
ctx.src->Symbols().NameFor(str->Declaration()->name) + "_block";
auto wrapper_name = ctx.src->Symbols().NameFor(var->symbol) + "_block";
auto* ret = ctx.dst->create<ast::Struct>(
ctx.dst->Symbols().New(wrapper_name),
ast::StructMemberList{ctx.dst->Member(kInnerStructMemberName,
CreateASTTypeFor(ctx, str))},
ast::StructMemberList{
ctx.dst->Member(kMemberName, CreateASTTypeFor(ctx, ty))},
ast::DecorationList{block});
ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), str->Declaration(),
ret);
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), var, ret);
return ret;
});
ctx.Replace(var->type, ctx.dst->ty.Of(wrapper));
// Insert a member accessor to get the original struct from the wrapper at
// Insert a member accessor to get the original type from the wrapper at
// any usage of the original variable.
for (auto* user : sem_var->Users()) {
ctx.Replace(user->Declaration(),
ctx.dst->MemberAccessor(ctx.Clone(var->symbol),
kInnerStructMemberName));
ctx.Replace(
user->Declaration(),
ctx.dst->MemberAccessor(ctx.Clone(var->symbol), kMemberName));
}
} else {
// Add a block decoration to this struct directly.

View File

@@ -73,7 +73,98 @@ fn main() -> S {
EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockDecorationTest, Basic) {
TEST_F(AddSpirvBlockDecorationTest, BasicScalar) {
auto* src = R"(
[[group(0), binding(0)]]
var<uniform> u : f32;
[[stage(fragment)]]
fn main() {
let f = u;
}
)";
auto* expect = R"(
[[internal(spirv_block)]]
struct u_block {
inner : f32;
};
[[group(0), binding(0)]] var<uniform> u : u_block;
[[stage(fragment)]]
fn main() {
let f = u.inner;
}
)";
auto got = Run<AddSpirvBlockDecoration>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockDecorationTest, BasicArray) {
auto* src = R"(
[[group(0), binding(0)]]
var<uniform> u : array<vec4<f32>, 4u>;
[[stage(fragment)]]
fn main() {
let a = u;
}
)";
auto* expect = R"(
[[internal(spirv_block)]]
struct u_block {
inner : array<vec4<f32>, 4u>;
};
[[group(0), binding(0)]] var<uniform> u : u_block;
[[stage(fragment)]]
fn main() {
let a = u.inner;
}
)";
auto got = Run<AddSpirvBlockDecoration>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockDecorationTest, BasicArray_Alias) {
auto* src = R"(
type Numbers = array<vec4<f32>, 4u>;
[[group(0), binding(0)]]
var<uniform> u : Numbers;
[[stage(fragment)]]
fn main() {
let a = u;
}
)";
auto* expect = R"(
type Numbers = array<vec4<f32>, 4u>;
[[internal(spirv_block)]]
struct u_block {
inner : array<vec4<f32>, 4u>;
};
[[group(0), binding(0)]] var<uniform> u : u_block;
[[stage(fragment)]]
fn main() {
let a = u.inner;
}
)";
auto got = Run<AddSpirvBlockDecoration>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockDecorationTest, BasicStruct) {
auto* src = R"(
struct S {
f : f32;
@@ -174,11 +265,6 @@ struct Inner {
f : f32;
};
[[internal(spirv_block)]]
struct Inner_block {
inner : Inner;
};
[[internal(spirv_block)]]
struct Outer {
i : Inner;
@@ -186,7 +272,12 @@ struct Outer {
[[group(0), binding(0)]] var<uniform> u0 : Outer;
[[group(0), binding(1)]] var<uniform> u1 : Inner_block;
[[internal(spirv_block)]]
struct u1_block {
inner : Inner;
};
[[group(0), binding(1)]] var<uniform> u1 : u1_block;
[[stage(fragment)]]
fn main() {
@@ -226,18 +317,18 @@ struct Inner {
f : f32;
};
[[internal(spirv_block)]]
struct Inner_block {
inner : Inner;
};
struct Outer {
i : Inner;
};
var<private> p : Outer;
[[group(0), binding(1)]] var<uniform> u : Inner_block;
[[internal(spirv_block)]]
struct u_block {
inner : Inner;
};
[[group(0), binding(1)]] var<uniform> u : u_block;
[[stage(fragment)]]
fn main() {
@@ -282,11 +373,6 @@ struct Inner {
f : f32;
};
[[internal(spirv_block)]]
struct Inner_block {
inner : Inner;
};
[[internal(spirv_block)]]
struct S {
i : Inner;
@@ -294,9 +380,14 @@ struct S {
[[group(0), binding(0)]] var<uniform> u0 : S;
[[group(0), binding(1)]] var<uniform> u1 : Inner_block;
[[internal(spirv_block)]]
struct u1_block {
inner : Inner;
};
[[group(0), binding(2)]] var<uniform> u2 : Inner_block;
[[group(0), binding(1)]] var<uniform> u1 : u1_block;
[[group(0), binding(2)]] var<uniform> u2 : u1_block;
[[stage(fragment)]]
fn main() {
@@ -332,11 +423,11 @@ struct S {
};
[[internal(spirv_block)]]
struct S_block {
struct u_block {
inner : S;
};
[[group(0), binding(0)]] var<uniform> u : S_block;
[[group(0), binding(0)]] var<uniform> u : u_block;
[[stage(fragment)]]
fn main() {
@@ -375,13 +466,13 @@ struct S {
};
[[internal(spirv_block)]]
struct S_block {
struct u0_block {
inner : S;
};
[[group(0), binding(0)]] var<uniform> u0 : S_block;
[[group(0), binding(0)]] var<uniform> u0 : u0_block;
[[group(0), binding(1)]] var<uniform> u1 : S_block;
[[group(0), binding(1)]] var<uniform> u1 : u0_block;
[[stage(fragment)]]
fn main() {
@@ -427,11 +518,6 @@ struct Inner {
f : f32;
};
[[internal(spirv_block)]]
struct Inner_block {
inner : Inner;
};
type MyInner = Inner;
[[internal(spirv_block)]]
@@ -443,7 +529,12 @@ type MyOuter = Outer;
[[group(0), binding(0)]] var<uniform> u0 : MyOuter;
[[group(0), binding(1)]] var<uniform> u1 : Inner_block;
[[internal(spirv_block)]]
struct u1_block {
inner : Inner;
};
[[group(0), binding(1)]] var<uniform> u1 : u1_block;
[[stage(fragment)]]
fn main() {

View File

@@ -157,6 +157,7 @@ struct ModuleScopeVarToEntryPointParam::State {
for (auto* var : func_sem->TransitivelyReferencedGlobals()) {
auto sc = var->StorageClass();
auto* ty = var->Type()->UnwrapRef();
if (sc == ast::StorageClass::kNone) {
continue;
}
@@ -174,13 +175,15 @@ struct ModuleScopeVarToEntryPointParam::State {
auto new_var_symbol = ctx.dst->Sym();
// Helper to create an AST node for the store type of the variable.
auto store_type = [&]() {
return CreateASTTypeFor(ctx, var->Type()->UnwrapRef());
};
auto store_type = [&]() { return CreateASTTypeFor(ctx, ty); };
// Track whether the new variable is a pointer or not.
bool is_pointer = false;
// Track whether the new variable was wrapped in a struct or not.
bool is_wrapped = false;
const char* kWrappedArrayMemberName = "arr";
if (is_entry_point) {
if (var->Type()->UnwrapRef()->is_handle()) {
// For a texture or sampler variable, redeclare it as an entry point
@@ -200,8 +203,23 @@ struct ModuleScopeVarToEntryPointParam::State {
ast::DisabledValidation::kEntryPointParameter));
attributes.push_back(
ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
auto* param_type = ctx.dst->ty.pointer(
store_type(), sc, var->Declaration()->declared_access);
auto* param_type = store_type();
if (auto* arr = ty->As<sem::Array>();
arr && arr->IsRuntimeSized()) {
// Wrap runtime-sized arrays in structures, so that we can declare
// pointers to them. Ideally we'd just emit the array itself as a
// pointer, but this is not representable in Tint's AST.
CloneStructTypes(ty);
auto* wrapper = ctx.dst->Structure(
ctx.dst->Sym(),
{ctx.dst->Member(kWrappedArrayMemberName, param_type)});
param_type = ctx.dst->ty.Of(wrapper);
is_wrapped = true;
}
param_type = ctx.dst->ty.pointer(
param_type, sc, var->Declaration()->declared_access);
auto* param =
ctx.dst->Param(new_var_symbol, param_type, attributes);
ctx.InsertFront(func_ast->params, param);
@@ -283,6 +301,10 @@ struct ModuleScopeVarToEntryPointParam::State {
expr = ctx.dst->Deref(expr);
}
if (is_wrapped) {
// Get the member from the wrapper structure.
expr = ctx.dst->MemberAccessor(expr, kWrappedArrayMemberName);
}
ctx.Replace(user->Declaration(), expr);
}
}

View File

@@ -232,6 +232,99 @@ fn main([[group(0), binding(0), internal(disable_validation__entry_point_paramet
EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArray) {
auto* src = R"(
[[group(0), binding(0)]]
var<storage> buffer : array<f32>;
[[stage(compute), workgroup_size(1)]]
fn main() {
_ = buffer[0];
}
)";
auto* expect = R"(
struct tint_symbol_1 {
arr : array<f32>;
};
[[stage(compute), workgroup_size(1)]]
fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter), internal(disable_validation__ignore_storage_class)]] tint_symbol : ptr<storage, tint_symbol_1>) {
_ = (*(tint_symbol)).arr[0];
}
)";
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArray_Alias) {
auto* src = R"(
type myarray = array<f32>;
[[group(0), binding(0)]]
var<storage> buffer : myarray;
[[stage(compute), workgroup_size(1)]]
fn main() {
_ = buffer[0];
}
)";
auto* expect = R"(
struct tint_symbol_1 {
arr : array<f32>;
};
type myarray = array<f32>;
[[stage(compute), workgroup_size(1)]]
fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter), internal(disable_validation__ignore_storage_class)]] tint_symbol : ptr<storage, tint_symbol_1>) {
_ = (*(tint_symbol)).arr[0];
}
)";
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_ArrayOfStruct) {
auto* src = R"(
struct S {
f : f32;
};
[[group(0), binding(0)]]
var<storage> buffer : array<S>;
[[stage(compute), workgroup_size(1)]]
fn main() {
_ = buffer[0];
}
)";
auto* expect = R"(
struct S {
f : f32;
};
struct tint_symbol_1 {
arr : array<S>;
};
[[stage(compute), workgroup_size(1)]]
fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter), internal(disable_validation__ignore_storage_class)]] tint_symbol : ptr<storage, tint_symbol_1>) {
_ = (*(tint_symbol)).arr[0];
}
)";
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffers_FunctionCalls) {
auto* src = R"(
struct S {

View File

@@ -2869,14 +2869,6 @@ bool GeneratorImpl::EmitUniformVariable(const sem::Variable* var) {
auto* decl = var->Declaration();
auto binding_point = decl->BindingPoint();
auto* type = var->Type()->UnwrapRef();
auto* str = type->As<sem::Struct>();
if (!str) {
// https://www.w3.org/TR/WGSL/#module-scope-variables
TINT_ICE(Writer, diagnostics_)
<< "variables with uniform storage must be structure";
}
auto name = builder_.Symbols().NameFor(decl->symbol);
line() << "cbuffer cbuffer_" << name << RegisterAndSpace('b', binding_point)
<< " {";
@@ -3513,13 +3505,7 @@ bool GeneratorImpl::EmitType(std::ostream& out,
out << "ByteAddressBuffer";
return true;
case ast::StorageClass::kUniform: {
auto* str = type->As<sem::Struct>();
if (!str) {
// https://www.w3.org/TR/WGSL/#module-scope-variables
TINT_ICE(Writer, diagnostics_)
<< "variables with uniform storage must be structure";
}
auto array_length = (str->Size() + 15) / 16;
auto array_length = (type->Size() + 15) / 16;
out << "uint4 " << name << "[" << array_length << "]";
if (name_printed) {
*name_printed = true;