Make ArrayLengthFromUniform transform emit a valid UBO

The UBO must have a stride that is a multiple of 16 bytes.
Note that this change was part of https://dawn-review.googlesource.com/c/tint/+/56780
but the CL was reverted because it broke Dawn. This CL relands part of
the change, and adds the macro TINT_EXPECTS_UBOS_TO_BE_MULTIPLE_OF_16 so
that Dawn can conditionally compile against it.

Bug: tint:984
Bug: tint:643
Change-Id: I303b3fe81ff97c4933c489736d5d5432a59ce9b7
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/57921
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
Antonio Maiorano 2021-07-14 17:28:01 +00:00 committed by Tint LUCI CQ
parent 4d22c1dc2f
commit de2b7db244
14 changed files with 106 additions and 42 deletions

View File

@ -59,4 +59,7 @@
#include "src/writer/hlsl/generator.h"
#endif // TINT_BUILD_HLSL_WRITER
// TODO(crbug/984): Remove once Dawn builds with this flag
#define TINT_EXPECTS_UBOS_TO_BE_MULTIPLE_OF_16
#endif // INCLUDE_TINT_TINT_H_

View File

@ -67,11 +67,16 @@ void ArrayLengthFromUniform::Run(CloneContext& ctx,
ast::Variable* buffer_size_ubo = nullptr;
auto get_ubo = [&]() {
if (!buffer_size_ubo) {
// Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
// We do this because UBOs require an element stride that is 16-byte
// aligned.
auto* buffer_size_struct = ctx.dst->Structure(
ctx.dst->Sym(),
{ctx.dst->Member(
kBufferSizeMemberName,
ctx.dst->ty.array(ctx.dst->ty.u32(), max_buffer_size_index + 1))},
ctx.dst->ty.array(ctx.dst->ty.vec4(ctx.dst->ty.u32()),
(max_buffer_size_index / 4) + 1))},
ast::DecorationList{ctx.dst->create<ast::StructBlockDecoration>()});
buffer_size_ubo = ctx.dst->Global(
ctx.dst->Sym(), ctx.dst->ty.Of(buffer_size_struct),
@ -99,18 +104,20 @@ void ArrayLengthFromUniform::Run(CloneContext& ctx,
// Get the storage buffer that contains the runtime array.
// We assume that the argument to `arrayLength` has the form
// `&resource.array`, which requires that `InlinePointerLets` and `Simplify`
// have been run before this transform.
// `&resource.array`, which requires that `InlinePointerLets` and
// `Simplify` have been run before this transform.
auto* param = call_expr->params()[0]->As<ast::UnaryOpExpression>();
if (!param || param->op() != ast::UnaryOp::kAddressOf) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "expected form of arrayLength argument to be &resource.array";
<< "expected form of arrayLength argument to be "
"&resource.array";
break;
}
auto* accessor = param->expr()->As<ast::MemberAccessorExpression>();
if (!accessor) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "expected form of arrayLength argument to be &resource.array";
<< "expected form of arrayLength argument to be "
"&resource.array";
break;
}
auto* storage_buffer_expr = accessor->structure();
@ -118,7 +125,8 @@ void ArrayLengthFromUniform::Run(CloneContext& ctx,
sem.Get(storage_buffer_expr)->As<sem::VariableUser>();
if (!storage_buffer_sem) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "expected form of arrayLength argument to be &resource.array";
<< "expected form of arrayLength argument to be "
"&resource.array";
break;
}
@ -135,9 +143,13 @@ void ArrayLengthFromUniform::Run(CloneContext& ctx,
}
// Load the total storage buffer size from the UBO.
auto* total_storage_buffer_size = ctx.dst->IndexAccessor(
uint32_t array_index = idx_itr->second / 4;
auto* vec_expr = ctx.dst->IndexAccessor(
ctx.dst->MemberAccessor(get_ubo()->symbol(), kBufferSizeMemberName),
idx_itr->second);
array_index);
uint32_t vec_index = idx_itr->second % 4;
auto* total_storage_buffer_size =
ctx.dst->IndexAccessor(vec_expr, vec_index);
// Calculate actual array length
// total_storage_buffer_size - array_offset

View File

@ -81,7 +81,7 @@ fn main() {
auto* expect = R"(
[[block]]
struct tint_symbol {
buffer_size : array<u32, 1>;
buffer_size : array<vec4<u32>, 1>;
};
[[group(0), binding(30)]] var<uniform> tint_symbol_1 : tint_symbol;
@ -96,7 +96,7 @@ struct SB {
[[stage(compute), workgroup_size(1)]]
fn main() {
var len : u32 = ((tint_symbol_1.buffer_size[0u] - 4u) / 4u);
var len : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
}
)";
@ -134,7 +134,7 @@ fn main() {
auto* expect = R"(
[[block]]
struct tint_symbol {
buffer_size : array<u32, 1>;
buffer_size : array<vec4<u32>, 1>;
};
[[group(0), binding(30)]] var<uniform> tint_symbol_1 : tint_symbol;
@ -150,7 +150,7 @@ struct SB {
[[stage(compute), workgroup_size(1)]]
fn main() {
var len : u32 = ((tint_symbol_1.buffer_size[0u] - 8u) / 64u);
var len : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 8u) / 64u);
}
)";
@ -175,29 +175,48 @@ struct SB1 {
x : i32;
arr1 : array<i32>;
};
[[block]]
struct SB2 {
x : i32;
arr2 : array<vec4<f32>>;
};
[[block]]
struct SB3 {
x : i32;
arr3 : array<vec4<f32>>;
};
[[block]]
struct SB4 {
x : i32;
arr4 : array<vec4<f32>>;
};
[[block]]
struct SB5 {
x : i32;
arr5 : array<vec4<f32>>;
};
[[group(0), binding(2)]] var<storage, read> sb1 : SB1;
[[group(1), binding(2)]] var<storage, read> sb2 : SB2;
[[group(2), binding(2)]] var<storage, read> sb3 : SB3;
[[group(3), binding(2)]] var<storage, read> sb4 : SB4;
[[group(4), binding(2)]] var<storage, read> sb5 : SB5;
[[stage(compute), workgroup_size(1)]]
fn main() {
var len1 : u32 = arrayLength(&(sb1.arr1));
var len2 : u32 = arrayLength(&(sb2.arr2));
var x : u32 = (len1 + len2);
var len3 : u32 = arrayLength(&(sb3.arr3));
var len4 : u32 = arrayLength(&(sb4.arr4));
var len5 : u32 = arrayLength(&(sb5.arr5));
var x : u32 = (len1 + len2 + len3 + len4 + len5);
}
)";
auto* expect = R"(
[[block]]
struct tint_symbol {
buffer_size : array<u32, 2>;
buffer_size : array<vec4<u32>, 2>;
};
[[group(0), binding(30)]] var<uniform> tint_symbol_1 : tint_symbol;
@ -214,21 +233,51 @@ struct SB2 {
arr2 : array<vec4<f32>>;
};
[[block]]
struct SB3 {
x : i32;
arr3 : array<vec4<f32>>;
};
[[block]]
struct SB4 {
x : i32;
arr4 : array<vec4<f32>>;
};
[[block]]
struct SB5 {
x : i32;
arr5 : array<vec4<f32>>;
};
[[group(0), binding(2)]] var<storage, read> sb1 : SB1;
[[group(1), binding(2)]] var<storage, read> sb2 : SB2;
[[group(2), binding(2)]] var<storage, read> sb3 : SB3;
[[group(3), binding(2)]] var<storage, read> sb4 : SB4;
[[group(4), binding(2)]] var<storage, read> sb5 : SB5;
[[stage(compute), workgroup_size(1)]]
fn main() {
var len1 : u32 = ((tint_symbol_1.buffer_size[0u] - 4u) / 4u);
var len2 : u32 = ((tint_symbol_1.buffer_size[1u] - 16u) / 16u);
var x : u32 = (len1 + len2);
var len1 : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
var len2 : u32 = ((tint_symbol_1.buffer_size[0u][1u] - 16u) / 16u);
var len3 : u32 = ((tint_symbol_1.buffer_size[0u][2u] - 16u) / 16u);
var len4 : u32 = ((tint_symbol_1.buffer_size[0u][3u] - 16u) / 16u);
var len5 : u32 = ((tint_symbol_1.buffer_size[1u][0u] - 16u) / 16u);
var x : u32 = ((((len1 + len2) + len3) + len4) + len5);
}
)";
ArrayLengthFromUniform::Config cfg({0, 30u});
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2u}, 0);
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{1u, 2u}, 1);
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{2u, 2u}, 2);
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{3u, 2u}, 3);
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{4u, 2u}, 4);
DataMap data;
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));

View File

@ -2,14 +2,14 @@
using namespace metal;
struct tint_symbol_1 {
/* 0x0000 */ uint buffer_size[1];
/* 0x0000 */ uint4 buffer_size[1];
};
struct S {
/* 0x0000 */ int a[1];
};
kernel void tint_symbol(constant tint_symbol_1& tint_symbol_2 [[buffer(30)]]) {
uint const l1 = ((tint_symbol_2.buffer_size[0u] - 0u) / 4u);
uint const l1 = ((tint_symbol_2.buffer_size[0u][0u] - 0u) / 4u);
return;
}

View File

@ -2,15 +2,15 @@
using namespace metal;
struct tint_symbol_1 {
/* 0x0000 */ uint buffer_size[1];
/* 0x0000 */ uint4 buffer_size[1];
};
struct S {
/* 0x0000 */ int a[1];
};
kernel void tint_symbol(constant tint_symbol_1& tint_symbol_2 [[buffer(30)]]) {
uint const l1 = ((tint_symbol_2.buffer_size[0u] - 0u) / 4u);
uint const l2 = ((tint_symbol_2.buffer_size[0u] - 0u) / 4u);
uint const l1 = ((tint_symbol_2.buffer_size[0u][0u] - 0u) / 4u);
uint const l2 = ((tint_symbol_2.buffer_size[0u][0u] - 0u) / 4u);
return;
}

View File

@ -2,14 +2,14 @@
using namespace metal;
struct tint_symbol_1 {
/* 0x0000 */ uint buffer_size[1];
/* 0x0000 */ uint4 buffer_size[1];
};
struct S {
/* 0x0000 */ int a[1];
};
kernel void tint_symbol(constant tint_symbol_1& tint_symbol_2 [[buffer(30)]]) {
uint const l1 = ((tint_symbol_2.buffer_size[0u] - 0u) / 4u);
uint const l1 = ((tint_symbol_2.buffer_size[0u][0u] - 0u) / 4u);
return;
}

View File

@ -2,14 +2,14 @@
using namespace metal;
struct tint_symbol_1 {
/* 0x0000 */ uint buffer_size[1];
/* 0x0000 */ uint4 buffer_size[1];
};
struct S {
/* 0x0000 */ int a[1];
};
kernel void tint_symbol(constant tint_symbol_1& tint_symbol_2 [[buffer(30)]]) {
uint const l1 = ((tint_symbol_2.buffer_size[0u] - 0u) / 4u);
uint const l1 = ((tint_symbol_2.buffer_size[0u][0u] - 0u) / 4u);
return;
}

View File

@ -2,14 +2,14 @@
using namespace metal;
struct tint_symbol_1 {
/* 0x0000 */ uint buffer_size[1];
/* 0x0000 */ uint4 buffer_size[1];
};
struct S {
/* 0x0000 */ int a[1];
};
kernel void tint_symbol(constant tint_symbol_1& tint_symbol_2 [[buffer(30)]]) {
uint const l1 = ((tint_symbol_2.buffer_size[0u] - 0u) / 4u);
uint const l1 = ((tint_symbol_2.buffer_size[0u][0u] - 0u) / 4u);
return;
}

View File

@ -2,7 +2,7 @@
using namespace metal;
struct tint_symbol_2 {
/* 0x0000 */ uint buffer_size[2];
/* 0x0000 */ uint4 buffer_size[1];
};
struct SB_RO {
/* 0x0000 */ int arg_0[1];
@ -12,7 +12,7 @@ struct tint_symbol {
};
void arrayLength_1588cd(constant tint_symbol_2& tint_symbol_3) {
uint res = ((tint_symbol_3.buffer_size[1u] - 0u) / 4u);
uint res = ((tint_symbol_3.buffer_size[0u][1u] - 0u) / 4u);
}
vertex tint_symbol vertex_main(constant tint_symbol_2& tint_symbol_3 [[buffer(30)]]) {

View File

@ -2,7 +2,7 @@
using namespace metal;
struct tint_symbol_2 {
/* 0x0000 */ uint buffer_size[1];
/* 0x0000 */ uint4 buffer_size[1];
};
struct SB_RW {
/* 0x0000 */ int arg_0[1];
@ -12,7 +12,7 @@ struct tint_symbol {
};
void arrayLength_61b1c7(constant tint_symbol_2& tint_symbol_3) {
uint res = ((tint_symbol_3.buffer_size[0u] - 0u) / 4u);
uint res = ((tint_symbol_3.buffer_size[0u][0u] - 0u) / 4u);
}
vertex tint_symbol vertex_main(constant tint_symbol_2& tint_symbol_3 [[buffer(30)]]) {

View File

@ -2,7 +2,7 @@
using namespace metal;
struct tint_symbol_2 {
/* 0x0000 */ uint buffer_size[2];
/* 0x0000 */ uint4 buffer_size[1];
};
struct SB_RO {
/* 0x0000 */ float arg_0[1];
@ -12,7 +12,7 @@ struct tint_symbol {
};
void arrayLength_a0f5ca(constant tint_symbol_2& tint_symbol_3) {
uint res = ((tint_symbol_3.buffer_size[1u] - 0u) / 4u);
uint res = ((tint_symbol_3.buffer_size[0u][1u] - 0u) / 4u);
}
vertex tint_symbol vertex_main(constant tint_symbol_2& tint_symbol_3 [[buffer(30)]]) {

View File

@ -2,7 +2,7 @@
using namespace metal;
struct tint_symbol_2 {
/* 0x0000 */ uint buffer_size[1];
/* 0x0000 */ uint4 buffer_size[1];
};
struct SB_RW {
/* 0x0000 */ float arg_0[1];
@ -12,7 +12,7 @@ struct tint_symbol {
};
void arrayLength_cdd123(constant tint_symbol_2& tint_symbol_3) {
uint res = ((tint_symbol_3.buffer_size[0u] - 0u) / 4u);
uint res = ((tint_symbol_3.buffer_size[0u][0u] - 0u) / 4u);
}
vertex tint_symbol vertex_main(constant tint_symbol_2& tint_symbol_3 [[buffer(30)]]) {

View File

@ -2,7 +2,7 @@
using namespace metal;
struct tint_symbol_2 {
/* 0x0000 */ uint buffer_size[2];
/* 0x0000 */ uint4 buffer_size[1];
};
struct SB_RO {
/* 0x0000 */ uint arg_0[1];
@ -12,7 +12,7 @@ struct tint_symbol {
};
void arrayLength_cfca0a(constant tint_symbol_2& tint_symbol_3) {
uint res = ((tint_symbol_3.buffer_size[1u] - 0u) / 4u);
uint res = ((tint_symbol_3.buffer_size[0u][1u] - 0u) / 4u);
}
vertex tint_symbol vertex_main(constant tint_symbol_2& tint_symbol_3 [[buffer(30)]]) {

View File

@ -2,7 +2,7 @@
using namespace metal;
struct tint_symbol_2 {
/* 0x0000 */ uint buffer_size[1];
/* 0x0000 */ uint4 buffer_size[1];
};
struct SB_RW {
/* 0x0000 */ uint arg_0[1];
@ -12,7 +12,7 @@ struct tint_symbol {
};
void arrayLength_eb510f(constant tint_symbol_2& tint_symbol_3) {
uint res = ((tint_symbol_3.buffer_size[0u] - 0u) / 4u);
uint res = ((tint_symbol_3.buffer_size[0u][0u] - 0u) / 4u);
}
vertex tint_symbol vertex_main(constant tint_symbol_2& tint_symbol_3 [[buffer(30)]]) {