msl: Fix non-struct runtime-sized array codegen

When these are used inside a function, we were not unwrapping the
array from the struct that we wrapped it in.

Fixed: tint:1385
Change-Id: Ide7bbd802394bf09819265be48d978ec9346adfe
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/77180
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: James Price <jrprice@google.com>
This commit is contained in:
James Price 2022-01-19 19:36:17 +00:00 committed by Tint LUCI CQ
parent 6b1e5f552b
commit b28c5ad3c6
8 changed files with 157 additions and 3 deletions

View File

@ -141,7 +141,9 @@ struct ModuleScopeVarToEntryPointParam::State {
struct NewVar { struct NewVar {
Symbol symbol; Symbol symbol;
bool is_pointer; bool is_pointer;
bool is_wrapped;
}; };
const char* kWrappedArrayMemberName = "arr";
std::unordered_map<const sem::Variable*, NewVar> var_to_newvar; std::unordered_map<const sem::Variable*, NewVar> var_to_newvar;
// We aggregate all workgroup variables into a struct to avoid hitting // We aggregate all workgroup variables into a struct to avoid hitting
@ -182,7 +184,6 @@ struct ModuleScopeVarToEntryPointParam::State {
// Track whether the new variable was wrapped in a struct or not. // Track whether the new variable was wrapped in a struct or not.
bool is_wrapped = false; bool is_wrapped = false;
const char* kWrappedArrayMemberName = "arr";
if (is_entry_point) { if (is_entry_point) {
if (var->Type()->UnwrapRef()->is_handle()) { if (var->Type()->UnwrapRef()->is_handle()) {
@ -309,7 +310,7 @@ struct ModuleScopeVarToEntryPointParam::State {
} }
} }
var_to_newvar[var] = {new_var_symbol, is_pointer}; var_to_newvar[var] = {new_var_symbol, is_pointer, is_wrapped};
} }
if (!workgroup_parameter_members.empty()) { if (!workgroup_parameter_members.empty()) {
@ -344,7 +345,12 @@ struct ModuleScopeVarToEntryPointParam::State {
auto new_var = var_to_newvar[target_var]; auto new_var = var_to_newvar[target_var];
bool is_handle = target_var->Type()->UnwrapRef()->is_handle(); bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
const ast::Expression* arg = ctx.dst->Expr(new_var.symbol); const ast::Expression* arg = ctx.dst->Expr(new_var.symbol);
if (is_entry_point && !is_handle && !new_var.is_pointer) { if (new_var.is_wrapped) {
// The variable is wrapped in a struct, so we need to pass a pointer
// to the struct member instead.
arg = ctx.dst->AddressOf(ctx.dst->MemberAccessor(
ctx.dst->Deref(arg), kWrappedArrayMemberName));
} else if (is_entry_point && !is_handle && !new_var.is_pointer) {
// We need to pass a pointer and we don't already have one, so take // We need to pass a pointer and we don't already have one, so take
// the address of the new variable. // the address of the new variable.
arg = ctx.dst->AddressOf(arg); arg = ctx.dst->AddressOf(arg);

View File

@ -259,6 +259,41 @@ fn main([[group(0), binding(0), internal(disable_validation__entry_point_paramet
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
} }
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArrayInsideFunction) {
auto* src = R"(
[[group(0), binding(0)]]
var<storage> buffer : array<f32>;
fn foo() {
_ = buffer[0];
}
[[stage(compute), workgroup_size(1)]]
fn main() {
foo();
}
)";
auto* expect = R"(
struct tint_symbol_2 {
arr : array<f32>;
}
fn foo([[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol : ptr<storage, array<f32>>) {
_ = (*(tint_symbol))[0];
}
[[stage(compute), workgroup_size(1)]]
fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter), internal(disable_validation__ignore_storage_class)]] tint_symbol_1 : ptr<storage, tint_symbol_2>) {
foo(&((*(tint_symbol_1)).arr));
}
)";
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArray_Alias) { TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArray_Alias) {
auto* src = R"( auto* src = R"(
type myarray = array<f32>; type myarray = array<f32>;

11
test/bug/tint/1385.wgsl Normal file
View File

@ -0,0 +1,11 @@
[[group(0), binding(1)]]
var<storage, read> data : array<i32>;
fn foo() -> i32 {
return data[0];
}
[[stage(compute), workgroup_size(16, 16, 1)]]
fn main() {
foo();
}

View File

@ -0,0 +1,22 @@
#version 310 es
precision mediump float;
layout (binding = 1) buffer data_block_1 {
int inner[];
} data;
int foo() {
return data.inner[0];
}
layout(local_size_x = 16, local_size_y = 16, local_size_z = 1) in;
void tint_symbol() {
foo();
return;
}
void main() {
tint_symbol();
}

View File

@ -0,0 +1,11 @@
ByteAddressBuffer data : register(t1, space0);
int foo() {
return asint(data.Load((4u * uint(0))));
}
[numthreads(16, 16, 1)]
void main() {
foo();
return;
}

View File

@ -0,0 +1,16 @@
#include <metal_stdlib>
using namespace metal;
struct tint_symbol_3 {
/* 0x0000 */ int arr[1];
};
int foo(const device int (*const tint_symbol_1)[1]) {
return (*(tint_symbol_1))[0];
}
kernel void tint_symbol(const device tint_symbol_3* tint_symbol_2 [[buffer(0)]]) {
foo(&((*(tint_symbol_2)).arr));
return;
}

View File

@ -0,0 +1,43 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 20
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 16 16 1
OpName %data_block "data_block"
OpMemberName %data_block 0 "inner"
OpName %data "data"
OpName %foo "foo"
OpName %main "main"
OpDecorate %data_block Block
OpMemberDecorate %data_block 0 Offset 0
OpDecorate %_runtimearr_int ArrayStride 4
OpDecorate %data NonWritable
OpDecorate %data DescriptorSet 0
OpDecorate %data Binding 1
%int = OpTypeInt 32 1
%_runtimearr_int = OpTypeRuntimeArray %int
%data_block = OpTypeStruct %_runtimearr_int
%_ptr_StorageBuffer_data_block = OpTypePointer StorageBuffer %data_block
%data = OpVariable %_ptr_StorageBuffer_data_block StorageBuffer
%6 = OpTypeFunction %int
%uint = OpTypeInt 32 0
%uint_0 = OpConstant %uint 0
%int_0 = OpConstant %int 0
%_ptr_StorageBuffer_int = OpTypePointer StorageBuffer %int
%void = OpTypeVoid
%15 = OpTypeFunction %void
%foo = OpFunction %int None %6
%8 = OpLabel
%13 = OpAccessChain %_ptr_StorageBuffer_int %data %uint_0 %int_0
%14 = OpLoad %int %13
OpReturnValue %14
OpFunctionEnd
%main = OpFunction %void None %15
%18 = OpLabel
%19 = OpFunctionCall %int %foo
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,10 @@
[[group(0), binding(1)]] var<storage, read> data : array<i32>;
fn foo() -> i32 {
return data[0];
}
[[stage(compute), workgroup_size(16, 16, 1)]]
fn main() {
foo();
}