msl: Fix non-struct runtime-sized array codegen

When these are used inside a function, we were not unwrapping the
array from the struct that we wrapped it in.

Fixed: tint:1385
Change-Id: Ide7bbd802394bf09819265be48d978ec9346adfe
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/77180
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: James Price <jrprice@google.com>
This commit is contained in:
James Price
2022-01-19 19:36:17 +00:00
committed by Tint LUCI CQ
parent 6b1e5f552b
commit b28c5ad3c6
8 changed files with 157 additions and 3 deletions

View File

@@ -141,7 +141,9 @@ struct ModuleScopeVarToEntryPointParam::State {
struct NewVar {
Symbol symbol;
bool is_pointer;
bool is_wrapped;
};
const char* kWrappedArrayMemberName = "arr";
std::unordered_map<const sem::Variable*, NewVar> var_to_newvar;
// We aggregate all workgroup variables into a struct to avoid hitting
@@ -182,7 +184,6 @@ struct ModuleScopeVarToEntryPointParam::State {
// Track whether the new variable was wrapped in a struct or not.
bool is_wrapped = false;
const char* kWrappedArrayMemberName = "arr";
if (is_entry_point) {
if (var->Type()->UnwrapRef()->is_handle()) {
@@ -309,7 +310,7 @@ struct ModuleScopeVarToEntryPointParam::State {
}
}
var_to_newvar[var] = {new_var_symbol, is_pointer};
var_to_newvar[var] = {new_var_symbol, is_pointer, is_wrapped};
}
if (!workgroup_parameter_members.empty()) {
@@ -344,7 +345,12 @@ struct ModuleScopeVarToEntryPointParam::State {
auto new_var = var_to_newvar[target_var];
bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
const ast::Expression* arg = ctx.dst->Expr(new_var.symbol);
if (is_entry_point && !is_handle && !new_var.is_pointer) {
if (new_var.is_wrapped) {
// The variable is wrapped in a struct, so we need to pass a pointer
// to the struct member instead.
arg = ctx.dst->AddressOf(ctx.dst->MemberAccessor(
ctx.dst->Deref(arg), kWrappedArrayMemberName));
} else if (is_entry_point && !is_handle && !new_var.is_pointer) {
// We need to pass a pointer and we don't already have one, so take
// the address of the new variable.
arg = ctx.dst->AddressOf(arg);

View File

@@ -259,6 +259,41 @@ fn main([[group(0), binding(0), internal(disable_validation__entry_point_paramet
EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArrayInsideFunction) {
auto* src = R"(
[[group(0), binding(0)]]
var<storage> buffer : array<f32>;
fn foo() {
_ = buffer[0];
}
[[stage(compute), workgroup_size(1)]]
fn main() {
foo();
}
)";
auto* expect = R"(
struct tint_symbol_2 {
arr : array<f32>;
}
fn foo([[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol : ptr<storage, array<f32>>) {
_ = (*(tint_symbol))[0];
}
[[stage(compute), workgroup_size(1)]]
fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter), internal(disable_validation__ignore_storage_class)]] tint_symbol_1 : ptr<storage, tint_symbol_2>) {
foo(&((*(tint_symbol_1)).arr));
}
)";
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArray_Alias) {
auto* src = R"(
type myarray = array<f32>;