mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-11 22:44:04 +00:00
tint: polyfill remainder to handle negative operands
Bug: tint:1802 Change-Id: Ie9baa045feda08523e5ca4f5ce94b6db7d4477e5 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/119100 Commit-Queue: Antonio Maiorano <amaiorano@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
committed by
Dawn LUCI CQ
parent
78583a14e6
commit
ec20758675
@@ -745,18 +745,52 @@ struct BuiltinPolyfill::State {
|
||||
}
|
||||
|
||||
auto name = b.Symbols().New(is_div ? "tint_div" : "tint_mod");
|
||||
auto* use_one = b.Equal(rhs, ScalarOrVector(width, 0_a));
|
||||
|
||||
auto* rhs_is_zero = b.Equal(rhs, ScalarOrVector(width, 0_a));
|
||||
|
||||
if (lhs_ty->is_signed_integer_scalar_or_vector()) {
|
||||
const auto bits = lhs_el_ty->Size() * 8;
|
||||
auto min_int = AInt(AInt::kLowestValue >> (AInt::kNumBits - bits));
|
||||
const ast::Expression* lhs_is_min = b.Equal(lhs, ScalarOrVector(width, min_int));
|
||||
const ast::Expression* rhs_is_minus_one = b.Equal(rhs, ScalarOrVector(width, -1_a));
|
||||
// use_one = use_one | ((lhs == MIN_INT) & (rhs == -1))
|
||||
use_one = b.Or(use_one, b.And(lhs_is_min, rhs_is_minus_one));
|
||||
}
|
||||
auto* select = b.Call("select", rhs, ScalarOrVector(width, 1_a), use_one);
|
||||
// use_one = rhs_is_zero | ((lhs == MIN_INT) & (rhs == -1))
|
||||
auto* use_one = b.Or(rhs_is_zero, b.And(lhs_is_min, rhs_is_minus_one));
|
||||
|
||||
// Special handling for mod in case either operand is negative, as negative operands
|
||||
// for % is undefined behaviour for most backends (HLSL, MSL, GLSL, SPIR-V).
|
||||
if (!is_div) {
|
||||
const char* rhs_or_one = "rhs_or_one";
|
||||
body.Push(b.Decl(b.Let(
|
||||
rhs_or_one, b.Call("select", rhs, ScalarOrVector(width, 1_a), use_one))));
|
||||
|
||||
// Is either operand negative?
|
||||
// (lhs | rhs) & (1<<31)
|
||||
auto sign_bit_mask = ScalarOrVector(width, u32(1 << (bits - 1)));
|
||||
auto* lhs_or_rhs = CastScalarOrVector<u32>(width, b.Or(lhs, rhs_or_one));
|
||||
auto* lhs_or_rhs_is_neg =
|
||||
b.NotEqual(b.And(lhs_or_rhs, sign_bit_mask), ScalarOrVector(width, 0_u));
|
||||
|
||||
// lhs - trunc(lhs / rhs) * rhs (note: integral division truncates)
|
||||
auto* slow_mod = b.Sub(lhs, b.Mul(b.Div(lhs, rhs_or_one), rhs_or_one));
|
||||
|
||||
// lhs % rhs
|
||||
auto* fast_mod = b.Mod(lhs, rhs_or_one);
|
||||
|
||||
auto* use_slow = b.Call("any", lhs_or_rhs_is_neg);
|
||||
|
||||
body.Push(b.If(use_slow, b.Block(b.Return(slow_mod)),
|
||||
b.Else(b.Block(b.Return(fast_mod)))));
|
||||
|
||||
} else {
|
||||
auto* rhs_or_one = b.Call("select", rhs, ScalarOrVector(width, 1_a), use_one);
|
||||
body.Push(b.Return(is_div ? b.Div(lhs, rhs_or_one) : b.Mod(lhs, rhs_or_one)));
|
||||
}
|
||||
|
||||
} else {
|
||||
auto* rhs_or_one = b.Call("select", rhs, ScalarOrVector(width, 1_a), rhs_is_zero);
|
||||
body.Push(b.Return(is_div ? b.Div(lhs, rhs_or_one) : b.Mod(lhs, rhs_or_one)));
|
||||
}
|
||||
|
||||
body.Push(b.Return(is_div ? b.Div(lhs, select) : b.Mod(lhs, select)));
|
||||
b.Func(name,
|
||||
utils::Vector{
|
||||
b.Param("lhs", T(lhs_ty)),
|
||||
@@ -808,6 +842,14 @@ struct BuiltinPolyfill::State {
|
||||
}
|
||||
return b.Call(b.ty.vec<T>(width), value);
|
||||
}
|
||||
|
||||
template <typename To>
|
||||
const ast::Expression* CastScalarOrVector(uint32_t width, const ast::Expression* e) {
|
||||
if (width == 1) {
|
||||
return b.Call(b.ty.Of<To>(), e);
|
||||
}
|
||||
return b.Call(b.ty.vec<To>(width), e);
|
||||
}
|
||||
};
|
||||
|
||||
BuiltinPolyfill::BuiltinPolyfill() = default;
|
||||
|
||||
@@ -2073,7 +2073,12 @@ fn f() {
|
||||
|
||||
auto* expect = R"(
|
||||
fn tint_mod(lhs : i32, rhs : i32) -> i32 {
|
||||
return (lhs % select(rhs, 1, ((rhs == 0) | ((lhs == -2147483648) & (rhs == -1)))));
|
||||
let rhs_or_one = select(rhs, 1, ((rhs == 0) | ((lhs == -2147483648) & (rhs == -1))));
|
||||
if (any(((u32((lhs | rhs_or_one)) & 2147483648u) != 0u))) {
|
||||
return (lhs - ((lhs / rhs_or_one) * rhs_or_one));
|
||||
} else {
|
||||
return (lhs % rhs_or_one);
|
||||
}
|
||||
}
|
||||
|
||||
fn f() {
|
||||
@@ -2121,7 +2126,12 @@ fn f() {
|
||||
|
||||
auto* expect = R"(
|
||||
fn tint_mod(lhs : i32, rhs : i32) -> i32 {
|
||||
return (lhs % select(rhs, 1, ((rhs == 0) | ((lhs == -2147483648) & (rhs == -1)))));
|
||||
let rhs_or_one = select(rhs, 1, ((rhs == 0) | ((lhs == -2147483648) & (rhs == -1))));
|
||||
if (any(((u32((lhs | rhs_or_one)) & 2147483648u) != 0u))) {
|
||||
return (lhs - ((lhs / rhs_or_one) * rhs_or_one));
|
||||
} else {
|
||||
return (lhs % rhs_or_one);
|
||||
}
|
||||
}
|
||||
|
||||
fn f() {
|
||||
@@ -2169,7 +2179,12 @@ fn f() {
|
||||
|
||||
auto* expect = R"(
|
||||
fn tint_mod(lhs : i32, rhs : i32) -> i32 {
|
||||
return (lhs % select(rhs, 1, ((rhs == 0) | ((lhs == -2147483648) & (rhs == -1)))));
|
||||
let rhs_or_one = select(rhs, 1, ((rhs == 0) | ((lhs == -2147483648) & (rhs == -1))));
|
||||
if (any(((u32((lhs | rhs_or_one)) & 2147483648u) != 0u))) {
|
||||
return (lhs - ((lhs / rhs_or_one) * rhs_or_one));
|
||||
} else {
|
||||
return (lhs % rhs_or_one);
|
||||
}
|
||||
}
|
||||
|
||||
fn f() {
|
||||
@@ -2363,7 +2378,12 @@ fn f() {
|
||||
auto* expect = R"(
|
||||
fn tint_mod(lhs : vec3<i32>, rhs : i32) -> vec3<i32> {
|
||||
let r = vec3<i32>(rhs);
|
||||
return (lhs % select(r, vec3(1), ((r == vec3(0)) | ((lhs == vec3(-2147483648)) & (r == vec3(-1))))));
|
||||
let rhs_or_one = select(r, vec3(1), ((r == vec3(0)) | ((lhs == vec3(-2147483648)) & (r == vec3(-1)))));
|
||||
if (any(((vec3<u32>((lhs | rhs_or_one)) & vec3<u32>(2147483648u)) != vec3<u32>(0u)))) {
|
||||
return (lhs - ((lhs / rhs_or_one) * rhs_or_one));
|
||||
} else {
|
||||
return (lhs % rhs_or_one);
|
||||
}
|
||||
}
|
||||
|
||||
fn f() {
|
||||
@@ -2413,7 +2433,12 @@ fn f() {
|
||||
auto* expect = R"(
|
||||
fn tint_mod(lhs : vec3<i32>, rhs : i32) -> vec3<i32> {
|
||||
let r = vec3<i32>(rhs);
|
||||
return (lhs % select(r, vec3(1), ((r == vec3(0)) | ((lhs == vec3(-2147483648)) & (r == vec3(-1))))));
|
||||
let rhs_or_one = select(r, vec3(1), ((r == vec3(0)) | ((lhs == vec3(-2147483648)) & (r == vec3(-1)))));
|
||||
if (any(((vec3<u32>((lhs | rhs_or_one)) & vec3<u32>(2147483648u)) != vec3<u32>(0u)))) {
|
||||
return (lhs - ((lhs / rhs_or_one) * rhs_or_one));
|
||||
} else {
|
||||
return (lhs % rhs_or_one);
|
||||
}
|
||||
}
|
||||
|
||||
fn f() {
|
||||
@@ -2463,7 +2488,12 @@ fn f() {
|
||||
auto* expect = R"(
|
||||
fn tint_mod(lhs : vec3<i32>, rhs : i32) -> vec3<i32> {
|
||||
let r = vec3<i32>(rhs);
|
||||
return (lhs % select(r, vec3(1), ((r == vec3(0)) | ((lhs == vec3(-2147483648)) & (r == vec3(-1))))));
|
||||
let rhs_or_one = select(r, vec3(1), ((r == vec3(0)) | ((lhs == vec3(-2147483648)) & (r == vec3(-1)))));
|
||||
if (any(((vec3<u32>((lhs | rhs_or_one)) & vec3<u32>(2147483648u)) != vec3<u32>(0u)))) {
|
||||
return (lhs - ((lhs / rhs_or_one) * rhs_or_one));
|
||||
} else {
|
||||
return (lhs % rhs_or_one);
|
||||
}
|
||||
}
|
||||
|
||||
fn f() {
|
||||
@@ -2563,7 +2593,12 @@ fn f() {
|
||||
auto* expect = R"(
|
||||
fn tint_mod(lhs : i32, rhs : vec3<i32>) -> vec3<i32> {
|
||||
let l = vec3<i32>(lhs);
|
||||
return (l % select(rhs, vec3(1), ((rhs == vec3(0)) | ((l == vec3(-2147483648)) & (rhs == vec3(-1))))));
|
||||
let rhs_or_one = select(rhs, vec3(1), ((rhs == vec3(0)) | ((l == vec3(-2147483648)) & (rhs == vec3(-1)))));
|
||||
if (any(((vec3<u32>((l | rhs_or_one)) & vec3<u32>(2147483648u)) != vec3<u32>(0u)))) {
|
||||
return (l - ((l / rhs_or_one) * rhs_or_one));
|
||||
} else {
|
||||
return (l % rhs_or_one);
|
||||
}
|
||||
}
|
||||
|
||||
fn f() {
|
||||
@@ -2613,7 +2648,12 @@ fn f() {
|
||||
auto* expect = R"(
|
||||
fn tint_mod(lhs : i32, rhs : vec3<i32>) -> vec3<i32> {
|
||||
let l = vec3<i32>(lhs);
|
||||
return (l % select(rhs, vec3(1), ((rhs == vec3(0)) | ((l == vec3(-2147483648)) & (rhs == vec3(-1))))));
|
||||
let rhs_or_one = select(rhs, vec3(1), ((rhs == vec3(0)) | ((l == vec3(-2147483648)) & (rhs == vec3(-1)))));
|
||||
if (any(((vec3<u32>((l | rhs_or_one)) & vec3<u32>(2147483648u)) != vec3<u32>(0u)))) {
|
||||
return (l - ((l / rhs_or_one) * rhs_or_one));
|
||||
} else {
|
||||
return (l % rhs_or_one);
|
||||
}
|
||||
}
|
||||
|
||||
fn f() {
|
||||
@@ -2711,7 +2751,12 @@ fn f() {
|
||||
|
||||
auto* expect = R"(
|
||||
fn tint_mod(lhs : vec3<i32>, rhs : vec3<i32>) -> vec3<i32> {
|
||||
return (lhs % select(rhs, vec3(1), ((rhs == vec3(0)) | ((lhs == vec3(-2147483648)) & (rhs == vec3(-1))))));
|
||||
let rhs_or_one = select(rhs, vec3(1), ((rhs == vec3(0)) | ((lhs == vec3(-2147483648)) & (rhs == vec3(-1)))));
|
||||
if (any(((vec3<u32>((lhs | rhs_or_one)) & vec3<u32>(2147483648u)) != vec3<u32>(0u)))) {
|
||||
return (lhs - ((lhs / rhs_or_one) * rhs_or_one));
|
||||
} else {
|
||||
return (lhs % rhs_or_one);
|
||||
}
|
||||
}
|
||||
|
||||
fn f() {
|
||||
|
||||
Reference in New Issue
Block a user