tint/transform: Refactor BuiltinPolyfill transform
Move the bulk of the logic into the State class. Reduces deep indentation, and likely improves performance by reducing the number of variable that require lambda capture. Change-Id: I85c87298157f34645d0ae064439bb640f7af7c80 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/123200 Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Ben Clayton <bclayton@chromium.org> Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
parent
9eff4810c6
commit
ce2578bc99
|
@ -40,17 +40,100 @@ using BinaryOpSignature = std::tuple<ast::BinaryOp, const type::Type*, const typ
|
||||||
/// PIMPL state for the transform
|
/// PIMPL state for the transform
|
||||||
struct BuiltinPolyfill::State {
|
struct BuiltinPolyfill::State {
|
||||||
/// Constructor
|
/// Constructor
|
||||||
/// @param c the CloneContext
|
/// @param program the source program
|
||||||
/// @param p the builtins to polyfill
|
/// @param config the transform config
|
||||||
State(CloneContext& c, Builtins p) : ctx(c), polyfill(p) {
|
State(const Program* program, const Config& config) : src(program), cfg(config) {
|
||||||
has_full_ptr_params = false;
|
has_full_ptr_params = false;
|
||||||
for (auto* enable : c.src->AST().Enables()) {
|
for (auto* enable : src->AST().Enables()) {
|
||||||
if (enable->extension == builtin::Extension::kChromiumExperimentalFullPtrParameters) {
|
if (enable->extension == builtin::Extension::kChromiumExperimentalFullPtrParameters) {
|
||||||
has_full_ptr_params = true;
|
has_full_ptr_params = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Runs the transform
|
||||||
|
/// @returns the new program or SkipTransform if the transform is not required
|
||||||
|
Transform::ApplyResult Run() {
|
||||||
|
for (auto* node : src->ASTNodes().Objects()) {
|
||||||
|
Switch(
|
||||||
|
node, //
|
||||||
|
[&](const ast::CallExpression* expr) { Call(expr); },
|
||||||
|
[&](const ast::BinaryExpression* bin_op) {
|
||||||
|
if (auto* s = src->Sem().Get(bin_op);
|
||||||
|
!s || s->Stage() == sem::EvaluationStage::kConstant ||
|
||||||
|
s->Stage() == sem::EvaluationStage::kNotEvaluated) {
|
||||||
|
return; // Don't polyfill @const expressions
|
||||||
|
}
|
||||||
|
switch (bin_op->op) {
|
||||||
|
case ast::BinaryOp::kShiftLeft:
|
||||||
|
case ast::BinaryOp::kShiftRight: {
|
||||||
|
if (cfg.builtins.bitshift_modulo) {
|
||||||
|
ctx.Replace(bin_op,
|
||||||
|
[this, bin_op] { return BitshiftModulo(bin_op); });
|
||||||
|
made_changes = true;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case ast::BinaryOp::kDivide: {
|
||||||
|
if (cfg.builtins.int_div_mod) {
|
||||||
|
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
|
||||||
|
if (lhs_ty->is_integer_scalar_or_vector()) {
|
||||||
|
ctx.Replace(bin_op,
|
||||||
|
[this, bin_op] { return IntDivMod(bin_op); });
|
||||||
|
made_changes = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case ast::BinaryOp::kModulo: {
|
||||||
|
if (cfg.builtins.int_div_mod) {
|
||||||
|
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
|
||||||
|
if (lhs_ty->is_integer_scalar_or_vector()) {
|
||||||
|
ctx.Replace(bin_op,
|
||||||
|
[this, bin_op] { return IntDivMod(bin_op); });
|
||||||
|
made_changes = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (cfg.builtins.precise_float_mod) {
|
||||||
|
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
|
||||||
|
if (lhs_ty->is_float_scalar_or_vector()) {
|
||||||
|
ctx.Replace(bin_op,
|
||||||
|
[this, bin_op] { return PreciseFloatMod(bin_op); });
|
||||||
|
made_changes = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[&](const ast::Expression* expr) {
|
||||||
|
if (cfg.builtins.bgra8unorm) {
|
||||||
|
if (auto* ty_expr = src->Sem().Get<sem::TypeExpression>(expr)) {
|
||||||
|
if (auto* tex = ty_expr->Type()->As<type::StorageTexture>()) {
|
||||||
|
if (tex->texel_format() == builtin::TexelFormat::kBgra8Unorm) {
|
||||||
|
ctx.Replace(expr, [this, tex] {
|
||||||
|
return ctx.dst->Expr(ctx.dst->ty.storage_texture(
|
||||||
|
tex->dim(), builtin::TexelFormat::kRgba8Unorm,
|
||||||
|
tex->access()));
|
||||||
|
});
|
||||||
|
made_changes = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!made_changes) {
|
||||||
|
return SkipTransform;
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Clone();
|
||||||
|
return Program(std::move(b));
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////
|
||||||
// Function polyfills
|
// Function polyfills
|
||||||
////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -71,7 +154,7 @@ struct BuiltinPolyfill::State {
|
||||||
};
|
};
|
||||||
|
|
||||||
utils::Vector<const ast::Statement*, 4> body;
|
utils::Vector<const ast::Statement*, 4> body;
|
||||||
switch (polyfill.acosh) {
|
switch (cfg.builtins.acosh) {
|
||||||
case Level::kFull:
|
case Level::kFull:
|
||||||
// return log(x + sqrt(x*x - 1));
|
// return log(x + sqrt(x*x - 1));
|
||||||
body.Push(b.Return(
|
body.Push(b.Return(
|
||||||
|
@ -85,7 +168,7 @@ struct BuiltinPolyfill::State {
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
TINT_ICE(Transform, b.Diagnostics())
|
TINT_ICE(Transform, b.Diagnostics())
|
||||||
<< "unhandled polyfill level: " << static_cast<int>(polyfill.acosh);
|
<< "unhandled polyfill level: " << static_cast<int>(cfg.builtins.acosh);
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -125,7 +208,7 @@ struct BuiltinPolyfill::State {
|
||||||
};
|
};
|
||||||
|
|
||||||
utils::Vector<const ast::Statement*, 1> body;
|
utils::Vector<const ast::Statement*, 1> body;
|
||||||
switch (polyfill.atanh) {
|
switch (cfg.builtins.atanh) {
|
||||||
case Level::kFull:
|
case Level::kFull:
|
||||||
// return log((1+x) / (1-x)) * 0.5
|
// return log((1+x) / (1-x)) * 0.5
|
||||||
body.Push(
|
body.Push(
|
||||||
|
@ -138,7 +221,7 @@ struct BuiltinPolyfill::State {
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
TINT_ICE(Transform, b.Diagnostics())
|
TINT_ICE(Transform, b.Diagnostics())
|
||||||
<< "unhandled polyfill level: " << static_cast<int>(polyfill.acosh);
|
<< "unhandled polyfill level: " << static_cast<int>(cfg.builtins.acosh);
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -306,7 +389,7 @@ struct BuiltinPolyfill::State {
|
||||||
b.Decl(b.Let("e", b.Call("min", u32(W), b.Add("s", "count")))),
|
b.Decl(b.Let("e", b.Call("min", u32(W), b.Add("s", "count")))),
|
||||||
};
|
};
|
||||||
|
|
||||||
switch (polyfill.extract_bits) {
|
switch (cfg.builtins.extract_bits) {
|
||||||
case Level::kFull:
|
case Level::kFull:
|
||||||
body.Push(b.Decl(b.Let("shl", b.Sub(u32(W), "e"))));
|
body.Push(b.Decl(b.Let("shl", b.Sub(u32(W), "e"))));
|
||||||
body.Push(b.Decl(b.Let("shr", b.Add("shl", "s"))));
|
body.Push(b.Decl(b.Let("shr", b.Add("shl", "s"))));
|
||||||
|
@ -328,7 +411,7 @@ struct BuiltinPolyfill::State {
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
TINT_ICE(Transform, b.Diagnostics())
|
TINT_ICE(Transform, b.Diagnostics())
|
||||||
<< "unhandled polyfill level: " << static_cast<int>(polyfill.extract_bits);
|
<< "unhandled polyfill level: " << static_cast<int>(cfg.builtins.extract_bits);
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -532,7 +615,7 @@ struct BuiltinPolyfill::State {
|
||||||
|
|
||||||
utils::Vector<const ast::Statement*, 8> body;
|
utils::Vector<const ast::Statement*, 8> body;
|
||||||
|
|
||||||
switch (polyfill.insert_bits) {
|
switch (cfg.builtins.insert_bits) {
|
||||||
case Level::kFull:
|
case Level::kFull:
|
||||||
// let e = offset + count;
|
// let e = offset + count;
|
||||||
body.Push(b.Decl(b.Let("e", b.Add("offset", "count"))));
|
body.Push(b.Decl(b.Let("e", b.Add("offset", "count"))));
|
||||||
|
@ -566,7 +649,7 @@ struct BuiltinPolyfill::State {
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
TINT_ICE(Transform, b.Diagnostics())
|
TINT_ICE(Transform, b.Diagnostics())
|
||||||
<< "unhandled polyfill level: " << static_cast<int>(polyfill.insert_bits);
|
<< "unhandled polyfill level: " << static_cast<int>(cfg.builtins.insert_bits);
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -727,8 +810,8 @@ struct BuiltinPolyfill::State {
|
||||||
/// @param bin_op the original BinaryExpression
|
/// @param bin_op the original BinaryExpression
|
||||||
/// @return the polyfill value for bitshift operation
|
/// @return the polyfill value for bitshift operation
|
||||||
const ast::Expression* BitshiftModulo(const ast::BinaryExpression* bin_op) {
|
const ast::Expression* BitshiftModulo(const ast::BinaryExpression* bin_op) {
|
||||||
auto* lhs_ty = ctx.src->TypeOf(bin_op->lhs)->UnwrapRef();
|
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
|
||||||
auto* rhs_ty = ctx.src->TypeOf(bin_op->rhs)->UnwrapRef();
|
auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
|
||||||
auto* lhs_el_ty = type::Type::DeepestElementOf(lhs_ty);
|
auto* lhs_el_ty = type::Type::DeepestElementOf(lhs_ty);
|
||||||
const ast::Expression* mask = b.Expr(AInt(lhs_el_ty->Size() * 8 - 1));
|
const ast::Expression* mask = b.Expr(AInt(lhs_el_ty->Size() * 8 - 1));
|
||||||
if (rhs_ty->Is<type::Vector>()) {
|
if (rhs_ty->Is<type::Vector>()) {
|
||||||
|
@ -744,8 +827,8 @@ struct BuiltinPolyfill::State {
|
||||||
/// @param bin_op the original BinaryExpression
|
/// @param bin_op the original BinaryExpression
|
||||||
/// @return the polyfill divide or modulo
|
/// @return the polyfill divide or modulo
|
||||||
const ast::Expression* IntDivMod(const ast::BinaryExpression* bin_op) {
|
const ast::Expression* IntDivMod(const ast::BinaryExpression* bin_op) {
|
||||||
auto* lhs_ty = ctx.src->TypeOf(bin_op->lhs)->UnwrapRef();
|
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
|
||||||
auto* rhs_ty = ctx.src->TypeOf(bin_op->rhs)->UnwrapRef();
|
auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
|
||||||
BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty};
|
BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty};
|
||||||
auto fn = binary_op_polyfills.GetOrCreate(sig, [&] {
|
auto fn = binary_op_polyfills.GetOrCreate(sig, [&] {
|
||||||
const bool is_div = bin_op->op == ast::BinaryOp::kDivide;
|
const bool is_div = bin_op->op == ast::BinaryOp::kDivide;
|
||||||
|
@ -839,8 +922,8 @@ struct BuiltinPolyfill::State {
|
||||||
/// @param bin_op the original BinaryExpression
|
/// @param bin_op the original BinaryExpression
|
||||||
/// @return the polyfill divide or modulo
|
/// @return the polyfill divide or modulo
|
||||||
const ast::Expression* PreciseFloatMod(const ast::BinaryExpression* bin_op) {
|
const ast::Expression* PreciseFloatMod(const ast::BinaryExpression* bin_op) {
|
||||||
auto* lhs_ty = ctx.src->TypeOf(bin_op->lhs)->UnwrapRef();
|
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
|
||||||
auto* rhs_ty = ctx.src->TypeOf(bin_op->rhs)->UnwrapRef();
|
auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
|
||||||
BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty};
|
BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty};
|
||||||
auto fn = binary_op_polyfills.GetOrCreate(sig, [&] {
|
auto fn = binary_op_polyfills.GetOrCreate(sig, [&] {
|
||||||
uint32_t lhs_width = 1;
|
uint32_t lhs_width = 1;
|
||||||
|
@ -888,23 +971,27 @@ struct BuiltinPolyfill::State {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// The clone context
|
/// The source program
|
||||||
CloneContext& ctx;
|
Program const* const src;
|
||||||
/// The builtins to polyfill
|
/// The transform config
|
||||||
Builtins polyfill;
|
const Config& cfg;
|
||||||
/// The destination program builder
|
/// The destination program builder
|
||||||
ProgramBuilder& b = *ctx.dst;
|
ProgramBuilder b;
|
||||||
|
/// The clone context
|
||||||
|
CloneContext ctx{&b, src};
|
||||||
/// The source clone context
|
/// The source clone context
|
||||||
const sem::Info& sem = ctx.src->Sem();
|
const sem::Info& sem = src->Sem();
|
||||||
|
/// Polyfill functions for binary operators.
|
||||||
// Polyfill functions for binary operators.
|
|
||||||
utils::Hashmap<BinaryOpSignature, Symbol, 8> binary_op_polyfills;
|
utils::Hashmap<BinaryOpSignature, Symbol, 8> binary_op_polyfills;
|
||||||
|
/// Polyfill builtins.
|
||||||
|
utils::Hashmap<const sem::Builtin*, Symbol, 8> builtin_polyfills;
|
||||||
// Tracks whether the chromium_experimental_full_ptr_parameters extension has been enabled.
|
// Tracks whether the chromium_experimental_full_ptr_parameters extension has been enabled.
|
||||||
bool has_full_ptr_params;
|
bool has_full_ptr_params = false;
|
||||||
|
/// True if the transform has made changes (i.e. the program needs cloning)
|
||||||
|
bool made_changes = false;
|
||||||
|
|
||||||
/// @returns the AST type for the given sem type
|
/// @returns the AST type for the given sem type
|
||||||
ast::Type T(const type::Type* ty) const { return CreateASTTypeFor(ctx, ty); }
|
ast::Type T(const type::Type* ty) { return CreateASTTypeFor(ctx, ty); }
|
||||||
|
|
||||||
/// @returns 1 if `ty` is not a vector, otherwise the vector width
|
/// @returns 1 if `ty` is not a vector, otherwise the vector width
|
||||||
uint32_t WidthOf(const type::Type* ty) const {
|
uint32_t WidthOf(const type::Type* ty) const {
|
||||||
|
@ -917,7 +1004,7 @@ struct BuiltinPolyfill::State {
|
||||||
/// @returns a scalar or vector with the given width, with each element with
|
/// @returns a scalar or vector with the given width, with each element with
|
||||||
/// the given value.
|
/// the given value.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
const ast::Expression* ScalarOrVector(uint32_t width, T value) const {
|
const ast::Expression* ScalarOrVector(uint32_t width, T value) {
|
||||||
if (width == 1) {
|
if (width == 1) {
|
||||||
return b.Expr(value);
|
return b.Expr(value);
|
||||||
}
|
}
|
||||||
|
@ -931,158 +1018,144 @@ struct BuiltinPolyfill::State {
|
||||||
}
|
}
|
||||||
return b.Call(b.ty.vec<To>(width), e);
|
return b.Call(b.ty.vec<To>(width), e);
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
BuiltinPolyfill::BuiltinPolyfill() = default;
|
/// Examines the call expression @p expr, applying any necessary polyfill transforms
|
||||||
|
void Call(const ast::CallExpression* expr) {
|
||||||
BuiltinPolyfill::~BuiltinPolyfill() = default;
|
|
||||||
|
|
||||||
Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src,
|
|
||||||
const DataMap& data,
|
|
||||||
DataMap&) const {
|
|
||||||
auto* cfg = data.Get<Config>();
|
|
||||||
if (!cfg) {
|
|
||||||
return SkipTransform;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& polyfill = cfg->builtins;
|
|
||||||
|
|
||||||
utils::Hashmap<const sem::Builtin*, Symbol, 8> builtin_polyfills;
|
|
||||||
|
|
||||||
ProgramBuilder b;
|
|
||||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
|
||||||
State s{ctx, polyfill};
|
|
||||||
|
|
||||||
bool made_changes = false;
|
|
||||||
for (auto* node : src->ASTNodes().Objects()) {
|
|
||||||
Switch(
|
|
||||||
node,
|
|
||||||
[&](const ast::CallExpression* expr) {
|
|
||||||
auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
|
auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
|
||||||
if (!call || call->Stage() == sem::EvaluationStage::kConstant ||
|
if (!call || call->Stage() == sem::EvaluationStage::kConstant ||
|
||||||
call->Stage() == sem::EvaluationStage::kNotEvaluated) {
|
call->Stage() == sem::EvaluationStage::kNotEvaluated) {
|
||||||
return; // Don't polyfill @const expressions
|
return; // Don't polyfill @const expressions
|
||||||
}
|
}
|
||||||
auto* builtin = call->Target()->As<sem::Builtin>();
|
Symbol fn = Switch(
|
||||||
if (!builtin) {
|
call->Target(), //
|
||||||
return;
|
[&](const sem::Builtin* builtin) {
|
||||||
}
|
|
||||||
Symbol fn;
|
|
||||||
switch (builtin->Type()) {
|
switch (builtin->Type()) {
|
||||||
case sem::BuiltinType::kAcosh:
|
case sem::BuiltinType::kAcosh:
|
||||||
if (polyfill.acosh != Level::kNone) {
|
if (cfg.builtins.acosh != Level::kNone) {
|
||||||
fn = builtin_polyfills.GetOrCreate(
|
return builtin_polyfills.GetOrCreate(
|
||||||
builtin, [&] { return s.acosh(builtin->ReturnType()); });
|
builtin, [&] { return acosh(builtin->ReturnType()); });
|
||||||
}
|
}
|
||||||
break;
|
return Symbol{};
|
||||||
|
|
||||||
case sem::BuiltinType::kAsinh:
|
case sem::BuiltinType::kAsinh:
|
||||||
if (polyfill.asinh) {
|
if (cfg.builtins.asinh) {
|
||||||
fn = builtin_polyfills.GetOrCreate(
|
return builtin_polyfills.GetOrCreate(
|
||||||
builtin, [&] { return s.asinh(builtin->ReturnType()); });
|
builtin, [&] { return asinh(builtin->ReturnType()); });
|
||||||
}
|
}
|
||||||
break;
|
return Symbol{};
|
||||||
|
|
||||||
case sem::BuiltinType::kAtanh:
|
case sem::BuiltinType::kAtanh:
|
||||||
if (polyfill.atanh != Level::kNone) {
|
if (cfg.builtins.atanh != Level::kNone) {
|
||||||
fn = builtin_polyfills.GetOrCreate(
|
return builtin_polyfills.GetOrCreate(
|
||||||
builtin, [&] { return s.atanh(builtin->ReturnType()); });
|
builtin, [&] { return atanh(builtin->ReturnType()); });
|
||||||
}
|
}
|
||||||
break;
|
return Symbol{};
|
||||||
|
|
||||||
case sem::BuiltinType::kClamp:
|
case sem::BuiltinType::kClamp:
|
||||||
if (polyfill.clamp_int) {
|
if (cfg.builtins.clamp_int) {
|
||||||
auto& sig = builtin->Signature();
|
auto& sig = builtin->Signature();
|
||||||
if (sig.parameters[0]->Type()->is_integer_scalar_or_vector()) {
|
if (sig.parameters[0]->Type()->is_integer_scalar_or_vector()) {
|
||||||
fn = builtin_polyfills.GetOrCreate(
|
return builtin_polyfills.GetOrCreate(
|
||||||
builtin, [&] { return s.clampInteger(builtin->ReturnType()); });
|
builtin, [&] { return clampInteger(builtin->ReturnType()); });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
return Symbol{};
|
||||||
|
|
||||||
case sem::BuiltinType::kCountLeadingZeros:
|
case sem::BuiltinType::kCountLeadingZeros:
|
||||||
if (polyfill.count_leading_zeros) {
|
if (cfg.builtins.count_leading_zeros) {
|
||||||
fn = builtin_polyfills.GetOrCreate(builtin, [&] {
|
return builtin_polyfills.GetOrCreate(
|
||||||
return s.countLeadingZeros(builtin->ReturnType());
|
builtin, [&] { return countLeadingZeros(builtin->ReturnType()); });
|
||||||
});
|
|
||||||
}
|
}
|
||||||
break;
|
return Symbol{};
|
||||||
|
|
||||||
case sem::BuiltinType::kCountTrailingZeros:
|
case sem::BuiltinType::kCountTrailingZeros:
|
||||||
if (polyfill.count_trailing_zeros) {
|
if (cfg.builtins.count_trailing_zeros) {
|
||||||
fn = builtin_polyfills.GetOrCreate(builtin, [&] {
|
return builtin_polyfills.GetOrCreate(
|
||||||
return s.countTrailingZeros(builtin->ReturnType());
|
builtin, [&] { return countTrailingZeros(builtin->ReturnType()); });
|
||||||
});
|
|
||||||
}
|
}
|
||||||
break;
|
return Symbol{};
|
||||||
|
|
||||||
case sem::BuiltinType::kExtractBits:
|
case sem::BuiltinType::kExtractBits:
|
||||||
if (polyfill.extract_bits != Level::kNone) {
|
if (cfg.builtins.extract_bits != Level::kNone) {
|
||||||
fn = builtin_polyfills.GetOrCreate(
|
return builtin_polyfills.GetOrCreate(
|
||||||
builtin, [&] { return s.extractBits(builtin->ReturnType()); });
|
builtin, [&] { return extractBits(builtin->ReturnType()); });
|
||||||
}
|
}
|
||||||
break;
|
return Symbol{};
|
||||||
|
|
||||||
case sem::BuiltinType::kFirstLeadingBit:
|
case sem::BuiltinType::kFirstLeadingBit:
|
||||||
if (polyfill.first_leading_bit) {
|
if (cfg.builtins.first_leading_bit) {
|
||||||
fn = builtin_polyfills.GetOrCreate(
|
return builtin_polyfills.GetOrCreate(
|
||||||
builtin, [&] { return s.firstLeadingBit(builtin->ReturnType()); });
|
builtin, [&] { return firstLeadingBit(builtin->ReturnType()); });
|
||||||
}
|
}
|
||||||
break;
|
return Symbol{};
|
||||||
|
|
||||||
case sem::BuiltinType::kFirstTrailingBit:
|
case sem::BuiltinType::kFirstTrailingBit:
|
||||||
if (polyfill.first_trailing_bit) {
|
if (cfg.builtins.first_trailing_bit) {
|
||||||
fn = builtin_polyfills.GetOrCreate(
|
return builtin_polyfills.GetOrCreate(
|
||||||
builtin, [&] { return s.firstTrailingBit(builtin->ReturnType()); });
|
builtin, [&] { return firstTrailingBit(builtin->ReturnType()); });
|
||||||
}
|
}
|
||||||
break;
|
return Symbol{};
|
||||||
|
|
||||||
case sem::BuiltinType::kInsertBits:
|
case sem::BuiltinType::kInsertBits:
|
||||||
if (polyfill.insert_bits != Level::kNone) {
|
if (cfg.builtins.insert_bits != Level::kNone) {
|
||||||
fn = builtin_polyfills.GetOrCreate(
|
return builtin_polyfills.GetOrCreate(
|
||||||
builtin, [&] { return s.insertBits(builtin->ReturnType()); });
|
builtin, [&] { return insertBits(builtin->ReturnType()); });
|
||||||
}
|
}
|
||||||
break;
|
return Symbol{};
|
||||||
|
|
||||||
case sem::BuiltinType::kReflect:
|
case sem::BuiltinType::kReflect:
|
||||||
// Only polyfill for vec2<f32>. See https://crbug.com/tint/1798 for more
|
// Only polyfill for vec2<f32>. See https://crbug.com/tint/1798 for
|
||||||
// details.
|
// more details.
|
||||||
if (polyfill.reflect_vec2_f32) {
|
if (cfg.builtins.reflect_vec2_f32) {
|
||||||
auto& sig = builtin->Signature();
|
auto& sig = builtin->Signature();
|
||||||
auto* vec = sig.return_type->As<type::Vector>();
|
auto* vec = sig.return_type->As<type::Vector>();
|
||||||
if (vec && vec->Width() == 2 && vec->type()->Is<type::F32>()) {
|
if (vec && vec->Width() == 2 && vec->type()->Is<type::F32>()) {
|
||||||
fn = builtin_polyfills.GetOrCreate(
|
return builtin_polyfills.GetOrCreate(
|
||||||
builtin, [&] { return s.reflect(builtin->ReturnType()); });
|
builtin, [&] { return reflect(builtin->ReturnType()); });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
return Symbol{};
|
||||||
|
|
||||||
case sem::BuiltinType::kSaturate:
|
case sem::BuiltinType::kSaturate:
|
||||||
if (polyfill.saturate) {
|
if (cfg.builtins.saturate) {
|
||||||
fn = builtin_polyfills.GetOrCreate(
|
return builtin_polyfills.GetOrCreate(
|
||||||
builtin, [&] { return s.saturate(builtin->ReturnType()); });
|
builtin, [&] { return saturate(builtin->ReturnType()); });
|
||||||
}
|
}
|
||||||
break;
|
return Symbol{};
|
||||||
|
|
||||||
case sem::BuiltinType::kSign:
|
case sem::BuiltinType::kSign:
|
||||||
if (polyfill.sign_int) {
|
if (cfg.builtins.sign_int) {
|
||||||
auto* ty = builtin->ReturnType();
|
auto* ty = builtin->ReturnType();
|
||||||
if (ty->is_signed_integer_scalar_or_vector()) {
|
if (ty->is_signed_integer_scalar_or_vector()) {
|
||||||
fn = builtin_polyfills.GetOrCreate(builtin,
|
return builtin_polyfills.GetOrCreate(builtin,
|
||||||
[&] { return s.sign_int(ty); });
|
[&] { return sign_int(ty); });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
return Symbol{};
|
||||||
|
|
||||||
case sem::BuiltinType::kTextureSampleBaseClampToEdge:
|
case sem::BuiltinType::kTextureSampleBaseClampToEdge:
|
||||||
if (polyfill.texture_sample_base_clamp_to_edge_2d_f32) {
|
if (cfg.builtins.texture_sample_base_clamp_to_edge_2d_f32) {
|
||||||
auto& sig = builtin->Signature();
|
auto& sig = builtin->Signature();
|
||||||
auto* tex = sig.Parameter(sem::ParameterUsage::kTexture);
|
auto* tex = sig.Parameter(sem::ParameterUsage::kTexture);
|
||||||
if (auto* stex = tex->Type()->As<type::SampledTexture>()) {
|
if (auto* stex = tex->Type()->As<type::SampledTexture>()) {
|
||||||
if (stex->type()->Is<type::F32>()) {
|
if (stex->type()->Is<type::F32>()) {
|
||||||
fn = builtin_polyfills.GetOrCreate(builtin, [&] {
|
return builtin_polyfills.GetOrCreate(builtin, [&] {
|
||||||
return s.textureSampleBaseClampToEdge_2d_f32();
|
return textureSampleBaseClampToEdge_2d_f32();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
return Symbol{};
|
||||||
|
|
||||||
case sem::BuiltinType::kTextureStore:
|
case sem::BuiltinType::kTextureStore:
|
||||||
if (polyfill.bgra8unorm) {
|
if (cfg.builtins.bgra8unorm) {
|
||||||
auto& sig = builtin->Signature();
|
auto& sig = builtin->Signature();
|
||||||
auto* tex = sig.Parameter(sem::ParameterUsage::kTexture);
|
auto* tex = sig.Parameter(sem::ParameterUsage::kTexture);
|
||||||
if (auto* stex = tex->Type()->As<type::StorageTexture>()) {
|
if (auto* stex = tex->Type()->As<type::StorageTexture>()) {
|
||||||
if (stex->texel_format() == builtin::TexelFormat::kBgra8Unorm) {
|
if (stex->texel_format() == builtin::TexelFormat::kBgra8Unorm) {
|
||||||
size_t value_idx = static_cast<size_t>(
|
size_t value_idx = static_cast<size_t>(
|
||||||
sig.IndexOf(sem::ParameterUsage::kValue));
|
sig.IndexOf(sem::ParameterUsage::kValue));
|
||||||
ctx.Replace(expr, [&ctx, expr, value_idx] {
|
ctx.Replace(expr, [this, expr, value_idx] {
|
||||||
utils::Vector<const ast::Expression*, 3> args;
|
utils::Vector<const ast::Expression*, 3> args;
|
||||||
for (auto* arg : expr->args) {
|
for (auto* arg : expr->args) {
|
||||||
arg = ctx.Clone(arg);
|
arg = ctx.Clone(arg);
|
||||||
|
@ -1099,106 +1172,50 @@ Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
return Symbol{};
|
||||||
|
|
||||||
case sem::BuiltinType::kQuantizeToF16:
|
case sem::BuiltinType::kQuantizeToF16:
|
||||||
if (polyfill.quantize_to_vec_f16) {
|
if (cfg.builtins.quantize_to_vec_f16) {
|
||||||
if (auto* vec = builtin->ReturnType()->As<type::Vector>()) {
|
if (auto* vec = builtin->ReturnType()->As<type::Vector>()) {
|
||||||
fn = builtin_polyfills.GetOrCreate(
|
return builtin_polyfills.GetOrCreate(
|
||||||
builtin, [&] { return s.quantizeToF16(vec); });
|
builtin, [&] { return quantizeToF16(vec); });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
return Symbol{};
|
||||||
|
|
||||||
case sem::BuiltinType::kWorkgroupUniformLoad:
|
case sem::BuiltinType::kWorkgroupUniformLoad:
|
||||||
if (polyfill.workgroup_uniform_load) {
|
if (cfg.builtins.workgroup_uniform_load) {
|
||||||
fn = builtin_polyfills.GetOrCreate(builtin, [&] {
|
return builtin_polyfills.GetOrCreate(builtin, [&] {
|
||||||
return s.workgroupUniformLoad(builtin->ReturnType());
|
return workgroupUniformLoad(builtin->ReturnType());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
break;
|
return Symbol{};
|
||||||
|
|
||||||
default:
|
default:
|
||||||
break;
|
return Symbol{};
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
|
||||||
if (fn.IsValid()) {
|
if (fn.IsValid()) {
|
||||||
ctx.Replace(call->Declaration(), [&ctx, fn, expr] {
|
ctx.Replace(call->Declaration(),
|
||||||
return ctx.dst->Call(fn, ctx.Clone(expr->args));
|
[this, fn, expr] { return ctx.dst->Call(fn, ctx.Clone(expr->args)); });
|
||||||
});
|
|
||||||
made_changes = true;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[&](const ast::BinaryExpression* bin_op) {
|
|
||||||
if (auto* sem = src->Sem().Get(bin_op);
|
|
||||||
!sem || sem->Stage() == sem::EvaluationStage::kConstant ||
|
|
||||||
sem->Stage() == sem::EvaluationStage::kNotEvaluated) {
|
|
||||||
return; // Don't polyfill @const expressions
|
|
||||||
}
|
|
||||||
switch (bin_op->op) {
|
|
||||||
case ast::BinaryOp::kShiftLeft:
|
|
||||||
case ast::BinaryOp::kShiftRight: {
|
|
||||||
if (polyfill.bitshift_modulo) {
|
|
||||||
ctx.Replace(bin_op, [bin_op, &s] { return s.BitshiftModulo(bin_op); });
|
|
||||||
made_changes = true;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case ast::BinaryOp::kDivide: {
|
|
||||||
if (polyfill.int_div_mod) {
|
|
||||||
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
|
|
||||||
if (lhs_ty->is_integer_scalar_or_vector()) {
|
|
||||||
ctx.Replace(bin_op, [bin_op, &s] { return s.IntDivMod(bin_op); });
|
|
||||||
made_changes = true;
|
made_changes = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
};
|
||||||
}
|
|
||||||
case ast::BinaryOp::kModulo: {
|
|
||||||
if (polyfill.int_div_mod) {
|
|
||||||
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
|
|
||||||
if (lhs_ty->is_integer_scalar_or_vector()) {
|
|
||||||
ctx.Replace(bin_op, [bin_op, &s] { return s.IntDivMod(bin_op); });
|
|
||||||
made_changes = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (polyfill.precise_float_mod) {
|
|
||||||
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
|
|
||||||
if (lhs_ty->is_float_scalar_or_vector()) {
|
|
||||||
ctx.Replace(bin_op,
|
|
||||||
[bin_op, &s] { return s.PreciseFloatMod(bin_op); });
|
|
||||||
made_changes = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[&](const ast::Expression* expr) {
|
|
||||||
if (polyfill.bgra8unorm) {
|
|
||||||
if (auto* ty_expr = src->Sem().Get<sem::TypeExpression>(expr)) {
|
|
||||||
if (auto* tex = ty_expr->Type()->As<type::StorageTexture>()) {
|
|
||||||
if (tex->texel_format() == builtin::TexelFormat::kBgra8Unorm) {
|
|
||||||
ctx.Replace(expr, [&ctx, tex] {
|
|
||||||
return ctx.dst->Expr(ctx.dst->ty.storage_texture(
|
|
||||||
tex->dim(), builtin::TexelFormat::kRgba8Unorm,
|
|
||||||
tex->access()));
|
|
||||||
});
|
|
||||||
made_changes = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!made_changes) {
|
BuiltinPolyfill::BuiltinPolyfill() = default;
|
||||||
|
|
||||||
|
BuiltinPolyfill::~BuiltinPolyfill() = default;
|
||||||
|
|
||||||
|
Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src,
|
||||||
|
const DataMap& data,
|
||||||
|
DataMap&) const {
|
||||||
|
auto* cfg = data.Get<Config>();
|
||||||
|
if (!cfg) {
|
||||||
return SkipTransform;
|
return SkipTransform;
|
||||||
}
|
}
|
||||||
|
return State{src, *cfg}.Run();
|
||||||
ctx.Clone();
|
|
||||||
return Program(std::move(b));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BuiltinPolyfill::Config::Config(const Builtins& b) : builtins(b) {}
|
BuiltinPolyfill::Config::Config(const Builtins& b) : builtins(b) {}
|
||||||
|
|
Loading…
Reference in New Issue