tint: polyfill remainder to handle negative operands

Bug: tint:1802
Change-Id: Ie9baa045feda08523e5ca4f5ce94b6db7d4477e5
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/119100
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
Antonio Maiorano
2023-02-10 15:01:02 +00:00
committed by Dawn LUCI CQ
parent 78583a14e6
commit ec20758675
103 changed files with 1660 additions and 554 deletions

View File

@@ -745,18 +745,52 @@ struct BuiltinPolyfill::State {
}
auto name = b.Symbols().New(is_div ? "tint_div" : "tint_mod");
auto* use_one = b.Equal(rhs, ScalarOrVector(width, 0_a));
auto* rhs_is_zero = b.Equal(rhs, ScalarOrVector(width, 0_a));
if (lhs_ty->is_signed_integer_scalar_or_vector()) {
const auto bits = lhs_el_ty->Size() * 8;
auto min_int = AInt(AInt::kLowestValue >> (AInt::kNumBits - bits));
const ast::Expression* lhs_is_min = b.Equal(lhs, ScalarOrVector(width, min_int));
const ast::Expression* rhs_is_minus_one = b.Equal(rhs, ScalarOrVector(width, -1_a));
// use_one = use_one | ((lhs == MIN_INT) & (rhs == -1))
use_one = b.Or(use_one, b.And(lhs_is_min, rhs_is_minus_one));
}
auto* select = b.Call("select", rhs, ScalarOrVector(width, 1_a), use_one);
// use_one = rhs_is_zero | ((lhs == MIN_INT) & (rhs == -1))
auto* use_one = b.Or(rhs_is_zero, b.And(lhs_is_min, rhs_is_minus_one));
// Special handling for mod in case either operand is negative, as negative operands
// for % is undefined behaviour for most backends (HLSL, MSL, GLSL, SPIR-V).
if (!is_div) {
const char* rhs_or_one = "rhs_or_one";
body.Push(b.Decl(b.Let(
rhs_or_one, b.Call("select", rhs, ScalarOrVector(width, 1_a), use_one))));
// Is either operand negative?
// (lhs | rhs) & (1<<31)
auto sign_bit_mask = ScalarOrVector(width, u32(1 << (bits - 1)));
auto* lhs_or_rhs = CastScalarOrVector<u32>(width, b.Or(lhs, rhs_or_one));
auto* lhs_or_rhs_is_neg =
b.NotEqual(b.And(lhs_or_rhs, sign_bit_mask), ScalarOrVector(width, 0_u));
// lhs - trunc(lhs / rhs) * rhs (note: integral division truncates)
auto* slow_mod = b.Sub(lhs, b.Mul(b.Div(lhs, rhs_or_one), rhs_or_one));
// lhs % rhs
auto* fast_mod = b.Mod(lhs, rhs_or_one);
auto* use_slow = b.Call("any", lhs_or_rhs_is_neg);
body.Push(b.If(use_slow, b.Block(b.Return(slow_mod)),
b.Else(b.Block(b.Return(fast_mod)))));
} else {
auto* rhs_or_one = b.Call("select", rhs, ScalarOrVector(width, 1_a), use_one);
body.Push(b.Return(is_div ? b.Div(lhs, rhs_or_one) : b.Mod(lhs, rhs_or_one)));
}
} else {
auto* rhs_or_one = b.Call("select", rhs, ScalarOrVector(width, 1_a), rhs_is_zero);
body.Push(b.Return(is_div ? b.Div(lhs, rhs_or_one) : b.Mod(lhs, rhs_or_one)));
}
body.Push(b.Return(is_div ? b.Div(lhs, select) : b.Mod(lhs, select)));
b.Func(name,
utils::Vector{
b.Param("lhs", T(lhs_ty)),
@@ -808,6 +842,14 @@ struct BuiltinPolyfill::State {
}
return b.Call(b.ty.vec<T>(width), value);
}
template <typename To>
const ast::Expression* CastScalarOrVector(uint32_t width, const ast::Expression* e) {
if (width == 1) {
return b.Call(b.ty.Of<To>(), e);
}
return b.Call(b.ty.vec<To>(width), e);
}
};
BuiltinPolyfill::BuiltinPolyfill() = default;

View File

@@ -2073,7 +2073,12 @@ fn f() {
auto* expect = R"(
fn tint_mod(lhs : i32, rhs : i32) -> i32 {
return (lhs % select(rhs, 1, ((rhs == 0) | ((lhs == -2147483648) & (rhs == -1)))));
let rhs_or_one = select(rhs, 1, ((rhs == 0) | ((lhs == -2147483648) & (rhs == -1))));
if (any(((u32((lhs | rhs_or_one)) & 2147483648u) != 0u))) {
return (lhs - ((lhs / rhs_or_one) * rhs_or_one));
} else {
return (lhs % rhs_or_one);
}
}
fn f() {
@@ -2121,7 +2126,12 @@ fn f() {
auto* expect = R"(
fn tint_mod(lhs : i32, rhs : i32) -> i32 {
return (lhs % select(rhs, 1, ((rhs == 0) | ((lhs == -2147483648) & (rhs == -1)))));
let rhs_or_one = select(rhs, 1, ((rhs == 0) | ((lhs == -2147483648) & (rhs == -1))));
if (any(((u32((lhs | rhs_or_one)) & 2147483648u) != 0u))) {
return (lhs - ((lhs / rhs_or_one) * rhs_or_one));
} else {
return (lhs % rhs_or_one);
}
}
fn f() {
@@ -2169,7 +2179,12 @@ fn f() {
auto* expect = R"(
fn tint_mod(lhs : i32, rhs : i32) -> i32 {
return (lhs % select(rhs, 1, ((rhs == 0) | ((lhs == -2147483648) & (rhs == -1)))));
let rhs_or_one = select(rhs, 1, ((rhs == 0) | ((lhs == -2147483648) & (rhs == -1))));
if (any(((u32((lhs | rhs_or_one)) & 2147483648u) != 0u))) {
return (lhs - ((lhs / rhs_or_one) * rhs_or_one));
} else {
return (lhs % rhs_or_one);
}
}
fn f() {
@@ -2363,7 +2378,12 @@ fn f() {
auto* expect = R"(
fn tint_mod(lhs : vec3<i32>, rhs : i32) -> vec3<i32> {
let r = vec3<i32>(rhs);
return (lhs % select(r, vec3(1), ((r == vec3(0)) | ((lhs == vec3(-2147483648)) & (r == vec3(-1))))));
let rhs_or_one = select(r, vec3(1), ((r == vec3(0)) | ((lhs == vec3(-2147483648)) & (r == vec3(-1)))));
if (any(((vec3<u32>((lhs | rhs_or_one)) & vec3<u32>(2147483648u)) != vec3<u32>(0u)))) {
return (lhs - ((lhs / rhs_or_one) * rhs_or_one));
} else {
return (lhs % rhs_or_one);
}
}
fn f() {
@@ -2413,7 +2433,12 @@ fn f() {
auto* expect = R"(
fn tint_mod(lhs : vec3<i32>, rhs : i32) -> vec3<i32> {
let r = vec3<i32>(rhs);
return (lhs % select(r, vec3(1), ((r == vec3(0)) | ((lhs == vec3(-2147483648)) & (r == vec3(-1))))));
let rhs_or_one = select(r, vec3(1), ((r == vec3(0)) | ((lhs == vec3(-2147483648)) & (r == vec3(-1)))));
if (any(((vec3<u32>((lhs | rhs_or_one)) & vec3<u32>(2147483648u)) != vec3<u32>(0u)))) {
return (lhs - ((lhs / rhs_or_one) * rhs_or_one));
} else {
return (lhs % rhs_or_one);
}
}
fn f() {
@@ -2463,7 +2488,12 @@ fn f() {
auto* expect = R"(
fn tint_mod(lhs : vec3<i32>, rhs : i32) -> vec3<i32> {
let r = vec3<i32>(rhs);
return (lhs % select(r, vec3(1), ((r == vec3(0)) | ((lhs == vec3(-2147483648)) & (r == vec3(-1))))));
let rhs_or_one = select(r, vec3(1), ((r == vec3(0)) | ((lhs == vec3(-2147483648)) & (r == vec3(-1)))));
if (any(((vec3<u32>((lhs | rhs_or_one)) & vec3<u32>(2147483648u)) != vec3<u32>(0u)))) {
return (lhs - ((lhs / rhs_or_one) * rhs_or_one));
} else {
return (lhs % rhs_or_one);
}
}
fn f() {
@@ -2563,7 +2593,12 @@ fn f() {
auto* expect = R"(
fn tint_mod(lhs : i32, rhs : vec3<i32>) -> vec3<i32> {
let l = vec3<i32>(lhs);
return (l % select(rhs, vec3(1), ((rhs == vec3(0)) | ((l == vec3(-2147483648)) & (rhs == vec3(-1))))));
let rhs_or_one = select(rhs, vec3(1), ((rhs == vec3(0)) | ((l == vec3(-2147483648)) & (rhs == vec3(-1)))));
if (any(((vec3<u32>((l | rhs_or_one)) & vec3<u32>(2147483648u)) != vec3<u32>(0u)))) {
return (l - ((l / rhs_or_one) * rhs_or_one));
} else {
return (l % rhs_or_one);
}
}
fn f() {
@@ -2613,7 +2648,12 @@ fn f() {
auto* expect = R"(
fn tint_mod(lhs : i32, rhs : vec3<i32>) -> vec3<i32> {
let l = vec3<i32>(lhs);
return (l % select(rhs, vec3(1), ((rhs == vec3(0)) | ((l == vec3(-2147483648)) & (rhs == vec3(-1))))));
let rhs_or_one = select(rhs, vec3(1), ((rhs == vec3(0)) | ((l == vec3(-2147483648)) & (rhs == vec3(-1)))));
if (any(((vec3<u32>((l | rhs_or_one)) & vec3<u32>(2147483648u)) != vec3<u32>(0u)))) {
return (l - ((l / rhs_or_one) * rhs_or_one));
} else {
return (l % rhs_or_one);
}
}
fn f() {
@@ -2711,7 +2751,12 @@ fn f() {
auto* expect = R"(
fn tint_mod(lhs : vec3<i32>, rhs : vec3<i32>) -> vec3<i32> {
return (lhs % select(rhs, vec3(1), ((rhs == vec3(0)) | ((lhs == vec3(-2147483648)) & (rhs == vec3(-1))))));
let rhs_or_one = select(rhs, vec3(1), ((rhs == vec3(0)) | ((lhs == vec3(-2147483648)) & (rhs == vec3(-1)))));
if (any(((vec3<u32>((lhs | rhs_or_one)) & vec3<u32>(2147483648u)) != vec3<u32>(0u)))) {
return (lhs - ((lhs / rhs_or_one) * rhs_or_one));
} else {
return (lhs % rhs_or_one);
}
}
fn f() {