tint/transform: Handle arrays of complex override lengths

Update CreateASTTypeFor() to handle a potential edge-case described in tint:1764.

We haven't seen this issue happen in production, nor can I find a way to trigger this with the tint executable, but try to handle this before we encounter a nasty bug.

Fixed: tint:1764
Change-Id: I496932955a6fdcbe26eacef8dcd04988f92545a1
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111040
Auto-Submit: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ben Clayton 2022-11-21 19:05:24 +00:00 committed by Dawn LUCI CQ
parent efe9c49819
commit 87bccb74d8
9 changed files with 230 additions and 0 deletions

View File

@ -114,6 +114,19 @@ const ast::Type* Transform::CreateASTTypeFor(CloneContext& ctx, const sem::Type*
return ctx.dst->ty.array(el, count, std::move(attrs)); return ctx.dst->ty.array(el, count, std::move(attrs));
} }
if (auto* override = std::get_if<sem::UnnamedOverrideArrayCount>(&a->Count())) { if (auto* override = std::get_if<sem::UnnamedOverrideArrayCount>(&a->Count())) {
// If the array count is an unnamed (complex) override expression, then its not safe to
// redeclare this type as we'd end up with two types that would not compare equal.
// See crbug.com/tint/1764.
// Look for a type alias for this array.
for (auto* type_decl : ctx.src->AST().TypeDecls()) {
if (auto* alias = type_decl->As<ast::Alias>()) {
if (ty == ctx.src->Sem().Get(alias)) {
// Alias found. Use the alias name to ensure types compare equal.
return ctx.dst->ty.type_name(ctx.Clone(alias->name));
}
}
}
// Array is not aliased. Rebuild the array.
auto* count = ctx.Clone(override->expr->Declaration()); auto* count = ctx.Clone(override->expr->Declaration());
return ctx.dst->ty.array(el, count, std::move(attrs)); return ctx.dst->ty.array(el, count, std::move(attrs));
} }

View File

@ -21,6 +21,8 @@
namespace tint::transform { namespace tint::transform {
namespace { namespace {
using namespace tint::number_suffixes; // NOLINT
// Inherit from Transform so we have access to protected methods // Inherit from Transform so we have access to protected methods
struct CreateASTTypeForTest : public testing::Test, public Transform { struct CreateASTTypeForTest : public testing::Test, public Transform {
ApplyResult Apply(const Program*, const DataMap&, DataMap&) const override { ApplyResult Apply(const Program*, const DataMap&, DataMap&) const override {
@ -95,6 +97,28 @@ TEST_F(CreateASTTypeForTest, ArrayNonImplicitStride) {
EXPECT_EQ(size->value, 2); EXPECT_EQ(size->value, 2);
} }
// crbug.com/tint/1764
TEST_F(CreateASTTypeForTest, AliasedArrayWithComplexOverrideLength) {
// override O = 123;
// type A = array<i32, O*2>;
//
// var<workgroup> W : A;
//
ProgramBuilder b;
auto* arr_len = b.Mul("O", 2_a);
b.Override("O", b.Expr(123_a));
auto* alias = b.Alias("A", b.ty.array(b.ty.i32(), arr_len));
Program program(std::move(b));
auto* arr_ty = program.Sem().Get(alias);
CloneContext ctx(&ast_type_builder, &program, false);
auto* ast_ty = tint::As<ast::TypeName>(CreateASTTypeFor(ctx, arr_ty));
ASSERT_NE(ast_ty, nullptr);
EXPECT_EQ(ast_type_builder.Symbols().NameFor(ast_ty->name), "A");
}
TEST_F(CreateASTTypeForTest, Struct) { TEST_F(CreateASTTypeForTest, Struct) {
auto* str = create([](ProgramBuilder& b) { auto* str = create([](ProgramBuilder& b) {
auto* decl = b.Structure("S", {}); auto* decl = b.Structure("S", {});

View File

@ -0,0 +1,12 @@
// flags: --transform substitute_override
override O = 123;
type A = array<i32, O*2>;
var<workgroup> W : A;
@compute @workgroup_size(1)
fn main() {
let p : ptr<workgroup, A> = &W;
(*p)[0] = 42;
}

View File

@ -0,0 +1,22 @@
groupshared int W[246];
struct tint_symbol_1 {
uint local_invocation_index : SV_GroupIndex;
};
void main_inner(uint local_invocation_index) {
{
for(uint idx = local_invocation_index; (idx < 246u); idx = (idx + 1u)) {
const uint i = idx;
W[i] = 0;
}
}
GroupMemoryBarrierWithGroupSync();
W[0] = 42;
}
[numthreads(1, 1, 1)]
void main(tint_symbol_1 tint_symbol) {
main_inner(tint_symbol.local_invocation_index);
return;
}

View File

@ -0,0 +1,22 @@
groupshared int W[246];
struct tint_symbol_1 {
uint local_invocation_index : SV_GroupIndex;
};
void main_inner(uint local_invocation_index) {
{
for(uint idx = local_invocation_index; (idx < 246u); idx = (idx + 1u)) {
const uint i = idx;
W[i] = 0;
}
}
GroupMemoryBarrierWithGroupSync();
W[0] = 42;
}
[numthreads(1, 1, 1)]
void main(tint_symbol_1 tint_symbol) {
main_inner(tint_symbol.local_invocation_index);
return;
}

View File

@ -0,0 +1,19 @@
#version 310 es
shared int W[246];
void tint_symbol(uint local_invocation_index) {
{
for(uint idx = local_invocation_index; (idx < 246u); idx = (idx + 1u)) {
uint i = idx;
W[i] = 0;
}
}
barrier();
W[0] = 42;
}
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void main() {
tint_symbol(gl_LocalInvocationIndex);
return;
}

View File

@ -0,0 +1,31 @@
#include <metal_stdlib>
using namespace metal;
template<typename T, size_t N>
struct tint_array {
const constant T& operator[](size_t i) const constant { return elements[i]; }
device T& operator[](size_t i) device { return elements[i]; }
const device T& operator[](size_t i) const device { return elements[i]; }
thread T& operator[](size_t i) thread { return elements[i]; }
const thread T& operator[](size_t i) const thread { return elements[i]; }
threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
T elements[N];
};
void tint_symbol_inner(uint local_invocation_index, threadgroup tint_array<int, 246>* const tint_symbol_1) {
for(uint idx = local_invocation_index; (idx < 246u); idx = (idx + 1u)) {
uint const i = idx;
(*(tint_symbol_1))[i] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
(*(tint_symbol_1))[0] = 42;
}
kernel void tint_symbol(uint local_invocation_index [[thread_index_in_threadgroup]]) {
threadgroup tint_array<int, 246> tint_symbol_2;
tint_symbol_inner(local_invocation_index, &(tint_symbol_2));
return;
}

View File

@ -0,0 +1,76 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 44
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main" %local_invocation_index_1
OpExecutionMode %main LocalSize 1 1 1
OpName %local_invocation_index_1 "local_invocation_index_1"
OpName %W "W"
OpName %main_inner "main_inner"
OpName %local_invocation_index "local_invocation_index"
OpName %idx "idx"
OpName %main "main"
OpDecorate %local_invocation_index_1 BuiltIn LocalInvocationIndex
OpDecorate %_arr_int_uint_246 ArrayStride 4
%uint = OpTypeInt 32 0
%_ptr_Input_uint = OpTypePointer Input %uint
%local_invocation_index_1 = OpVariable %_ptr_Input_uint Input
%int = OpTypeInt 32 1
%uint_246 = OpConstant %uint 246
%_arr_int_uint_246 = OpTypeArray %int %uint_246
%_ptr_Workgroup__arr_int_uint_246 = OpTypePointer Workgroup %_arr_int_uint_246
%W = OpVariable %_ptr_Workgroup__arr_int_uint_246 Workgroup
%void = OpTypeVoid
%9 = OpTypeFunction %void %uint
%_ptr_Function_uint = OpTypePointer Function %uint
%16 = OpConstantNull %uint
%bool = OpTypeBool
%_ptr_Workgroup_int = OpTypePointer Workgroup %int
%30 = OpConstantNull %int
%uint_1 = OpConstant %uint 1
%uint_2 = OpConstant %uint 2
%uint_264 = OpConstant %uint 264
%int_42 = OpConstant %int 42
%39 = OpTypeFunction %void
%main_inner = OpFunction %void None %9
%local_invocation_index = OpFunctionParameter %uint
%13 = OpLabel
%idx = OpVariable %_ptr_Function_uint Function %16
OpStore %idx %local_invocation_index
OpBranch %17
%17 = OpLabel
OpLoopMerge %18 %19 None
OpBranch %20
%20 = OpLabel
%22 = OpLoad %uint %idx
%23 = OpULessThan %bool %22 %uint_246
%21 = OpLogicalNot %bool %23
OpSelectionMerge %25 None
OpBranchConditional %21 %26 %25
%26 = OpLabel
OpBranch %18
%25 = OpLabel
%27 = OpLoad %uint %idx
%29 = OpAccessChain %_ptr_Workgroup_int %W %27
OpStore %29 %30
OpBranch %19
%19 = OpLabel
%31 = OpLoad %uint %idx
%33 = OpIAdd %uint %31 %uint_1
OpStore %idx %33
OpBranch %17
%18 = OpLabel
OpControlBarrier %uint_2 %uint_2 %uint_264
%37 = OpAccessChain %_ptr_Workgroup_int %W %30
OpStore %37 %int_42
OpReturn
OpFunctionEnd
%main = OpFunction %void None %39
%41 = OpLabel
%43 = OpLoad %uint %local_invocation_index_1
%42 = OpFunctionCall %void %main_inner %43
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,11 @@
const O = 123;
type A = array<i32, (O * 2)>;
var<workgroup> W : A;
@compute @workgroup_size(1)
fn main() {
let p : ptr<workgroup, A> = &(W);
(*(p))[0] = 42;
}