msl: Fix non-struct runtime-sized array codegen
When these are used inside a function, we were not unwrapping the array from the struct that we wrapped it in. Fixed: tint:1385 Change-Id: Ide7bbd802394bf09819265be48d978ec9346adfe Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/77180 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Ben Clayton <bclayton@google.com> Commit-Queue: James Price <jrprice@google.com>
This commit is contained in:
parent
6b1e5f552b
commit
b28c5ad3c6
|
@ -141,7 +141,9 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||||
struct NewVar {
|
struct NewVar {
|
||||||
Symbol symbol;
|
Symbol symbol;
|
||||||
bool is_pointer;
|
bool is_pointer;
|
||||||
|
bool is_wrapped;
|
||||||
};
|
};
|
||||||
|
const char* kWrappedArrayMemberName = "arr";
|
||||||
std::unordered_map<const sem::Variable*, NewVar> var_to_newvar;
|
std::unordered_map<const sem::Variable*, NewVar> var_to_newvar;
|
||||||
|
|
||||||
// We aggregate all workgroup variables into a struct to avoid hitting
|
// We aggregate all workgroup variables into a struct to avoid hitting
|
||||||
|
@ -182,7 +184,6 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||||
|
|
||||||
// Track whether the new variable was wrapped in a struct or not.
|
// Track whether the new variable was wrapped in a struct or not.
|
||||||
bool is_wrapped = false;
|
bool is_wrapped = false;
|
||||||
const char* kWrappedArrayMemberName = "arr";
|
|
||||||
|
|
||||||
if (is_entry_point) {
|
if (is_entry_point) {
|
||||||
if (var->Type()->UnwrapRef()->is_handle()) {
|
if (var->Type()->UnwrapRef()->is_handle()) {
|
||||||
|
@ -309,7 +310,7 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var_to_newvar[var] = {new_var_symbol, is_pointer};
|
var_to_newvar[var] = {new_var_symbol, is_pointer, is_wrapped};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!workgroup_parameter_members.empty()) {
|
if (!workgroup_parameter_members.empty()) {
|
||||||
|
@ -344,7 +345,12 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||||
auto new_var = var_to_newvar[target_var];
|
auto new_var = var_to_newvar[target_var];
|
||||||
bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
|
bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
|
||||||
const ast::Expression* arg = ctx.dst->Expr(new_var.symbol);
|
const ast::Expression* arg = ctx.dst->Expr(new_var.symbol);
|
||||||
if (is_entry_point && !is_handle && !new_var.is_pointer) {
|
if (new_var.is_wrapped) {
|
||||||
|
// The variable is wrapped in a struct, so we need to pass a pointer
|
||||||
|
// to the struct member instead.
|
||||||
|
arg = ctx.dst->AddressOf(ctx.dst->MemberAccessor(
|
||||||
|
ctx.dst->Deref(arg), kWrappedArrayMemberName));
|
||||||
|
} else if (is_entry_point && !is_handle && !new_var.is_pointer) {
|
||||||
// We need to pass a pointer and we don't already have one, so take
|
// We need to pass a pointer and we don't already have one, so take
|
||||||
// the address of the new variable.
|
// the address of the new variable.
|
||||||
arg = ctx.dst->AddressOf(arg);
|
arg = ctx.dst->AddressOf(arg);
|
||||||
|
|
|
@ -259,6 +259,41 @@ fn main([[group(0), binding(0), internal(disable_validation__entry_point_paramet
|
||||||
EXPECT_EQ(expect, str(got));
|
EXPECT_EQ(expect, str(got));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArrayInsideFunction) {
|
||||||
|
auto* src = R"(
|
||||||
|
[[group(0), binding(0)]]
|
||||||
|
var<storage> buffer : array<f32>;
|
||||||
|
|
||||||
|
fn foo() {
|
||||||
|
_ = buffer[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
[[stage(compute), workgroup_size(1)]]
|
||||||
|
fn main() {
|
||||||
|
foo();
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
auto* expect = R"(
|
||||||
|
struct tint_symbol_2 {
|
||||||
|
arr : array<f32>;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn foo([[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol : ptr<storage, array<f32>>) {
|
||||||
|
_ = (*(tint_symbol))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
[[stage(compute), workgroup_size(1)]]
|
||||||
|
fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter), internal(disable_validation__ignore_storage_class)]] tint_symbol_1 : ptr<storage, tint_symbol_2>) {
|
||||||
|
foo(&((*(tint_symbol_1)).arr));
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
|
||||||
|
|
||||||
|
EXPECT_EQ(expect, str(got));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArray_Alias) {
|
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArray_Alias) {
|
||||||
auto* src = R"(
|
auto* src = R"(
|
||||||
type myarray = array<f32>;
|
type myarray = array<f32>;
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
[[group(0), binding(1)]]
|
||||||
|
var<storage, read> data : array<i32>;
|
||||||
|
|
||||||
|
fn foo() -> i32 {
|
||||||
|
return data[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
[[stage(compute), workgroup_size(16, 16, 1)]]
|
||||||
|
fn main() {
|
||||||
|
foo();
|
||||||
|
}
|
|
@ -0,0 +1,22 @@
|
||||||
|
#version 310 es
|
||||||
|
precision mediump float;
|
||||||
|
|
||||||
|
|
||||||
|
layout (binding = 1) buffer data_block_1 {
|
||||||
|
int inner[];
|
||||||
|
} data;
|
||||||
|
|
||||||
|
int foo() {
|
||||||
|
return data.inner[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
layout(local_size_x = 16, local_size_y = 16, local_size_z = 1) in;
|
||||||
|
void tint_symbol() {
|
||||||
|
foo();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
void main() {
|
||||||
|
tint_symbol();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
ByteAddressBuffer data : register(t1, space0);
|
||||||
|
|
||||||
|
int foo() {
|
||||||
|
return asint(data.Load((4u * uint(0))));
|
||||||
|
}
|
||||||
|
|
||||||
|
[numthreads(16, 16, 1)]
|
||||||
|
void main() {
|
||||||
|
foo();
|
||||||
|
return;
|
||||||
|
}
|
|
@ -0,0 +1,16 @@
|
||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
struct tint_symbol_3 {
|
||||||
|
/* 0x0000 */ int arr[1];
|
||||||
|
};
|
||||||
|
|
||||||
|
int foo(const device int (*const tint_symbol_1)[1]) {
|
||||||
|
return (*(tint_symbol_1))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void tint_symbol(const device tint_symbol_3* tint_symbol_2 [[buffer(0)]]) {
|
||||||
|
foo(&((*(tint_symbol_2)).arr));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
; SPIR-V
|
||||||
|
; Version: 1.3
|
||||||
|
; Generator: Google Tint Compiler; 0
|
||||||
|
; Bound: 20
|
||||||
|
; Schema: 0
|
||||||
|
OpCapability Shader
|
||||||
|
OpMemoryModel Logical GLSL450
|
||||||
|
OpEntryPoint GLCompute %main "main"
|
||||||
|
OpExecutionMode %main LocalSize 16 16 1
|
||||||
|
OpName %data_block "data_block"
|
||||||
|
OpMemberName %data_block 0 "inner"
|
||||||
|
OpName %data "data"
|
||||||
|
OpName %foo "foo"
|
||||||
|
OpName %main "main"
|
||||||
|
OpDecorate %data_block Block
|
||||||
|
OpMemberDecorate %data_block 0 Offset 0
|
||||||
|
OpDecorate %_runtimearr_int ArrayStride 4
|
||||||
|
OpDecorate %data NonWritable
|
||||||
|
OpDecorate %data DescriptorSet 0
|
||||||
|
OpDecorate %data Binding 1
|
||||||
|
%int = OpTypeInt 32 1
|
||||||
|
%_runtimearr_int = OpTypeRuntimeArray %int
|
||||||
|
%data_block = OpTypeStruct %_runtimearr_int
|
||||||
|
%_ptr_StorageBuffer_data_block = OpTypePointer StorageBuffer %data_block
|
||||||
|
%data = OpVariable %_ptr_StorageBuffer_data_block StorageBuffer
|
||||||
|
%6 = OpTypeFunction %int
|
||||||
|
%uint = OpTypeInt 32 0
|
||||||
|
%uint_0 = OpConstant %uint 0
|
||||||
|
%int_0 = OpConstant %int 0
|
||||||
|
%_ptr_StorageBuffer_int = OpTypePointer StorageBuffer %int
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%15 = OpTypeFunction %void
|
||||||
|
%foo = OpFunction %int None %6
|
||||||
|
%8 = OpLabel
|
||||||
|
%13 = OpAccessChain %_ptr_StorageBuffer_int %data %uint_0 %int_0
|
||||||
|
%14 = OpLoad %int %13
|
||||||
|
OpReturnValue %14
|
||||||
|
OpFunctionEnd
|
||||||
|
%main = OpFunction %void None %15
|
||||||
|
%18 = OpLabel
|
||||||
|
%19 = OpFunctionCall %int %foo
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
|
@ -0,0 +1,10 @@
|
||||||
|
[[group(0), binding(1)]] var<storage, read> data : array<i32>;
|
||||||
|
|
||||||
|
fn foo() -> i32 {
|
||||||
|
return data[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
[[stage(compute), workgroup_size(16, 16, 1)]]
|
||||||
|
fn main() {
|
||||||
|
foo();
|
||||||
|
}
|
Loading…
Reference in New Issue