Update MSL generator to handle casts of packed types.

Currently in the MSL backend we cast int values to uint in order to get
the correct WGSL behaviour for over/under flow. This fails in the case
of host shareable buffers as they use `packed` types which need to get
cast to the non-packed version first.

Bug: tint:1677
Change-Id: I57b70abaa8ca614472a26d63f19c1aef2bd64668
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/103986
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
dan sinclair 2022-09-29 19:44:58 +00:00 committed by Dawn LUCI CQ
parent d1a5f93630
commit 2bcade246a
27 changed files with 197 additions and 30 deletions

View File

@ -517,8 +517,18 @@ bool GeneratorImpl::EmitBinary(std::ostream& out, const ast::BinaryExpression* e
ScopedParen sp(out); ScopedParen sp(out);
{ {
ScopedBitCast lhs_uint_cast(this, out, lhs_type, unsigned_type_of(target_type)); ScopedBitCast lhs_uint_cast(this, out, lhs_type, unsigned_type_of(target_type));
if (!EmitExpression(out, expr->lhs)) {
return false; // In case the type is packed, cast to our own type in order to remove the packing.
// Otherwise, this just casts to itself.
if (lhs_type->is_signed_integer_vector()) {
ScopedBitCast lhs_self_cast(this, out, lhs_type, lhs_type);
if (!EmitExpression(out, expr->lhs)) {
return false;
}
} else {
if (!EmitExpression(out, expr->lhs)) {
return false;
}
} }
} }
if (!emit_op()) { if (!emit_op()) {
@ -526,8 +536,18 @@ bool GeneratorImpl::EmitBinary(std::ostream& out, const ast::BinaryExpression* e
} }
{ {
ScopedBitCast rhs_uint_cast(this, out, rhs_type, unsigned_type_of(target_type)); ScopedBitCast rhs_uint_cast(this, out, rhs_type, unsigned_type_of(target_type));
if (!EmitExpression(out, expr->rhs)) {
return false; // In case the type is packed, cast to our own type in order to remove the packing.
// Otherwise, this just casts to itself.
if (rhs_type->is_signed_integer_vector()) {
ScopedBitCast rhs_self_cast(this, out, rhs_type, rhs_type);
if (!EmitExpression(out, expr->rhs)) {
return false;
}
} else {
if (!EmitExpression(out, expr->rhs)) {
return false;
}
} }
} }
return true; return true;

View File

@ -88,7 +88,7 @@ void tint_symbol_inner(uint3 GlobalInvocationID, const constant Config* const ti
for(int x_1 = 0; (x_1 < TILE_COUNT_X); x_1 = as_type<int>((as_type<uint>(x_1) + as_type<uint>(1)))) { for(int x_1 = 0; (x_1 < TILE_COUNT_X); x_1 = as_type<int>((as_type<uint>(x_1) + as_type<uint>(1)))) {
int2 tilePixel0Idx = int2(as_type<int>((as_type<uint>(x_1) * as_type<uint>(TILE_SIZE))), as_type<int>((as_type<uint>(y_1) * as_type<uint>(TILE_SIZE)))); int2 tilePixel0Idx = int2(as_type<int>((as_type<uint>(x_1) * as_type<uint>(TILE_SIZE))), as_type<int>((as_type<uint>(y_1) * as_type<uint>(TILE_SIZE))));
float2 floorCoord = (((2.0f * float2(tilePixel0Idx)) / float4((*(tint_symbol_3)).fullScreenSize).xy) - float2(1.0f)); float2 floorCoord = (((2.0f * float2(tilePixel0Idx)) / float4((*(tint_symbol_3)).fullScreenSize).xy) - float2(1.0f));
float2 ceilCoord = (((2.0f * float2(as_type<int2>((as_type<uint2>(tilePixel0Idx) + as_type<uint2>(int2(TILE_SIZE)))))) / float4((*(tint_symbol_3)).fullScreenSize).xy) - float2(1.0f)); float2 ceilCoord = (((2.0f * float2(as_type<int2>((as_type<uint2>(as_type<int2>(tilePixel0Idx)) + as_type<uint2>(as_type<int2>(int2(TILE_SIZE))))))) / float4((*(tint_symbol_3)).fullScreenSize).xy) - float2(1.0f));
float2 viewFloorCoord = float2((((-(viewNear) * floorCoord[0]) - (M[2][0] * viewNear)) / M[0][0]), (((-(viewNear) * floorCoord[1]) - (M[2][1] * viewNear)) / M[1][1])); float2 viewFloorCoord = float2((((-(viewNear) * floorCoord[0]) - (M[2][0] * viewNear)) / M[0][0]), (((-(viewNear) * floorCoord[1]) - (M[2][1] * viewNear)) / M[1][1]));
float2 viewCeilCoord = float2((((-(viewNear) * ceilCoord[0]) - (M[2][0] * viewNear)) / M[0][0]), (((-(viewNear) * ceilCoord[1]) - (M[2][1] * viewNear)) / M[1][1])); float2 viewCeilCoord = float2((((-(viewNear) * ceilCoord[0]) - (M[2][0] * viewNear)) / M[0][0]), (((-(viewNear) * ceilCoord[1]) - (M[2][1] * viewNear)) / M[1][1]));
frustumPlanes[0] = float4(1.0f, 0.0f, (-(viewFloorCoord[0]) / viewNear), 0.0f); frustumPlanes[0] = float4(1.0f, 0.0f, (-(viewFloorCoord[0]) / viewNear), 0.0f);

View File

@ -45,13 +45,13 @@ bool test_int_S1_c0_b(const constant UniformBuffer* const tint_symbol_5) {
ok = x_41; ok = x_41;
int4 const x_44 = int4(x_27, x_27, x_27, x_27); int4 const x_44 = int4(x_27, x_27, x_27, x_27);
val = x_44; val = x_44;
int4 const x_47 = as_type<int4>((as_type<uint4>(x_44) + as_type<uint4>(int4(1)))); int4 const x_47 = as_type<int4>((as_type<uint4>(as_type<int4>(x_44)) + as_type<uint4>(as_type<int4>(int4(1)))));
val = x_47; val = x_47;
int4 const x_48 = as_type<int4>((as_type<uint4>(x_47) - as_type<uint4>(int4(1)))); int4 const x_48 = as_type<int4>((as_type<uint4>(as_type<int4>(x_47)) - as_type<uint4>(as_type<int4>(int4(1)))));
val = x_48; val = x_48;
int4 const x_49 = as_type<int4>((as_type<uint4>(x_48) + as_type<uint4>(int4(1)))); int4 const x_49 = as_type<int4>((as_type<uint4>(as_type<int4>(x_48)) + as_type<uint4>(as_type<int4>(int4(1)))));
val = x_49; val = x_49;
int4 const x_50 = as_type<int4>((as_type<uint4>(x_49) - as_type<uint4>(int4(1)))); int4 const x_50 = as_type<int4>((as_type<uint4>(as_type<int4>(x_49)) - as_type<uint4>(as_type<int4>(int4(1)))));
val = x_50; val = x_50;
x_55 = false; x_55 = false;
if (x_41) { if (x_41) {
@ -59,11 +59,11 @@ bool test_int_S1_c0_b(const constant UniformBuffer* const tint_symbol_5) {
x_55 = x_54; x_55 = x_54;
} }
ok = x_55; ok = x_55;
int4 const x_58 = as_type<int4>((as_type<uint4>(x_50) * as_type<uint4>(int4(2)))); int4 const x_58 = as_type<int4>((as_type<uint4>(as_type<int4>(x_50)) * as_type<uint4>(as_type<int4>(int4(2)))));
val = x_58; val = x_58;
int4 const x_59 = (x_58 / int4(2)); int4 const x_59 = (x_58 / int4(2));
val = x_59; val = x_59;
int4 const x_60 = as_type<int4>((as_type<uint4>(x_59) * as_type<uint4>(int4(2)))); int4 const x_60 = as_type<int4>((as_type<uint4>(as_type<int4>(x_59)) * as_type<uint4>(as_type<int4>(int4(2)))));
val = x_60; val = x_60;
int4 const x_61 = (x_60 / int4(2)); int4 const x_61 = (x_60 / int4(2));
val = x_61; val = x_61;

View File

@ -0,0 +1,10 @@
struct Input {
position : vec3<i32>,
}
@group(0) @binding(0) var<storage, read> input : Input;
@compute @workgroup_size(1, 1, 1)
fn main(@builtin(global_invocation_id) id : vec3<u32>) {
let pos = input.position - vec3<i32>(0);
}

View File

@ -0,0 +1,15 @@
ByteAddressBuffer input : register(t0, space0);
struct tint_symbol_1 {
uint3 id : SV_DispatchThreadID;
};
void main_inner(uint3 id) {
const int3 pos = (asint(input.Load3(0u)) - (0).xxx);
}
[numthreads(1, 1, 1)]
void main(tint_symbol_1 tint_symbol) {
main_inner(tint_symbol.id);
return;
}

View File

@ -0,0 +1,15 @@
ByteAddressBuffer input : register(t0, space0);
struct tint_symbol_1 {
uint3 id : SV_DispatchThreadID;
};
void main_inner(uint3 id) {
const int3 pos = (asint(input.Load3(0u)) - (0).xxx);
}
[numthreads(1, 1, 1)]
void main(tint_symbol_1 tint_symbol) {
main_inner(tint_symbol.id);
return;
}

View File

@ -0,0 +1,16 @@
#version 310 es
layout(binding = 0, std430) buffer Input_ssbo {
ivec3 position;
uint pad;
} tint_symbol;
void tint_symbol_1(uvec3 id) {
ivec3 pos = (tint_symbol.position - ivec3(0));
}
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void main() {
tint_symbol_1(gl_GlobalInvocationID);
return;
}

View File

@ -0,0 +1,30 @@
#include <metal_stdlib>
using namespace metal;
template<typename T, size_t N>
struct tint_array {
const constant T& operator[](size_t i) const constant { return elements[i]; }
device T& operator[](size_t i) device { return elements[i]; }
const device T& operator[](size_t i) const device { return elements[i]; }
thread T& operator[](size_t i) thread { return elements[i]; }
const thread T& operator[](size_t i) const thread { return elements[i]; }
threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
T elements[N];
};
struct Input {
/* 0x0000 */ packed_int3 position;
/* 0x000c */ tint_array<int8_t, 4> tint_pad;
};
void tint_symbol_inner(uint3 id, const device Input* const tint_symbol_1) {
int3 const pos = as_type<int3>((as_type<uint3>(as_type<int3>((*(tint_symbol_1)).position)) - as_type<uint3>(as_type<int3>(int3(0)))));
}
kernel void tint_symbol(const device Input* tint_symbol_2 [[buffer(0)]], uint3 id [[thread_position_in_grid]]) {
tint_symbol_inner(id, tint_symbol_2);
return;
}

View File

@ -0,0 +1,51 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 26
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main" %id_1
OpExecutionMode %main LocalSize 1 1 1
OpName %id_1 "id_1"
OpName %Input "Input"
OpMemberName %Input 0 "position"
OpName %input "input"
OpName %main_inner "main_inner"
OpName %id "id"
OpName %main "main"
OpDecorate %id_1 BuiltIn GlobalInvocationId
OpDecorate %Input Block
OpMemberDecorate %Input 0 Offset 0
OpDecorate %input NonWritable
OpDecorate %input DescriptorSet 0
OpDecorate %input Binding 0
%uint = OpTypeInt 32 0
%v3uint = OpTypeVector %uint 3
%_ptr_Input_v3uint = OpTypePointer Input %v3uint
%id_1 = OpVariable %_ptr_Input_v3uint Input
%int = OpTypeInt 32 1
%v3int = OpTypeVector %int 3
%Input = OpTypeStruct %v3int
%_ptr_StorageBuffer_Input = OpTypePointer StorageBuffer %Input
%input = OpVariable %_ptr_StorageBuffer_Input StorageBuffer
%void = OpTypeVoid
%10 = OpTypeFunction %void %v3uint
%uint_0 = OpConstant %uint 0
%_ptr_StorageBuffer_v3int = OpTypePointer StorageBuffer %v3int
%19 = OpConstantNull %v3int
%21 = OpTypeFunction %void
%main_inner = OpFunction %void None %10
%id = OpFunctionParameter %v3uint
%14 = OpLabel
%17 = OpAccessChain %_ptr_StorageBuffer_v3int %input %uint_0
%18 = OpLoad %v3int %17
%20 = OpISub %v3int %18 %19
OpReturn
OpFunctionEnd
%main = OpFunction %void None %21
%23 = OpLabel
%25 = OpLoad %v3uint %id_1
%24 = OpFunctionCall %void %main_inner %25
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,10 @@
struct Input {
position : vec3<i32>,
}
@group(0) @binding(0) var<storage, read> input : Input;
@compute @workgroup_size(1, 1, 1)
fn main(@builtin(global_invocation_id) id : vec3<u32>) {
let pos = (input.position - vec3<i32>(0));
}

View File

@ -32,10 +32,10 @@ void tint_symbol_inner(uint3 WorkGroupID, uint3 LocalInvocationID, uint local_in
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
uint const filterOffset = (((*(tint_symbol_2)).filterDim - 1u) / 2u); uint const filterOffset = (((*(tint_symbol_2)).filterDim - 1u) / 2u);
int2 const dims = int2(tint_symbol_3.get_width(0), tint_symbol_3.get_height(0)); int2 const dims = int2(tint_symbol_3.get_width(0), tint_symbol_3.get_height(0));
int2 const baseIndex = as_type<int2>((as_type<uint2>(int2(((uint3(WorkGroupID).xy * uint2((*(tint_symbol_2)).blockDim, 4u)) + (uint3(LocalInvocationID).xy * uint2(4u, 1u))))) - as_type<uint2>(int2(int(filterOffset), 0)))); int2 const baseIndex = as_type<int2>((as_type<uint2>(as_type<int2>(int2(((uint3(WorkGroupID).xy * uint2((*(tint_symbol_2)).blockDim, 4u)) + (uint3(LocalInvocationID).xy * uint2(4u, 1u)))))) - as_type<uint2>(as_type<int2>(int2(int(filterOffset), 0)))));
for(uint r = 0u; (r < 4u); r = (r + 1u)) { for(uint r = 0u; (r < 4u); r = (r + 1u)) {
for(uint c = 0u; (c < 4u); c = (c + 1u)) { for(uint c = 0u; (c < 4u); c = (c + 1u)) {
int2 loadIndex = as_type<int2>((as_type<uint2>(baseIndex) + as_type<uint2>(int2(int(c), int(r))))); int2 loadIndex = as_type<int2>((as_type<uint2>(as_type<int2>(baseIndex)) + as_type<uint2>(as_type<int2>(int2(int(c), int(r))))));
if (((*(tint_symbol_4)).value != 0u)) { if (((*(tint_symbol_4)).value != 0u)) {
loadIndex = int2(loadIndex).yx; loadIndex = int2(loadIndex).yx;
} }
@ -45,7 +45,7 @@ void tint_symbol_inner(uint3 WorkGroupID, uint3 LocalInvocationID, uint local_in
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
for(uint r = 0u; (r < 4u); r = (r + 1u)) { for(uint r = 0u; (r < 4u); r = (r + 1u)) {
for(uint c = 0u; (c < 4u); c = (c + 1u)) { for(uint c = 0u; (c < 4u); c = (c + 1u)) {
int2 writeIndex = as_type<int2>((as_type<uint2>(baseIndex) + as_type<uint2>(int2(int(c), int(r))))); int2 writeIndex = as_type<int2>((as_type<uint2>(as_type<int2>(baseIndex)) + as_type<uint2>(as_type<int2>(int2(int(c), int(r))))));
if (((*(tint_symbol_4)).value != 0u)) { if (((*(tint_symbol_4)).value != 0u)) {
writeIndex = int2(writeIndex).yx; writeIndex = int2(writeIndex).yx;
} }

View File

@ -4,7 +4,7 @@ using namespace metal;
kernel void f() { kernel void f() {
int const a = 4; int const a = 4;
int3 const b = int3(1, 2, 3); int3 const b = int3(1, 2, 3);
int3 const r = as_type<int3>((as_type<uint>(a) + as_type<uint3>(b))); int3 const r = as_type<int3>((as_type<uint>(a) + as_type<uint3>(as_type<int3>(b))));
return; return;
} }

View File

@ -4,7 +4,7 @@ using namespace metal;
kernel void f() { kernel void f() {
int3 const a = int3(1, 2, 3); int3 const a = int3(1, 2, 3);
int const b = 4; int const b = 4;
int3 const r = as_type<int3>((as_type<uint3>(a) + as_type<uint>(b))); int3 const r = as_type<int3>((as_type<uint3>(as_type<int3>(a)) + as_type<uint>(b)));
return; return;
} }

View File

@ -4,7 +4,7 @@ using namespace metal;
kernel void f() { kernel void f() {
int3 const a = int3(1, 2, 3); int3 const a = int3(1, 2, 3);
int3 const b = int3(4, 5, 6); int3 const b = int3(4, 5, 6);
int3 const r = as_type<int3>((as_type<uint3>(a) + as_type<uint3>(b))); int3 const r = as_type<int3>((as_type<uint3>(as_type<int3>(a)) + as_type<uint3>(as_type<int3>(b))));
return; return;
} }

View File

@ -4,7 +4,7 @@ using namespace metal;
kernel void f() { kernel void f() {
int a = 4; int a = 4;
int3 b = int3(0, 2, 0); int3 b = int3(0, 2, 0);
int3 const r = (a / as_type<int3>((as_type<uint3>(b) + as_type<uint3>(b)))); int3 const r = (a / as_type<int3>((as_type<uint3>(as_type<int3>(b)) + as_type<uint3>(as_type<int3>(b)))));
return; return;
} }

View File

@ -4,7 +4,7 @@ using namespace metal;
kernel void f() { kernel void f() {
int3 a = int3(1, 2, 3); int3 a = int3(1, 2, 3);
int3 b = int3(0, 5, 0); int3 b = int3(0, 5, 0);
int3 const r = (a / as_type<int3>((as_type<uint3>(b) + as_type<uint3>(b)))); int3 const r = (a / as_type<int3>((as_type<uint3>(as_type<int3>(b)) + as_type<uint3>(as_type<int3>(b)))));
return; return;
} }

View File

@ -4,7 +4,7 @@ using namespace metal;
kernel void f() { kernel void f() {
int a = 4; int a = 4;
int3 b = int3(0, 2, 0); int3 b = int3(0, 2, 0);
int3 const r = (a % as_type<int3>((as_type<uint3>(b) + as_type<uint3>(b)))); int3 const r = (a % as_type<int3>((as_type<uint3>(as_type<int3>(b)) + as_type<uint3>(as_type<int3>(b)))));
return; return;
} }

View File

@ -4,7 +4,7 @@ using namespace metal;
kernel void f() { kernel void f() {
int3 a = int3(1, 2, 3); int3 a = int3(1, 2, 3);
int3 b = int3(0, 5, 0); int3 b = int3(0, 5, 0);
int3 const r = (a % as_type<int3>((as_type<uint3>(b) + as_type<uint3>(b)))); int3 const r = (a % as_type<int3>((as_type<uint3>(as_type<int3>(b)) + as_type<uint3>(as_type<int3>(b)))));
return; return;
} }

View File

@ -4,7 +4,7 @@ using namespace metal;
kernel void f() { kernel void f() {
int const a = 4; int const a = 4;
int3 const b = int3(1, 2, 3); int3 const b = int3(1, 2, 3);
int3 const r = as_type<int3>((as_type<uint>(a) * as_type<uint3>(b))); int3 const r = as_type<int3>((as_type<uint>(a) * as_type<uint3>(as_type<int3>(b))));
return; return;
} }

View File

@ -4,7 +4,7 @@ using namespace metal;
kernel void f() { kernel void f() {
int3 const a = int3(1, 2, 3); int3 const a = int3(1, 2, 3);
int const b = 4; int const b = 4;
int3 const r = as_type<int3>((as_type<uint3>(a) * as_type<uint>(b))); int3 const r = as_type<int3>((as_type<uint3>(as_type<int3>(a)) * as_type<uint>(b)));
return; return;
} }

View File

@ -4,7 +4,7 @@ using namespace metal;
kernel void f() { kernel void f() {
int3 const a = int3(1, 2, 3); int3 const a = int3(1, 2, 3);
int3 const b = int3(4, 5, 6); int3 const b = int3(4, 5, 6);
int3 const r = as_type<int3>((as_type<uint3>(a) * as_type<uint3>(b))); int3 const r = as_type<int3>((as_type<uint3>(as_type<int3>(a)) * as_type<uint3>(as_type<int3>(b))));
return; return;
} }

View File

@ -4,7 +4,7 @@ using namespace metal;
kernel void f() { kernel void f() {
int const a = 4; int const a = 4;
int3 const b = int3(1, 2, 3); int3 const b = int3(1, 2, 3);
int3 const r = as_type<int3>((as_type<uint>(a) - as_type<uint3>(b))); int3 const r = as_type<int3>((as_type<uint>(a) - as_type<uint3>(as_type<int3>(b))));
return; return;
} }

View File

@ -4,7 +4,7 @@ using namespace metal;
kernel void f() { kernel void f() {
int3 const a = int3(1, 2, 3); int3 const a = int3(1, 2, 3);
int const b = 4; int const b = 4;
int3 const r = as_type<int3>((as_type<uint3>(a) - as_type<uint>(b))); int3 const r = as_type<int3>((as_type<uint3>(as_type<int3>(a)) - as_type<uint>(b)));
return; return;
} }

View File

@ -4,7 +4,7 @@ using namespace metal;
kernel void f() { kernel void f() {
int3 const a = int3(1, 2, 3); int3 const a = int3(1, 2, 3);
int3 const b = int3(4, 5, 6); int3 const b = int3(4, 5, 6);
int3 const r = as_type<int3>((as_type<uint3>(a) - as_type<uint3>(b))); int3 const r = as_type<int3>((as_type<uint3>(as_type<int3>(a)) - as_type<uint3>(as_type<int3>(b))));
return; return;
} }

View File

@ -6,6 +6,6 @@ struct S {
}; };
void foo(device S* const tint_symbol) { void foo(device S* const tint_symbol) {
(*(tint_symbol)).a = as_type<int4>((as_type<uint4>((*(tint_symbol)).a) - as_type<uint4>(int4(2)))); (*(tint_symbol)).a = as_type<int4>((as_type<uint4>(as_type<int4>((*(tint_symbol)).a)) - as_type<uint4>(as_type<int4>(int4(2)))));
} }

View File

@ -6,6 +6,6 @@ struct S {
}; };
void foo(device S* const tint_symbol) { void foo(device S* const tint_symbol) {
(*(tint_symbol)).a = as_type<int4>((as_type<uint4>((*(tint_symbol)).a) + as_type<uint4>(int4(2)))); (*(tint_symbol)).a = as_type<int4>((as_type<uint4>(as_type<int4>((*(tint_symbol)).a)) + as_type<uint4>(as_type<int4>(int4(2)))));
} }

View File

@ -6,6 +6,6 @@ struct S {
}; };
void foo(device S* const tint_symbol) { void foo(device S* const tint_symbol) {
(*(tint_symbol)).a = as_type<int4>((as_type<uint4>((*(tint_symbol)).a) * as_type<uint4>(int4(2)))); (*(tint_symbol)).a = as_type<int4>((as_type<uint4>(as_type<int4>((*(tint_symbol)).a)) * as_type<uint4>(as_type<int4>(int4(2)))));
} }