MSL: fix i32 INT_MIN literal emitted as `long` instead of `int`

Bug: tint:124
Change-Id: Ie632b78cd67948b65e823f0a3c52fda7ef7343f3
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/60440
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
Antonio Maiorano 2021-07-30 18:56:26 +00:00 committed by Tint LUCI CQ
parent 2c1fbe801b
commit 9bdf2dcc6b
12 changed files with 80 additions and 5 deletions

View File

@ -17,6 +17,7 @@
#include <algorithm>
#include <cmath>
#include <iomanip>
#include <limits>
#include <utility>
#include <vector>
@ -1246,7 +1247,17 @@ bool GeneratorImpl::EmitLiteral(std::ostream& out, ast::Literal* lit) {
out << FloatToString(fl->value()) << "f";
}
} else if (auto* sl = lit->As<ast::SintLiteral>()) {
out << sl->value();
// MSL (and C++) parse `-2147483648` as a `long` because it parses unary
// minus and `2147483648` as separate tokens, and the latter doesn't
// fit into an (32-bit) `int`. WGSL, OTOH, parses this as an `i32`. To avoid
// issues with `long` to `int` casts, emit `(2147483647 - 1)` instead, which
// ensures the expression type is `int`.
const auto int_min = std::numeric_limits<int32_t>::min();
if (sl->value_as_i32() == int_min) {
out << "(" << int_min + 1 << " - 1)";
} else {
out << sl->value();
}
} else if (auto* ul = lit->As<ast::UintLiteral>()) {
out << ul->value() << "u";
} else {

View File

@ -43,6 +43,17 @@ TEST_F(MslGeneratorImplTest, EmitExpression_Cast_Vector) {
EXPECT_EQ(out.str(), "float3(int3(1, 2, 3))");
}
TEST_F(MslGeneratorImplTest, EmitExpression_Cast_IntMin) {
auto* cast = Construct<u32>(std::numeric_limits<int32_t>::min());
WrapInFunction(cast);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, cast)) << gen.error();
EXPECT_EQ(out.str(), "uint((-2147483647 - 1))");
}
} // namespace
} // namespace msl
} // namespace writer

View File

@ -88,6 +88,18 @@ TEST_F(MslUnaryOpTest, Negation) {
EXPECT_EQ(out.str(), "-(expr)");
}
TEST_F(MslUnaryOpTest, NegationOfIntMin) {
auto* op = create<ast::UnaryOpExpression>(
ast::UnaryOp::kNegation, Expr(std::numeric_limits<int32_t>::min()));
WrapInFunction(op);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, op)) << gen.error();
EXPECT_EQ(out.str(), "-((-2147483647 - 1))");
}
} // namespace
} // namespace msl
} // namespace writer

View File

@ -0,0 +1,4 @@
[[stage(compute), workgroup_size(1)]]
fn f() {
let b : u32 = bitcast<u32>(-2147483648);
}

View File

@ -0,0 +1,5 @@
[numthreads(1, 1, 1)]
void f() {
const uint b = asuint(-2147483648);
return;
}

View File

@ -0,0 +1,8 @@
#include <metal_stdlib>
using namespace metal;
kernel void f() {
uint const b = as_type<uint>((-2147483647 - 1));
return;
}

View File

@ -0,0 +1,20 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 9
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %f "f"
OpExecutionMode %f LocalSize 1 1 1
OpName %f "f"
%void = OpTypeVoid
%1 = OpTypeFunction %void
%uint = OpTypeInt 32 0
%int = OpTypeInt 32 1
%int_n2147483648 = OpConstant %int -2147483648
%f = OpFunction %void None %1
%4 = OpLabel
%5 = OpBitcast %uint %int_n2147483648
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,4 @@
[[stage(compute), workgroup_size(1)]]
fn f() {
let b : u32 = bitcast<u32>(-2147483648);
}

View File

@ -14,7 +14,7 @@ struct tint_symbol_1 {
void main_1(constant buf0& x_7, thread float4* const tint_symbol_4) {
int minValue = 0;
int negMinValue = 0;
minValue = -2147483648;
minValue = (-2147483647 - 1);
int const x_25 = minValue;
negMinValue = -(x_25);
int const x_27 = negMinValue;

View File

@ -14,7 +14,7 @@ struct tint_symbol_1 {
void main_1(constant buf0& x_7, thread float4* const tint_symbol_4) {
int minValue = 0;
int negMinValue = 0;
minValue = -2147483648;
minValue = (-2147483647 - 1);
int const x_25 = minValue;
negMinValue = -(x_25);
int const x_27 = negMinValue;

View File

@ -69,7 +69,7 @@ void main_1(constant buf1& x_6, constant buf0& x_9, thread float4* const tint_sy
}
case 0: {
int const x_70 = i;
if ((-2147483648 < x_70)) {
if (((-2147483647 - 1) < x_70)) {
{
int const x_82 = j;
j = (x_82 + 1);

View File

@ -69,7 +69,7 @@ void main_1(constant buf1& x_6, constant buf0& x_9, thread float4* const tint_sy
}
case 0: {
int const x_70 = i;
if ((-2147483648 < x_70)) {
if (((-2147483647 - 1) < x_70)) {
{
int const x_82 = j;
j = (x_82 + 1);