mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-09 21:47:47 +00:00
tint/transform: Fix ICE when combining polyfills
There's a reason the overload of `ctx.Replace()` that takes a pointer to the replacement is deprecated - it doesn't play well when used as part of another replacement. Switch to using the callback overload of Replace() to fix bad transform output. Bug: tint:1386647 Change-Id: I94292eeb65d24d7b2446b16b8b4ad13bdd27965a Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111000 Auto-Submit: Ben Clayton <bclayton@google.com> Commit-Queue: James Price <jrprice@google.com> Reviewed-by: James Price <jrprice@google.com> Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
committed by
Dawn LUCI CQ
parent
f5eec817de
commit
619f9bd639
@@ -41,6 +41,10 @@ struct BuiltinPolyfill::State {
|
||||
/// @param p the builtins to polyfill
|
||||
State(CloneContext& c, Builtins p) : ctx(c), polyfill(p) {}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
// Function polyfills
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Builds the polyfill function for the `acosh` builtin
|
||||
/// @param ty the parameter and return type for the function
|
||||
/// @return the polyfill function name
|
||||
@@ -559,63 +563,6 @@ struct BuiltinPolyfill::State {
|
||||
return name;
|
||||
}
|
||||
|
||||
/// Builds the polyfill function for a divide or modulo operator with integer scalar or vector
|
||||
/// operands.
|
||||
/// @param sig the signature of the binary operator
|
||||
/// @return the polyfill function name
|
||||
Symbol int_div_mod(const BinaryOpSignature& sig) {
|
||||
const auto op = std::get<0>(sig);
|
||||
const auto* lhs_ty = std::get<1>(sig);
|
||||
const auto* rhs_ty = std::get<2>(sig);
|
||||
const bool is_div = op == ast::BinaryOp::kDivide;
|
||||
|
||||
uint32_t lhs_width = 1;
|
||||
uint32_t rhs_width = 1;
|
||||
const auto* lhs_el_ty = sem::Type::ElementOf(lhs_ty, &lhs_width);
|
||||
const auto* rhs_el_ty = sem::Type::ElementOf(rhs_ty, &rhs_width);
|
||||
|
||||
const uint32_t width = std::max(lhs_width, rhs_width);
|
||||
|
||||
const char* lhs = "lhs";
|
||||
const char* rhs = "rhs";
|
||||
|
||||
utils::Vector<const ast::Statement*, 4> body;
|
||||
|
||||
if (lhs_width < width) {
|
||||
// lhs is scalar, rhs is vector. Convert lhs to vector.
|
||||
body.Push(b.Decl(b.Let("l", b.vec(T(lhs_el_ty), width, b.Expr(lhs)))));
|
||||
lhs = "l";
|
||||
}
|
||||
if (rhs_width < width) {
|
||||
// lhs is vector, rhs is scalar. Convert rhs to vector.
|
||||
body.Push(b.Decl(b.Let("r", b.vec(T(rhs_el_ty), width, b.Expr(rhs)))));
|
||||
rhs = "r";
|
||||
}
|
||||
|
||||
auto name = b.Symbols().New(is_div ? "tint_div" : "tint_mod");
|
||||
auto* use_one = b.Equal(rhs, ScalarOrVector(width, 0_a));
|
||||
if (lhs_ty->is_signed_scalar_or_vector()) {
|
||||
const auto bits = lhs_el_ty->Size() * 8;
|
||||
auto min_int = AInt(AInt::kLowestValue >> (AInt::kNumBits - bits));
|
||||
const ast::Expression* lhs_is_min = b.Equal(lhs, ScalarOrVector(width, min_int));
|
||||
const ast::Expression* rhs_is_minus_one = b.Equal(rhs, ScalarOrVector(width, -1_a));
|
||||
// use_one = use_one | ((lhs == MIN_INT) & (rhs == -1))
|
||||
use_one = b.Or(use_one, b.And(lhs_is_min, rhs_is_minus_one));
|
||||
}
|
||||
auto* select = b.Call("select", rhs, ScalarOrVector(width, 1_a), use_one);
|
||||
|
||||
body.Push(b.Return(is_div ? b.Div(lhs, select) : b.Mod(lhs, select)));
|
||||
b.Func(name,
|
||||
utils::Vector{
|
||||
b.Param("lhs", T(lhs_ty)),
|
||||
b.Param("rhs", T(rhs_ty)),
|
||||
},
|
||||
width == 1 ? T(lhs_ty) : b.ty.vec(T(lhs_el_ty), width), // return type
|
||||
std::move(body));
|
||||
|
||||
return name;
|
||||
}
|
||||
|
||||
/// Builds the polyfill function for the `saturate` builtin
|
||||
/// @param ty the parameter and return type for the function
|
||||
/// @return the polyfill function name
|
||||
@@ -677,6 +624,89 @@ struct BuiltinPolyfill::State {
|
||||
return name;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
// Inline polyfills
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Builds the polyfill inline expression for a bitshift left or bitshift right, ensuring that
|
||||
/// the RHS is modulo the bit-width of the LHS.
|
||||
/// @param bin_op the original BinaryExpression
|
||||
/// @return the polyfill value for bitshift operation
|
||||
const ast::Expression* BitshiftModulo(const ast::BinaryExpression* bin_op) {
|
||||
auto* lhs_ty = ctx.src->TypeOf(bin_op->lhs)->UnwrapRef();
|
||||
auto* rhs_ty = ctx.src->TypeOf(bin_op->rhs)->UnwrapRef();
|
||||
auto* lhs_el_ty = sem::Type::DeepestElementOf(lhs_ty);
|
||||
const ast::Expression* mask = b.Expr(AInt(lhs_el_ty->Size() * 8 - 1));
|
||||
if (rhs_ty->Is<sem::Vector>()) {
|
||||
mask = b.Construct(CreateASTTypeFor(ctx, rhs_ty), mask);
|
||||
}
|
||||
auto* lhs = ctx.Clone(bin_op->lhs);
|
||||
auto* rhs = b.And(ctx.Clone(bin_op->rhs), mask);
|
||||
return b.create<ast::BinaryExpression>(ctx.Clone(bin_op->source), bin_op->op, lhs, rhs);
|
||||
}
|
||||
|
||||
/// Builds the polyfill inline expression for a integer divide or modulo, preventing DBZs and
|
||||
/// integer overflows.
|
||||
/// @param bin_op the original BinaryExpression
|
||||
/// @return the polyfill divide or modulo
|
||||
const ast::Expression* IntDivMod(const ast::BinaryExpression* bin_op) {
|
||||
auto* lhs_ty = ctx.src->TypeOf(bin_op->lhs)->UnwrapRef();
|
||||
auto* rhs_ty = ctx.src->TypeOf(bin_op->rhs)->UnwrapRef();
|
||||
BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty};
|
||||
auto fn = binary_op_polyfills.GetOrCreate(sig, [&] {
|
||||
const bool is_div = bin_op->op == ast::BinaryOp::kDivide;
|
||||
|
||||
uint32_t lhs_width = 1;
|
||||
uint32_t rhs_width = 1;
|
||||
const auto* lhs_el_ty = sem::Type::ElementOf(lhs_ty, &lhs_width);
|
||||
const auto* rhs_el_ty = sem::Type::ElementOf(rhs_ty, &rhs_width);
|
||||
|
||||
const uint32_t width = std::max(lhs_width, rhs_width);
|
||||
|
||||
const char* lhs = "lhs";
|
||||
const char* rhs = "rhs";
|
||||
|
||||
utils::Vector<const ast::Statement*, 4> body;
|
||||
|
||||
if (lhs_width < width) {
|
||||
// lhs is scalar, rhs is vector. Convert lhs to vector.
|
||||
body.Push(b.Decl(b.Let("l", b.vec(T(lhs_el_ty), width, b.Expr(lhs)))));
|
||||
lhs = "l";
|
||||
}
|
||||
if (rhs_width < width) {
|
||||
// lhs is vector, rhs is scalar. Convert rhs to vector.
|
||||
body.Push(b.Decl(b.Let("r", b.vec(T(rhs_el_ty), width, b.Expr(rhs)))));
|
||||
rhs = "r";
|
||||
}
|
||||
|
||||
auto name = b.Symbols().New(is_div ? "tint_div" : "tint_mod");
|
||||
auto* use_one = b.Equal(rhs, ScalarOrVector(width, 0_a));
|
||||
if (lhs_ty->is_signed_scalar_or_vector()) {
|
||||
const auto bits = lhs_el_ty->Size() * 8;
|
||||
auto min_int = AInt(AInt::kLowestValue >> (AInt::kNumBits - bits));
|
||||
const ast::Expression* lhs_is_min = b.Equal(lhs, ScalarOrVector(width, min_int));
|
||||
const ast::Expression* rhs_is_minus_one = b.Equal(rhs, ScalarOrVector(width, -1_a));
|
||||
// use_one = use_one | ((lhs == MIN_INT) & (rhs == -1))
|
||||
use_one = b.Or(use_one, b.And(lhs_is_min, rhs_is_minus_one));
|
||||
}
|
||||
auto* select = b.Call("select", rhs, ScalarOrVector(width, 1_a), use_one);
|
||||
|
||||
body.Push(b.Return(is_div ? b.Div(lhs, select) : b.Mod(lhs, select)));
|
||||
b.Func(name,
|
||||
utils::Vector{
|
||||
b.Param("lhs", T(lhs_ty)),
|
||||
b.Param("rhs", T(rhs_ty)),
|
||||
},
|
||||
width == 1 ? T(lhs_ty) : b.ty.vec(T(lhs_el_ty), width), // return type
|
||||
std::move(body));
|
||||
|
||||
return name;
|
||||
});
|
||||
auto* lhs = ctx.Clone(bin_op->lhs);
|
||||
auto* rhs = ctx.Clone(bin_op->rhs);
|
||||
return b.Call(fn, lhs, rhs);
|
||||
}
|
||||
|
||||
private:
|
||||
/// The clone context
|
||||
CloneContext& ctx;
|
||||
@@ -687,6 +717,9 @@ struct BuiltinPolyfill::State {
|
||||
/// The source clone context
|
||||
const sem::Info& sem = ctx.src->Sem();
|
||||
|
||||
// Polyfill functions for binary operators.
|
||||
utils::Hashmap<BinaryOpSignature, Symbol, 8> binary_op_polyfills;
|
||||
|
||||
/// @returns the AST type for the given sem type
|
||||
const ast::Type* T(const sem::Type* ty) const { return CreateASTTypeFor(ctx, ty); }
|
||||
|
||||
@@ -724,7 +757,6 @@ Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src,
|
||||
auto& polyfill = cfg->builtins;
|
||||
|
||||
utils::Hashmap<const sem::Builtin*, Symbol, 8> builtin_polyfills;
|
||||
utils::Hashmap<BinaryOpSignature, Symbol, 8> binary_op_polyfills;
|
||||
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
@@ -849,15 +881,7 @@ Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src,
|
||||
case ast::BinaryOp::kShiftLeft:
|
||||
case ast::BinaryOp::kShiftRight: {
|
||||
if (polyfill.bitshift_modulo) {
|
||||
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
|
||||
auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
|
||||
auto* lhs_el_ty = sem::Type::DeepestElementOf(lhs_ty);
|
||||
const ast::Expression* mask = b.Expr(AInt(lhs_el_ty->Size() * 8 - 1));
|
||||
if (rhs_ty->Is<sem::Vector>()) {
|
||||
mask = b.Construct(CreateASTTypeFor(ctx, rhs_ty), mask);
|
||||
}
|
||||
auto* mod = b.And(ctx.Clone(bin_op->rhs), mask);
|
||||
ctx.Replace(bin_op->rhs, mod);
|
||||
ctx.Replace(bin_op, [bin_op, &s] { return s.BitshiftModulo(bin_op); });
|
||||
made_changes = true;
|
||||
}
|
||||
break;
|
||||
@@ -867,13 +891,7 @@ Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src,
|
||||
if (polyfill.int_div_mod) {
|
||||
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
|
||||
if (lhs_ty->is_integer_scalar_or_vector()) {
|
||||
auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
|
||||
BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty};
|
||||
auto fn = binary_op_polyfills.GetOrCreate(
|
||||
sig, [&] { return s.int_div_mod(sig); });
|
||||
auto* lhs = ctx.Clone(bin_op->lhs);
|
||||
auto* rhs = ctx.Clone(bin_op->rhs);
|
||||
ctx.Replace(bin_op, b.Call(fn, lhs, rhs));
|
||||
ctx.Replace(bin_op, [bin_op, &s] { return s.IntDivMod(bin_op); });
|
||||
made_changes = true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3000,5 +3000,37 @@ fn f() {
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Polyfill combinations
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST_F(BuiltinPolyfillTest, BitshiftAndModulo) {
|
||||
auto* src = R"(
|
||||
fn f(x : i32, y : u32, z : u32) {
|
||||
let l = x << (y % z);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn tint_mod(lhs : u32, rhs : u32) -> u32 {
|
||||
return (lhs % select(rhs, 1, (rhs == 0)));
|
||||
}
|
||||
|
||||
fn f(x : i32, y : u32, z : u32) {
|
||||
let l = (x << (tint_mod(y, z) & 31));
|
||||
}
|
||||
)";
|
||||
|
||||
BuiltinPolyfill::Builtins builtins;
|
||||
builtins.bitshift_modulo = true;
|
||||
builtins.int_div_mod = true;
|
||||
DataMap data;
|
||||
data.Add<BuiltinPolyfill::Config>(builtins);
|
||||
|
||||
auto got = Run<BuiltinPolyfill>(src, std::move(data));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tint::transform
|
||||
|
||||
Reference in New Issue
Block a user