tint/transform: Fix ICE when combining polyfills

There's a reason the overload of `ctx.Replace()` that takes a pointer to the replacement is deprecated - it doesn't play well when used as part of another replacement.
Switch to using the callback overload of Replace() to fix bad transform output.

Bug: tint:1386647
Change-Id: I94292eeb65d24d7b2446b16b8b4ad13bdd27965a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111000
Auto-Submit: Ben Clayton <bclayton@google.com>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ben Clayton
2022-11-21 17:11:05 +00:00
committed by Dawn LUCI CQ
parent f5eec817de
commit 619f9bd639
69 changed files with 941 additions and 756 deletions

View File

@@ -41,6 +41,10 @@ struct BuiltinPolyfill::State {
/// @param p the builtins to polyfill
State(CloneContext& c, Builtins p) : ctx(c), polyfill(p) {}
////////////////////////////////////////////////////////////////////////////
// Function polyfills
////////////////////////////////////////////////////////////////////////////
/// Builds the polyfill function for the `acosh` builtin
/// @param ty the parameter and return type for the function
/// @return the polyfill function name
@@ -559,63 +563,6 @@ struct BuiltinPolyfill::State {
return name;
}
/// Builds the polyfill function for a divide or modulo operator with integer scalar or vector
/// operands.
/// @param sig the signature of the binary operator
/// @return the polyfill function name
Symbol int_div_mod(const BinaryOpSignature& sig) {
const auto op = std::get<0>(sig);
const auto* lhs_ty = std::get<1>(sig);
const auto* rhs_ty = std::get<2>(sig);
const bool is_div = op == ast::BinaryOp::kDivide;
uint32_t lhs_width = 1;
uint32_t rhs_width = 1;
const auto* lhs_el_ty = sem::Type::ElementOf(lhs_ty, &lhs_width);
const auto* rhs_el_ty = sem::Type::ElementOf(rhs_ty, &rhs_width);
const uint32_t width = std::max(lhs_width, rhs_width);
const char* lhs = "lhs";
const char* rhs = "rhs";
utils::Vector<const ast::Statement*, 4> body;
if (lhs_width < width) {
// lhs is scalar, rhs is vector. Convert lhs to vector.
body.Push(b.Decl(b.Let("l", b.vec(T(lhs_el_ty), width, b.Expr(lhs)))));
lhs = "l";
}
if (rhs_width < width) {
// lhs is vector, rhs is scalar. Convert rhs to vector.
body.Push(b.Decl(b.Let("r", b.vec(T(rhs_el_ty), width, b.Expr(rhs)))));
rhs = "r";
}
auto name = b.Symbols().New(is_div ? "tint_div" : "tint_mod");
auto* use_one = b.Equal(rhs, ScalarOrVector(width, 0_a));
if (lhs_ty->is_signed_scalar_or_vector()) {
const auto bits = lhs_el_ty->Size() * 8;
auto min_int = AInt(AInt::kLowestValue >> (AInt::kNumBits - bits));
const ast::Expression* lhs_is_min = b.Equal(lhs, ScalarOrVector(width, min_int));
const ast::Expression* rhs_is_minus_one = b.Equal(rhs, ScalarOrVector(width, -1_a));
// use_one = use_one | ((lhs == MIN_INT) & (rhs == -1))
use_one = b.Or(use_one, b.And(lhs_is_min, rhs_is_minus_one));
}
auto* select = b.Call("select", rhs, ScalarOrVector(width, 1_a), use_one);
body.Push(b.Return(is_div ? b.Div(lhs, select) : b.Mod(lhs, select)));
b.Func(name,
utils::Vector{
b.Param("lhs", T(lhs_ty)),
b.Param("rhs", T(rhs_ty)),
},
width == 1 ? T(lhs_ty) : b.ty.vec(T(lhs_el_ty), width), // return type
std::move(body));
return name;
}
/// Builds the polyfill function for the `saturate` builtin
/// @param ty the parameter and return type for the function
/// @return the polyfill function name
@@ -677,6 +624,89 @@ struct BuiltinPolyfill::State {
return name;
}
////////////////////////////////////////////////////////////////////////////
// Inline polyfills
////////////////////////////////////////////////////////////////////////////
/// Builds the polyfill inline expression for a bitshift left or bitshift right, ensuring that
/// the RHS is modulo the bit-width of the LHS.
/// @param bin_op the original BinaryExpression
/// @return the polyfill value for bitshift operation
const ast::Expression* BitshiftModulo(const ast::BinaryExpression* bin_op) {
auto* lhs_ty = ctx.src->TypeOf(bin_op->lhs)->UnwrapRef();
auto* rhs_ty = ctx.src->TypeOf(bin_op->rhs)->UnwrapRef();
auto* lhs_el_ty = sem::Type::DeepestElementOf(lhs_ty);
const ast::Expression* mask = b.Expr(AInt(lhs_el_ty->Size() * 8 - 1));
if (rhs_ty->Is<sem::Vector>()) {
mask = b.Construct(CreateASTTypeFor(ctx, rhs_ty), mask);
}
auto* lhs = ctx.Clone(bin_op->lhs);
auto* rhs = b.And(ctx.Clone(bin_op->rhs), mask);
return b.create<ast::BinaryExpression>(ctx.Clone(bin_op->source), bin_op->op, lhs, rhs);
}
/// Builds the polyfill inline expression for a integer divide or modulo, preventing DBZs and
/// integer overflows.
/// @param bin_op the original BinaryExpression
/// @return the polyfill divide or modulo
const ast::Expression* IntDivMod(const ast::BinaryExpression* bin_op) {
auto* lhs_ty = ctx.src->TypeOf(bin_op->lhs)->UnwrapRef();
auto* rhs_ty = ctx.src->TypeOf(bin_op->rhs)->UnwrapRef();
BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty};
auto fn = binary_op_polyfills.GetOrCreate(sig, [&] {
const bool is_div = bin_op->op == ast::BinaryOp::kDivide;
uint32_t lhs_width = 1;
uint32_t rhs_width = 1;
const auto* lhs_el_ty = sem::Type::ElementOf(lhs_ty, &lhs_width);
const auto* rhs_el_ty = sem::Type::ElementOf(rhs_ty, &rhs_width);
const uint32_t width = std::max(lhs_width, rhs_width);
const char* lhs = "lhs";
const char* rhs = "rhs";
utils::Vector<const ast::Statement*, 4> body;
if (lhs_width < width) {
// lhs is scalar, rhs is vector. Convert lhs to vector.
body.Push(b.Decl(b.Let("l", b.vec(T(lhs_el_ty), width, b.Expr(lhs)))));
lhs = "l";
}
if (rhs_width < width) {
// lhs is vector, rhs is scalar. Convert rhs to vector.
body.Push(b.Decl(b.Let("r", b.vec(T(rhs_el_ty), width, b.Expr(rhs)))));
rhs = "r";
}
auto name = b.Symbols().New(is_div ? "tint_div" : "tint_mod");
auto* use_one = b.Equal(rhs, ScalarOrVector(width, 0_a));
if (lhs_ty->is_signed_scalar_or_vector()) {
const auto bits = lhs_el_ty->Size() * 8;
auto min_int = AInt(AInt::kLowestValue >> (AInt::kNumBits - bits));
const ast::Expression* lhs_is_min = b.Equal(lhs, ScalarOrVector(width, min_int));
const ast::Expression* rhs_is_minus_one = b.Equal(rhs, ScalarOrVector(width, -1_a));
// use_one = use_one | ((lhs == MIN_INT) & (rhs == -1))
use_one = b.Or(use_one, b.And(lhs_is_min, rhs_is_minus_one));
}
auto* select = b.Call("select", rhs, ScalarOrVector(width, 1_a), use_one);
body.Push(b.Return(is_div ? b.Div(lhs, select) : b.Mod(lhs, select)));
b.Func(name,
utils::Vector{
b.Param("lhs", T(lhs_ty)),
b.Param("rhs", T(rhs_ty)),
},
width == 1 ? T(lhs_ty) : b.ty.vec(T(lhs_el_ty), width), // return type
std::move(body));
return name;
});
auto* lhs = ctx.Clone(bin_op->lhs);
auto* rhs = ctx.Clone(bin_op->rhs);
return b.Call(fn, lhs, rhs);
}
private:
/// The clone context
CloneContext& ctx;
@@ -687,6 +717,9 @@ struct BuiltinPolyfill::State {
/// The source clone context
const sem::Info& sem = ctx.src->Sem();
// Polyfill functions for binary operators.
utils::Hashmap<BinaryOpSignature, Symbol, 8> binary_op_polyfills;
/// @returns the AST type for the given sem type
const ast::Type* T(const sem::Type* ty) const { return CreateASTTypeFor(ctx, ty); }
@@ -724,7 +757,6 @@ Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src,
auto& polyfill = cfg->builtins;
utils::Hashmap<const sem::Builtin*, Symbol, 8> builtin_polyfills;
utils::Hashmap<BinaryOpSignature, Symbol, 8> binary_op_polyfills;
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
@@ -849,15 +881,7 @@ Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src,
case ast::BinaryOp::kShiftLeft:
case ast::BinaryOp::kShiftRight: {
if (polyfill.bitshift_modulo) {
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
auto* lhs_el_ty = sem::Type::DeepestElementOf(lhs_ty);
const ast::Expression* mask = b.Expr(AInt(lhs_el_ty->Size() * 8 - 1));
if (rhs_ty->Is<sem::Vector>()) {
mask = b.Construct(CreateASTTypeFor(ctx, rhs_ty), mask);
}
auto* mod = b.And(ctx.Clone(bin_op->rhs), mask);
ctx.Replace(bin_op->rhs, mod);
ctx.Replace(bin_op, [bin_op, &s] { return s.BitshiftModulo(bin_op); });
made_changes = true;
}
break;
@@ -867,13 +891,7 @@ Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src,
if (polyfill.int_div_mod) {
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
if (lhs_ty->is_integer_scalar_or_vector()) {
auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty};
auto fn = binary_op_polyfills.GetOrCreate(
sig, [&] { return s.int_div_mod(sig); });
auto* lhs = ctx.Clone(bin_op->lhs);
auto* rhs = ctx.Clone(bin_op->rhs);
ctx.Replace(bin_op, b.Call(fn, lhs, rhs));
ctx.Replace(bin_op, [bin_op, &s] { return s.IntDivMod(bin_op); });
made_changes = true;
}
}

View File

@@ -3000,5 +3000,37 @@ fn f() {
EXPECT_EQ(expect, str(got));
}
////////////////////////////////////////////////////////////////////////////////
// Polyfill combinations
////////////////////////////////////////////////////////////////////////////////
TEST_F(BuiltinPolyfillTest, BitshiftAndModulo) {
auto* src = R"(
fn f(x : i32, y : u32, z : u32) {
let l = x << (y % z);
}
)";
auto* expect = R"(
fn tint_mod(lhs : u32, rhs : u32) -> u32 {
return (lhs % select(rhs, 1, (rhs == 0)));
}
fn f(x : i32, y : u32, z : u32) {
let l = (x << (tint_mod(y, z) & 31));
}
)";
BuiltinPolyfill::Builtins builtins;
builtins.bitshift_modulo = true;
builtins.int_div_mod = true;
DataMap data;
data.Add<BuiltinPolyfill::Config>(builtins);
auto got = Run<BuiltinPolyfill>(src, std::move(data));
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace tint::transform