msl: Fix non-struct runtime-sized array codegen
When these are used inside a function, we were not unwrapping the array from the struct that we wrapped it in. Fixed: tint:1385 Change-Id: Ide7bbd802394bf09819265be48d978ec9346adfe Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/77180 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Ben Clayton <bclayton@google.com> Commit-Queue: James Price <jrprice@google.com>
This commit is contained in:
parent
6b1e5f552b
commit
b28c5ad3c6
|
@ -141,7 +141,9 @@ struct ModuleScopeVarToEntryPointParam::State {
|
|||
struct NewVar {
|
||||
Symbol symbol;
|
||||
bool is_pointer;
|
||||
bool is_wrapped;
|
||||
};
|
||||
const char* kWrappedArrayMemberName = "arr";
|
||||
std::unordered_map<const sem::Variable*, NewVar> var_to_newvar;
|
||||
|
||||
// We aggregate all workgroup variables into a struct to avoid hitting
|
||||
|
@ -182,7 +184,6 @@ struct ModuleScopeVarToEntryPointParam::State {
|
|||
|
||||
// Track whether the new variable was wrapped in a struct or not.
|
||||
bool is_wrapped = false;
|
||||
const char* kWrappedArrayMemberName = "arr";
|
||||
|
||||
if (is_entry_point) {
|
||||
if (var->Type()->UnwrapRef()->is_handle()) {
|
||||
|
@ -309,7 +310,7 @@ struct ModuleScopeVarToEntryPointParam::State {
|
|||
}
|
||||
}
|
||||
|
||||
var_to_newvar[var] = {new_var_symbol, is_pointer};
|
||||
var_to_newvar[var] = {new_var_symbol, is_pointer, is_wrapped};
|
||||
}
|
||||
|
||||
if (!workgroup_parameter_members.empty()) {
|
||||
|
@ -344,7 +345,12 @@ struct ModuleScopeVarToEntryPointParam::State {
|
|||
auto new_var = var_to_newvar[target_var];
|
||||
bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
|
||||
const ast::Expression* arg = ctx.dst->Expr(new_var.symbol);
|
||||
if (is_entry_point && !is_handle && !new_var.is_pointer) {
|
||||
if (new_var.is_wrapped) {
|
||||
// The variable is wrapped in a struct, so we need to pass a pointer
|
||||
// to the struct member instead.
|
||||
arg = ctx.dst->AddressOf(ctx.dst->MemberAccessor(
|
||||
ctx.dst->Deref(arg), kWrappedArrayMemberName));
|
||||
} else if (is_entry_point && !is_handle && !new_var.is_pointer) {
|
||||
// We need to pass a pointer and we don't already have one, so take
|
||||
// the address of the new variable.
|
||||
arg = ctx.dst->AddressOf(arg);
|
||||
|
|
|
@ -259,6 +259,41 @@ fn main([[group(0), binding(0), internal(disable_validation__entry_point_paramet
|
|||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArrayInsideFunction) {
|
||||
auto* src = R"(
|
||||
[[group(0), binding(0)]]
|
||||
var<storage> buffer : array<f32>;
|
||||
|
||||
fn foo() {
|
||||
_ = buffer[0];
|
||||
}
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
foo();
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol_2 {
|
||||
arr : array<f32>;
|
||||
}
|
||||
|
||||
fn foo([[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol : ptr<storage, array<f32>>) {
|
||||
_ = (*(tint_symbol))[0];
|
||||
}
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter), internal(disable_validation__ignore_storage_class)]] tint_symbol_1 : ptr<storage, tint_symbol_2>) {
|
||||
foo(&((*(tint_symbol_1)).arr));
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArray_Alias) {
|
||||
auto* src = R"(
|
||||
type myarray = array<f32>;
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
[[group(0), binding(1)]]
|
||||
var<storage, read> data : array<i32>;
|
||||
|
||||
fn foo() -> i32 {
|
||||
return data[0];
|
||||
}
|
||||
|
||||
[[stage(compute), workgroup_size(16, 16, 1)]]
|
||||
fn main() {
|
||||
foo();
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
#version 310 es
|
||||
precision mediump float;
|
||||
|
||||
|
||||
layout (binding = 1) buffer data_block_1 {
|
||||
int inner[];
|
||||
} data;
|
||||
|
||||
int foo() {
|
||||
return data.inner[0];
|
||||
}
|
||||
|
||||
layout(local_size_x = 16, local_size_y = 16, local_size_z = 1) in;
|
||||
void tint_symbol() {
|
||||
foo();
|
||||
return;
|
||||
}
|
||||
void main() {
|
||||
tint_symbol();
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
ByteAddressBuffer data : register(t1, space0);
|
||||
|
||||
int foo() {
|
||||
return asint(data.Load((4u * uint(0))));
|
||||
}
|
||||
|
||||
[numthreads(16, 16, 1)]
|
||||
void main() {
|
||||
foo();
|
||||
return;
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
struct tint_symbol_3 {
|
||||
/* 0x0000 */ int arr[1];
|
||||
};
|
||||
|
||||
int foo(const device int (*const tint_symbol_1)[1]) {
|
||||
return (*(tint_symbol_1))[0];
|
||||
}
|
||||
|
||||
kernel void tint_symbol(const device tint_symbol_3* tint_symbol_2 [[buffer(0)]]) {
|
||||
foo(&((*(tint_symbol_2)).arr));
|
||||
return;
|
||||
}
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
; SPIR-V
|
||||
; Version: 1.3
|
||||
; Generator: Google Tint Compiler; 0
|
||||
; Bound: 20
|
||||
; Schema: 0
|
||||
OpCapability Shader
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %main "main"
|
||||
OpExecutionMode %main LocalSize 16 16 1
|
||||
OpName %data_block "data_block"
|
||||
OpMemberName %data_block 0 "inner"
|
||||
OpName %data "data"
|
||||
OpName %foo "foo"
|
||||
OpName %main "main"
|
||||
OpDecorate %data_block Block
|
||||
OpMemberDecorate %data_block 0 Offset 0
|
||||
OpDecorate %_runtimearr_int ArrayStride 4
|
||||
OpDecorate %data NonWritable
|
||||
OpDecorate %data DescriptorSet 0
|
||||
OpDecorate %data Binding 1
|
||||
%int = OpTypeInt 32 1
|
||||
%_runtimearr_int = OpTypeRuntimeArray %int
|
||||
%data_block = OpTypeStruct %_runtimearr_int
|
||||
%_ptr_StorageBuffer_data_block = OpTypePointer StorageBuffer %data_block
|
||||
%data = OpVariable %_ptr_StorageBuffer_data_block StorageBuffer
|
||||
%6 = OpTypeFunction %int
|
||||
%uint = OpTypeInt 32 0
|
||||
%uint_0 = OpConstant %uint 0
|
||||
%int_0 = OpConstant %int 0
|
||||
%_ptr_StorageBuffer_int = OpTypePointer StorageBuffer %int
|
||||
%void = OpTypeVoid
|
||||
%15 = OpTypeFunction %void
|
||||
%foo = OpFunction %int None %6
|
||||
%8 = OpLabel
|
||||
%13 = OpAccessChain %_ptr_StorageBuffer_int %data %uint_0 %int_0
|
||||
%14 = OpLoad %int %13
|
||||
OpReturnValue %14
|
||||
OpFunctionEnd
|
||||
%main = OpFunction %void None %15
|
||||
%18 = OpLabel
|
||||
%19 = OpFunctionCall %int %foo
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -0,0 +1,10 @@
|
|||
[[group(0), binding(1)]] var<storage, read> data : array<i32>;
|
||||
|
||||
fn foo() -> i32 {
|
||||
return data[0];
|
||||
}
|
||||
|
||||
[[stage(compute), workgroup_size(16, 16, 1)]]
|
||||
fn main() {
|
||||
foo();
|
||||
}
|
Loading…
Reference in New Issue