msl: Fold &* when converting module-scope vars

This transform was previously converting this code:
```
var<private> v : f32;
fn foo() {
  bar(&v);
}
```

into this:
```
fn foo(vp : ptr<private, f32>) {
  bar(&*vp); // Invalid, since ptr args must be &ident
}
```

Fixed: tint:1086
Change-Id: Ic9efafa219c89a11a4d6e1d11fc69b3c0b9a5464
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/60520
Kokoro: Kokoro <noreply+kokoro@google.com>
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
James Price 2021-08-04 19:18:38 +00:00 committed by Tint LUCI CQ
parent 98fbf241d8
commit 5c61d6d12c
42 changed files with 159 additions and 36 deletions

View File

@ -181,6 +181,23 @@ void Msl::HandleModuleScopeVariables(CloneContext& ctx) const {
}
}
// Build a list of `&ident` expressions. We'll use this later to avoid
// generating expressions of the form `&*ident`, which break WGSL validation
// rules when this expression is passed to a function.
// TODO(jrprice): We should add support for bidirectional SEM tree traversal
// so that we can do this on the fly instead.
std::unordered_map<ast::IdentifierExpression*, ast::UnaryOpExpression*>
ident_to_address_of;
for (auto* node : ctx.src->ASTNodes().Objects()) {
auto* address_of = node->As<ast::UnaryOpExpression>();
if (!address_of || address_of->op() != ast::UnaryOp::kAddressOf) {
continue;
}
if (auto* ident = address_of->expr()->As<ast::IdentifierExpression>()) {
ident_to_address_of[ident] = address_of;
}
}
for (auto* func_ast : functions_to_process) {
auto* func_sem = ctx.src->Sem().Get(func_ast);
bool is_entry_point = func_ast->IsEntryPoint();
@ -241,6 +258,15 @@ void Msl::HandleModuleScopeVariables(CloneContext& ctx) const {
if (user->Stmt()->Function() == func_ast) {
ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
if (!is_entry_point && !store_type->is_handle()) {
// If this identifier is used by an address-of operator, just remove
// the address-of instead of adding a deref, since we already have a
// pointer.
auto* ident = user->Declaration()->As<ast::IdentifierExpression>();
if (ident_to_address_of.count(ident)) {
ctx.Replace(ident_to_address_of[ident], expr);
continue;
}
expr = ctx.dst->Deref(expr);
}
ctx.Replace(user->Declaration(), expr);

14
test/bug/tint/1086.wgsl Normal file
View File

@ -0,0 +1,14 @@
var<private> v : f32;
fn x(p : ptr<private, f32>) {
(*p) = 0.0;
}
fn g() {
x(&v);
}
[[stage(fragment)]]
fn f() {
g();
}

View File

@ -0,0 +1,14 @@
static float v = 0.0f;
void x(inout float p) {
p = 0.0f;
}
void g() {
x(v);
}
void f() {
g();
return;
}

View File

@ -0,0 +1,17 @@
#include <metal_stdlib>
using namespace metal;
void x(thread float* const p) {
*(p) = 0.0f;
}
void g(thread float* const tint_symbol) {
x(tint_symbol);
}
fragment void f() {
thread float tint_symbol_1 = 0.0f;
g(&(tint_symbol_1));
return;
}

View File

@ -0,0 +1,38 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 20
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %f "f"
OpExecutionMode %f OriginUpperLeft
OpName %v "v"
OpName %x "x"
OpName %p "p"
OpName %g "g"
OpName %f "f"
%float = OpTypeFloat 32
%_ptr_Private_float = OpTypePointer Private %float
%4 = OpConstantNull %float
%v = OpVariable %_ptr_Private_float Private %4
%void = OpTypeVoid
%5 = OpTypeFunction %void %_ptr_Private_float
%float_0 = OpConstant %float 0
%12 = OpTypeFunction %void
%x = OpFunction %void None %5
%p = OpFunctionParameter %_ptr_Private_float
%9 = OpLabel
OpStore %p %float_0
OpReturn
OpFunctionEnd
%g = OpFunction %void None %12
%14 = OpLabel
%15 = OpFunctionCall %void %x %v
OpReturn
OpFunctionEnd
%f = OpFunction %void None %12
%18 = OpLabel
%19 = OpFunctionCall %void %g
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,14 @@
var<private> v : f32;
fn x(p : ptr<private, f32>) {
*(p) = 0.0;
}
fn g() {
x(&(v));
}
[[stage(fragment)]]
fn f() {
g();
}

View File

@ -2,7 +2,7 @@
using namespace metal;
void atomicAdd_794055(threadgroup atomic_int* const tint_symbol_1) {
int res = atomic_fetch_add_explicit(&(*(tint_symbol_1)), 1, memory_order_relaxed);
int res = atomic_fetch_add_explicit(tint_symbol_1, 1, memory_order_relaxed);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -2,7 +2,7 @@
using namespace metal;
void atomicAdd_d5db1d(threadgroup atomic_uint* const tint_symbol_1) {
uint res = atomic_fetch_add_explicit(&(*(tint_symbol_1)), 1u, memory_order_relaxed);
uint res = atomic_fetch_add_explicit(tint_symbol_1, 1u, memory_order_relaxed);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -2,7 +2,7 @@
using namespace metal;
void atomicAnd_34edd3(threadgroup atomic_uint* const tint_symbol_1) {
uint res = atomic_fetch_and_explicit(&(*(tint_symbol_1)), 1u, memory_order_relaxed);
uint res = atomic_fetch_and_explicit(tint_symbol_1, 1u, memory_order_relaxed);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -2,7 +2,7 @@
using namespace metal;
void atomicAnd_45a819(threadgroup atomic_int* const tint_symbol_1) {
int res = atomic_fetch_and_explicit(&(*(tint_symbol_1)), 1, memory_order_relaxed);
int res = atomic_fetch_and_explicit(tint_symbol_1, 1, memory_order_relaxed);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -10,7 +10,7 @@ vec<T, 2> atomicCompareExchangeWeak_1(threadgroup A* atomic, T compare, T value)
}
void atomicCompareExchangeWeak_89ea3b(threadgroup atomic_int* const tint_symbol_1) {
int2 res = atomicCompareExchangeWeak_1(&(*(tint_symbol_1)), 1, 1);
int2 res = atomicCompareExchangeWeak_1(tint_symbol_1, 1, 1);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -10,7 +10,7 @@ vec<T, 2> atomicCompareExchangeWeak_1(threadgroup A* atomic, T compare, T value)
}
void atomicCompareExchangeWeak_b2ab2c(threadgroup atomic_uint* const tint_symbol_1) {
uint2 res = atomicCompareExchangeWeak_1(&(*(tint_symbol_1)), 1u, 1u);
uint2 res = atomicCompareExchangeWeak_1(tint_symbol_1, 1u, 1u);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -2,7 +2,7 @@
using namespace metal;
void atomicExchange_0a5dca(threadgroup atomic_uint* const tint_symbol_1) {
uint res = atomic_exchange_explicit(&(*(tint_symbol_1)), 1u, memory_order_relaxed);
uint res = atomic_exchange_explicit(tint_symbol_1, 1u, memory_order_relaxed);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -2,7 +2,7 @@
using namespace metal;
void atomicExchange_e114ba(threadgroup atomic_int* const tint_symbol_1) {
int res = atomic_exchange_explicit(&(*(tint_symbol_1)), 1, memory_order_relaxed);
int res = atomic_exchange_explicit(tint_symbol_1, 1, memory_order_relaxed);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -2,7 +2,7 @@
using namespace metal;
void atomicLoad_361bf1(threadgroup atomic_uint* const tint_symbol_1) {
uint res = atomic_load_explicit(&(*(tint_symbol_1)), memory_order_relaxed);
uint res = atomic_load_explicit(tint_symbol_1, memory_order_relaxed);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -2,7 +2,7 @@
using namespace metal;
void atomicLoad_afcc03(threadgroup atomic_int* const tint_symbol_1) {
int res = atomic_load_explicit(&(*(tint_symbol_1)), memory_order_relaxed);
int res = atomic_load_explicit(tint_symbol_1, memory_order_relaxed);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -2,7 +2,7 @@
using namespace metal;
void atomicMax_a89cc3(threadgroup atomic_int* const tint_symbol_1) {
int res = atomic_fetch_max_explicit(&(*(tint_symbol_1)), 1, memory_order_relaxed);
int res = atomic_fetch_max_explicit(tint_symbol_1, 1, memory_order_relaxed);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -2,7 +2,7 @@
using namespace metal;
void atomicMax_beccfc(threadgroup atomic_uint* const tint_symbol_1) {
uint res = atomic_fetch_max_explicit(&(*(tint_symbol_1)), 1u, memory_order_relaxed);
uint res = atomic_fetch_max_explicit(tint_symbol_1, 1u, memory_order_relaxed);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -2,7 +2,7 @@
using namespace metal;
void atomicMin_278235(threadgroup atomic_int* const tint_symbol_1) {
int res = atomic_fetch_min_explicit(&(*(tint_symbol_1)), 1, memory_order_relaxed);
int res = atomic_fetch_min_explicit(tint_symbol_1, 1, memory_order_relaxed);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -2,7 +2,7 @@
using namespace metal;
void atomicMin_69d383(threadgroup atomic_uint* const tint_symbol_1) {
uint res = atomic_fetch_min_explicit(&(*(tint_symbol_1)), 1u, memory_order_relaxed);
uint res = atomic_fetch_min_explicit(tint_symbol_1, 1u, memory_order_relaxed);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -2,7 +2,7 @@
using namespace metal;
void atomicOr_5e3d61(threadgroup atomic_uint* const tint_symbol_1) {
uint res = atomic_fetch_or_explicit(&(*(tint_symbol_1)), 1u, memory_order_relaxed);
uint res = atomic_fetch_or_explicit(tint_symbol_1, 1u, memory_order_relaxed);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -2,7 +2,7 @@
using namespace metal;
void atomicOr_d09248(threadgroup atomic_int* const tint_symbol_1) {
int res = atomic_fetch_or_explicit(&(*(tint_symbol_1)), 1, memory_order_relaxed);
int res = atomic_fetch_or_explicit(tint_symbol_1, 1, memory_order_relaxed);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -2,7 +2,7 @@
using namespace metal;
void atomicStore_726882(threadgroup atomic_uint* const tint_symbol_1) {
atomic_store_explicit(&(*(tint_symbol_1)), 1u, memory_order_relaxed);
atomic_store_explicit(tint_symbol_1, 1u, memory_order_relaxed);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -2,7 +2,7 @@
using namespace metal;
void atomicStore_8bea94(threadgroup atomic_int* const tint_symbol_1) {
atomic_store_explicit(&(*(tint_symbol_1)), 1, memory_order_relaxed);
atomic_store_explicit(tint_symbol_1, 1, memory_order_relaxed);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -2,7 +2,7 @@
using namespace metal;
void atomicXor_75dc95(threadgroup atomic_int* const tint_symbol_1) {
int res = atomic_fetch_xor_explicit(&(*(tint_symbol_1)), 1, memory_order_relaxed);
int res = atomic_fetch_xor_explicit(tint_symbol_1, 1, memory_order_relaxed);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -2,7 +2,7 @@
using namespace metal;
void atomicXor_c8e6be(threadgroup atomic_uint* const tint_symbol_1) {
uint res = atomic_fetch_xor_explicit(&(*(tint_symbol_1)), 1u, memory_order_relaxed);
uint res = atomic_fetch_xor_explicit(tint_symbol_1, 1u, memory_order_relaxed);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -14,7 +14,7 @@ float tint_frexp(float param_0, threadgroup int* param_1) {
}
void frexp_0da285(threadgroup int* const tint_symbol_1) {
float res = tint_frexp(1.0f, &(*(tint_symbol_1)));
float res = tint_frexp(1.0f, tint_symbol_1);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -14,7 +14,7 @@ float3 tint_frexp(float3 param_0, threadgroup int3* param_1) {
}
void frexp_40fc9b(threadgroup int3* const tint_symbol_1) {
float3 res = tint_frexp(float3(), &(*(tint_symbol_1)));
float3 res = tint_frexp(float3(), tint_symbol_1);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -18,7 +18,7 @@ struct tint_symbol {
};
void frexp_6efa09(thread int3* const tint_symbol_2) {
float3 res = tint_frexp(float3(), &(*(tint_symbol_2)));
float3 res = tint_frexp(float3(), tint_symbol_2);
}
vertex tint_symbol vertex_main() {

View File

@ -18,7 +18,7 @@ struct tint_symbol {
};
void frexp_a2a617(thread int* const tint_symbol_2) {
float res = tint_frexp(1.0f, &(*(tint_symbol_2)));
float res = tint_frexp(1.0f, tint_symbol_2);
}
vertex tint_symbol vertex_main() {

View File

@ -14,7 +14,7 @@ float2 tint_frexp(float2 param_0, threadgroup int2* param_1) {
}
void frexp_a3f940(threadgroup int2* const tint_symbol_1) {
float2 res = tint_frexp(float2(), &(*(tint_symbol_1)));
float2 res = tint_frexp(float2(), tint_symbol_1);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -18,7 +18,7 @@ struct tint_symbol {
};
void frexp_b45525(thread int4* const tint_symbol_2) {
float4 res = tint_frexp(float4(), &(*(tint_symbol_2)));
float4 res = tint_frexp(float4(), tint_symbol_2);
}
vertex tint_symbol vertex_main() {

View File

@ -14,7 +14,7 @@ float4 tint_frexp(float4 param_0, threadgroup int4* param_1) {
}
void frexp_b87f4e(threadgroup int4* const tint_symbol_1) {
float4 res = tint_frexp(float4(), &(*(tint_symbol_1)));
float4 res = tint_frexp(float4(), tint_symbol_1);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -18,7 +18,7 @@ struct tint_symbol {
};
void frexp_c084e3(thread int2* const tint_symbol_2) {
float2 res = tint_frexp(float2(), &(*(tint_symbol_2)));
float2 res = tint_frexp(float2(), tint_symbol_2);
}
vertex tint_symbol vertex_main() {

View File

@ -14,7 +14,7 @@ float4 tint_modf(float4 param_0, threadgroup float4* param_1) {
}
void modf_1d59e5(threadgroup float4* const tint_symbol_1) {
float4 res = tint_modf(float4(), &(*(tint_symbol_1)));
float4 res = tint_modf(float4(), tint_symbol_1);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -18,7 +18,7 @@ struct tint_symbol {
};
void modf_3d00e2(thread float4* const tint_symbol_2) {
float4 res = tint_modf(float4(), &(*(tint_symbol_2)));
float4 res = tint_modf(float4(), tint_symbol_2);
}
vertex tint_symbol vertex_main() {

View File

@ -18,7 +18,7 @@ struct tint_symbol {
};
void modf_5e8476(thread float* const tint_symbol_2) {
float res = tint_modf(1.0f, &(*(tint_symbol_2)));
float res = tint_modf(1.0f, tint_symbol_2);
}
vertex tint_symbol vertex_main() {

View File

@ -18,7 +18,7 @@ struct tint_symbol {
};
void modf_9c6a91(thread float2* const tint_symbol_2) {
float2 res = tint_modf(float2(), &(*(tint_symbol_2)));
float2 res = tint_modf(float2(), tint_symbol_2);
}
vertex tint_symbol vertex_main() {

View File

@ -18,7 +18,7 @@ struct tint_symbol {
};
void modf_9cecfc(thread float3* const tint_symbol_2) {
float3 res = tint_modf(float3(), &(*(tint_symbol_2)));
float3 res = tint_modf(float3(), tint_symbol_2);
}
vertex tint_symbol vertex_main() {

View File

@ -14,7 +14,7 @@ float2 tint_modf(float2 param_0, threadgroup float2* param_1) {
}
void modf_a128ab(threadgroup float2* const tint_symbol_1) {
float2 res = tint_modf(float2(), &(*(tint_symbol_1)));
float2 res = tint_modf(float2(), tint_symbol_1);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -14,7 +14,7 @@ float3 tint_modf(float3 param_0, threadgroup float3* param_1) {
}
void modf_bb9088(threadgroup float3* const tint_symbol_1) {
float3 res = tint_modf(float3(), &(*(tint_symbol_1)));
float3 res = tint_modf(float3(), tint_symbol_1);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -14,7 +14,7 @@ float tint_modf(float param_0, threadgroup float* param_1) {
}
void modf_e38ae6(threadgroup float* const tint_symbol_1) {
float res = tint_modf(1.0f, &(*(tint_symbol_1)));
float res = tint_modf(1.0f, tint_symbol_1);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {