[msl] Handle packed conversions in shift expressions.

Similar to the handling of packed values in the arithmetic operators
the shift operators need to cast to the unpacked type before doing the
as_type casts.

Bug: tint:1542
Change-Id: I4289c45ab0a067ce122f61675fe5e251a83b6f8b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/105720
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
dan sinclair 2022-10-13 15:28:51 +00:00 committed by Dawn LUCI CQ
parent 840e42477d
commit d5139b4463
10 changed files with 132 additions and 4 deletions

View File

@ -590,8 +590,18 @@ bool GeneratorImpl::EmitBinary(std::ostream& out, const ast::BinaryExpression* e
ScopedParen sp(out);
{
ScopedBitCast lhs_uint_cast(this, out, lhs_type, unsigned_type_of(lhs_type));
if (!EmitExpression(out, expr->lhs)) {
return false;
// In case the type is packed, cast to our own type in order to remove the packing.
// Otherwise, this just casts to itself.
if (lhs_type->is_signed_integer_vector()) {
ScopedCast lhs_self_cast(this, out, lhs_type, lhs_type);
if (!EmitExpression(out, expr->lhs)) {
return false;
}
} else {
if (!EmitExpression(out, expr->lhs)) {
return false;
}
}
}
if (!emit_op()) {

View File

@ -0,0 +1,11 @@
struct UniformBuffer {
d: vec3<i32>,
}
@group(0) @binding(0)
var<uniform> u_input: UniformBuffer;
@compute @workgroup_size(1)
fn main() {
let temp: vec3<i32> = (u_input.d << vec3<u32>());
}

View File

@ -0,0 +1,9 @@
cbuffer cbuffer_u_input : register(b0, space0) {
uint4 u_input[1];
};
[numthreads(1, 1, 1)]
void main() {
const int3 temp = (asint(u_input[0].xyz) << (0u).xxx);
return;
}

View File

@ -0,0 +1,9 @@
cbuffer cbuffer_u_input : register(b0, space0) {
uint4 u_input[1];
};
[numthreads(1, 1, 1)]
void main() {
const int3 temp = (asint(u_input[0].xyz) << (0u).xxx);
return;
}

View File

@ -0,0 +1,16 @@
#version 310 es
layout(binding = 0, std140) uniform UniformBuffer_ubo {
ivec3 d;
uint pad;
} u_input;
void tint_symbol() {
ivec3 temp = (u_input.d << uvec3(0u));
}
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void main() {
tint_symbol();
return;
}

View File

@ -0,0 +1,26 @@
#include <metal_stdlib>
using namespace metal;
template<typename T, size_t N>
struct tint_array {
const constant T& operator[](size_t i) const constant { return elements[i]; }
device T& operator[](size_t i) device { return elements[i]; }
const device T& operator[](size_t i) const device { return elements[i]; }
thread T& operator[](size_t i) thread { return elements[i]; }
const thread T& operator[](size_t i) const thread { return elements[i]; }
threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
T elements[N];
};
struct UniformBuffer {
/* 0x0000 */ packed_int3 d;
/* 0x000c */ tint_array<int8_t, 4> tint_pad;
};
kernel void tint_symbol(const constant UniformBuffer* tint_symbol_1 [[buffer(0)]]) {
int3 const temp = as_type<int3>((as_type<uint3>(int3((*(tint_symbol_1)).d)) << uint3(0u)));
return;
}

View File

@ -0,0 +1,37 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 18
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 1 1 1
OpName %UniformBuffer "UniformBuffer"
OpMemberName %UniformBuffer 0 "d"
OpName %u_input "u_input"
OpName %main "main"
OpDecorate %UniformBuffer Block
OpMemberDecorate %UniformBuffer 0 Offset 0
OpDecorate %u_input NonWritable
OpDecorate %u_input DescriptorSet 0
OpDecorate %u_input Binding 0
%int = OpTypeInt 32 1
%v3int = OpTypeVector %int 3
%UniformBuffer = OpTypeStruct %v3int
%_ptr_Uniform_UniformBuffer = OpTypePointer Uniform %UniformBuffer
%u_input = OpVariable %_ptr_Uniform_UniformBuffer Uniform
%void = OpTypeVoid
%6 = OpTypeFunction %void
%uint = OpTypeInt 32 0
%uint_0 = OpConstant %uint 0
%_ptr_Uniform_v3int = OpTypePointer Uniform %v3int
%v3uint = OpTypeVector %uint 3
%16 = OpConstantNull %v3uint
%main = OpFunction %void None %6
%9 = OpLabel
%13 = OpAccessChain %_ptr_Uniform_v3int %u_input %uint_0
%14 = OpLoad %v3int %13
%17 = OpShiftLeftLogical %v3int %14 %16
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,10 @@
struct UniformBuffer {
d : vec3<i32>,
}
@group(0) @binding(0) var<uniform> u_input : UniformBuffer;
@compute @workgroup_size(1)
fn main() {
let temp : vec3<i32> = (u_input.d << vec3<u32>());
}

View File

@ -4,7 +4,7 @@ using namespace metal;
kernel void f() {
int3 const a = int3(1, 2, 3);
uint3 const b = uint3(4u, 5u, 6u);
int3 const r = as_type<int3>((as_type<uint3>(a) << b));
int3 const r = as_type<int3>((as_type<uint3>(int3(a)) << b));
return;
}

View File

@ -6,6 +6,6 @@ struct S {
};
void foo(device S* const tint_symbol) {
(*(tint_symbol)).a = as_type<int4>((as_type<uint4>((*(tint_symbol)).a) << uint4(2u)));
(*(tint_symbol)).a = as_type<int4>((as_type<uint4>(int4((*(tint_symbol)).a)) << uint4(2u)));
}