mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-17 00:47:13 +00:00
msl: Fix non-struct runtime-sized array codegen
When these are used inside a function, we were not unwrapping the array from the struct that we wrapped it in. Fixed: tint:1385 Change-Id: Ide7bbd802394bf09819265be48d978ec9346adfe Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/77180 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Ben Clayton <bclayton@google.com> Commit-Queue: James Price <jrprice@google.com>
This commit is contained in:
committed by
Tint LUCI CQ
parent
6b1e5f552b
commit
b28c5ad3c6
@@ -141,7 +141,9 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||
struct NewVar {
|
||||
Symbol symbol;
|
||||
bool is_pointer;
|
||||
bool is_wrapped;
|
||||
};
|
||||
const char* kWrappedArrayMemberName = "arr";
|
||||
std::unordered_map<const sem::Variable*, NewVar> var_to_newvar;
|
||||
|
||||
// We aggregate all workgroup variables into a struct to avoid hitting
|
||||
@@ -182,7 +184,6 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||
|
||||
// Track whether the new variable was wrapped in a struct or not.
|
||||
bool is_wrapped = false;
|
||||
const char* kWrappedArrayMemberName = "arr";
|
||||
|
||||
if (is_entry_point) {
|
||||
if (var->Type()->UnwrapRef()->is_handle()) {
|
||||
@@ -309,7 +310,7 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||
}
|
||||
}
|
||||
|
||||
var_to_newvar[var] = {new_var_symbol, is_pointer};
|
||||
var_to_newvar[var] = {new_var_symbol, is_pointer, is_wrapped};
|
||||
}
|
||||
|
||||
if (!workgroup_parameter_members.empty()) {
|
||||
@@ -344,7 +345,12 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||
auto new_var = var_to_newvar[target_var];
|
||||
bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
|
||||
const ast::Expression* arg = ctx.dst->Expr(new_var.symbol);
|
||||
if (is_entry_point && !is_handle && !new_var.is_pointer) {
|
||||
if (new_var.is_wrapped) {
|
||||
// The variable is wrapped in a struct, so we need to pass a pointer
|
||||
// to the struct member instead.
|
||||
arg = ctx.dst->AddressOf(ctx.dst->MemberAccessor(
|
||||
ctx.dst->Deref(arg), kWrappedArrayMemberName));
|
||||
} else if (is_entry_point && !is_handle && !new_var.is_pointer) {
|
||||
// We need to pass a pointer and we don't already have one, so take
|
||||
// the address of the new variable.
|
||||
arg = ctx.dst->AddressOf(arg);
|
||||
|
||||
@@ -259,6 +259,41 @@ fn main([[group(0), binding(0), internal(disable_validation__entry_point_paramet
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArrayInsideFunction) {
|
||||
auto* src = R"(
|
||||
[[group(0), binding(0)]]
|
||||
var<storage> buffer : array<f32>;
|
||||
|
||||
fn foo() {
|
||||
_ = buffer[0];
|
||||
}
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
foo();
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol_2 {
|
||||
arr : array<f32>;
|
||||
}
|
||||
|
||||
fn foo([[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol : ptr<storage, array<f32>>) {
|
||||
_ = (*(tint_symbol))[0];
|
||||
}
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter), internal(disable_validation__ignore_storage_class)]] tint_symbol_1 : ptr<storage, tint_symbol_2>) {
|
||||
foo(&((*(tint_symbol_1)).arr));
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffer_RuntimeArray_Alias) {
|
||||
auto* src = R"(
|
||||
type myarray = array<f32>;
|
||||
|
||||
Reference in New Issue
Block a user