diff --git a/src/tint/transform/builtin_polyfill.cc b/src/tint/transform/builtin_polyfill.cc index 259eb1117b..587dcf05ad 100644 --- a/src/tint/transform/builtin_polyfill.cc +++ b/src/tint/transform/builtin_polyfill.cc @@ -40,17 +40,100 @@ using BinaryOpSignature = std::tupleAST().Enables()) { + for (auto* enable : src->AST().Enables()) { if (enable->extension == builtin::Extension::kChromiumExperimentalFullPtrParameters) { has_full_ptr_params = true; } } } + /// Runs the transform + /// @returns the new program or SkipTransform if the transform is not required + Transform::ApplyResult Run() { + for (auto* node : src->ASTNodes().Objects()) { + Switch( + node, // + [&](const ast::CallExpression* expr) { Call(expr); }, + [&](const ast::BinaryExpression* bin_op) { + if (auto* s = src->Sem().Get(bin_op); + !s || s->Stage() == sem::EvaluationStage::kConstant || + s->Stage() == sem::EvaluationStage::kNotEvaluated) { + return; // Don't polyfill @const expressions + } + switch (bin_op->op) { + case ast::BinaryOp::kShiftLeft: + case ast::BinaryOp::kShiftRight: { + if (cfg.builtins.bitshift_modulo) { + ctx.Replace(bin_op, + [this, bin_op] { return BitshiftModulo(bin_op); }); + made_changes = true; + } + break; + } + case ast::BinaryOp::kDivide: { + if (cfg.builtins.int_div_mod) { + auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef(); + if (lhs_ty->is_integer_scalar_or_vector()) { + ctx.Replace(bin_op, + [this, bin_op] { return IntDivMod(bin_op); }); + made_changes = true; + } + } + break; + } + case ast::BinaryOp::kModulo: { + if (cfg.builtins.int_div_mod) { + auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef(); + if (lhs_ty->is_integer_scalar_or_vector()) { + ctx.Replace(bin_op, + [this, bin_op] { return IntDivMod(bin_op); }); + made_changes = true; + } + } + if (cfg.builtins.precise_float_mod) { + auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef(); + if (lhs_ty->is_float_scalar_or_vector()) { + ctx.Replace(bin_op, + [this, bin_op] { return PreciseFloatMod(bin_op); }); + made_changes = true; + } + } + break; + } + default: + break; + } + }, + [&](const ast::Expression* expr) { + if (cfg.builtins.bgra8unorm) { + if (auto* ty_expr = src->Sem().Get(expr)) { + if (auto* tex = ty_expr->Type()->As()) { + if (tex->texel_format() == builtin::TexelFormat::kBgra8Unorm) { + ctx.Replace(expr, [this, tex] { + return ctx.dst->Expr(ctx.dst->ty.storage_texture( + tex->dim(), builtin::TexelFormat::kRgba8Unorm, + tex->access())); + }); + made_changes = true; + } + } + } + } + }); + } + + if (!made_changes) { + return SkipTransform; + } + + ctx.Clone(); + return Program(std::move(b)); + } + //////////////////////////////////////////////////////////////////////////// // Function polyfills //////////////////////////////////////////////////////////////////////////// @@ -71,7 +154,7 @@ struct BuiltinPolyfill::State { }; utils::Vector body; - switch (polyfill.acosh) { + switch (cfg.builtins.acosh) { case Level::kFull: // return log(x + sqrt(x*x - 1)); body.Push(b.Return( @@ -85,7 +168,7 @@ struct BuiltinPolyfill::State { } default: TINT_ICE(Transform, b.Diagnostics()) - << "unhandled polyfill level: " << static_cast(polyfill.acosh); + << "unhandled polyfill level: " << static_cast(cfg.builtins.acosh); return {}; } @@ -125,7 +208,7 @@ struct BuiltinPolyfill::State { }; utils::Vector body; - switch (polyfill.atanh) { + switch (cfg.builtins.atanh) { case Level::kFull: // return log((1+x) / (1-x)) * 0.5 body.Push( @@ -138,7 +221,7 @@ struct BuiltinPolyfill::State { break; default: TINT_ICE(Transform, b.Diagnostics()) - << "unhandled polyfill level: " << static_cast(polyfill.acosh); + << "unhandled polyfill level: " << static_cast(cfg.builtins.acosh); return {}; } @@ -306,7 +389,7 @@ struct BuiltinPolyfill::State { b.Decl(b.Let("e", b.Call("min", u32(W), b.Add("s", "count")))), }; - switch (polyfill.extract_bits) { + switch (cfg.builtins.extract_bits) { case Level::kFull: body.Push(b.Decl(b.Let("shl", b.Sub(u32(W), "e")))); body.Push(b.Decl(b.Let("shr", b.Add("shl", "s")))); @@ -328,7 +411,7 @@ struct BuiltinPolyfill::State { break; default: TINT_ICE(Transform, b.Diagnostics()) - << "unhandled polyfill level: " << static_cast(polyfill.extract_bits); + << "unhandled polyfill level: " << static_cast(cfg.builtins.extract_bits); return {}; } @@ -532,7 +615,7 @@ struct BuiltinPolyfill::State { utils::Vector body; - switch (polyfill.insert_bits) { + switch (cfg.builtins.insert_bits) { case Level::kFull: // let e = offset + count; body.Push(b.Decl(b.Let("e", b.Add("offset", "count")))); @@ -566,7 +649,7 @@ struct BuiltinPolyfill::State { break; default: TINT_ICE(Transform, b.Diagnostics()) - << "unhandled polyfill level: " << static_cast(polyfill.insert_bits); + << "unhandled polyfill level: " << static_cast(cfg.builtins.insert_bits); return {}; } @@ -727,8 +810,8 @@ struct BuiltinPolyfill::State { /// @param bin_op the original BinaryExpression /// @return the polyfill value for bitshift operation const ast::Expression* BitshiftModulo(const ast::BinaryExpression* bin_op) { - auto* lhs_ty = ctx.src->TypeOf(bin_op->lhs)->UnwrapRef(); - auto* rhs_ty = ctx.src->TypeOf(bin_op->rhs)->UnwrapRef(); + auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef(); + auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef(); auto* lhs_el_ty = type::Type::DeepestElementOf(lhs_ty); const ast::Expression* mask = b.Expr(AInt(lhs_el_ty->Size() * 8 - 1)); if (rhs_ty->Is()) { @@ -744,8 +827,8 @@ struct BuiltinPolyfill::State { /// @param bin_op the original BinaryExpression /// @return the polyfill divide or modulo const ast::Expression* IntDivMod(const ast::BinaryExpression* bin_op) { - auto* lhs_ty = ctx.src->TypeOf(bin_op->lhs)->UnwrapRef(); - auto* rhs_ty = ctx.src->TypeOf(bin_op->rhs)->UnwrapRef(); + auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef(); + auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef(); BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty}; auto fn = binary_op_polyfills.GetOrCreate(sig, [&] { const bool is_div = bin_op->op == ast::BinaryOp::kDivide; @@ -839,8 +922,8 @@ struct BuiltinPolyfill::State { /// @param bin_op the original BinaryExpression /// @return the polyfill divide or modulo const ast::Expression* PreciseFloatMod(const ast::BinaryExpression* bin_op) { - auto* lhs_ty = ctx.src->TypeOf(bin_op->lhs)->UnwrapRef(); - auto* rhs_ty = ctx.src->TypeOf(bin_op->rhs)->UnwrapRef(); + auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef(); + auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef(); BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty}; auto fn = binary_op_polyfills.GetOrCreate(sig, [&] { uint32_t lhs_width = 1; @@ -888,23 +971,27 @@ struct BuiltinPolyfill::State { } private: - /// The clone context - CloneContext& ctx; - /// The builtins to polyfill - Builtins polyfill; + /// The source program + Program const* const src; + /// The transform config + const Config& cfg; /// The destination program builder - ProgramBuilder& b = *ctx.dst; + ProgramBuilder b; + /// The clone context + CloneContext ctx{&b, src}; /// The source clone context - const sem::Info& sem = ctx.src->Sem(); - - // Polyfill functions for binary operators. + const sem::Info& sem = src->Sem(); + /// Polyfill functions for binary operators. utils::Hashmap binary_op_polyfills; - + /// Polyfill builtins. + utils::Hashmap builtin_polyfills; // Tracks whether the chromium_experimental_full_ptr_parameters extension has been enabled. - bool has_full_ptr_params; + bool has_full_ptr_params = false; + /// True if the transform has made changes (i.e. the program needs cloning) + bool made_changes = false; /// @returns the AST type for the given sem type - ast::Type T(const type::Type* ty) const { return CreateASTTypeFor(ctx, ty); } + ast::Type T(const type::Type* ty) { return CreateASTTypeFor(ctx, ty); } /// @returns 1 if `ty` is not a vector, otherwise the vector width uint32_t WidthOf(const type::Type* ty) const { @@ -917,7 +1004,7 @@ struct BuiltinPolyfill::State { /// @returns a scalar or vector with the given width, with each element with /// the given value. template - const ast::Expression* ScalarOrVector(uint32_t width, T value) const { + const ast::Expression* ScalarOrVector(uint32_t width, T value) { if (width == 1) { return b.Expr(value); } @@ -931,158 +1018,144 @@ struct BuiltinPolyfill::State { } return b.Call(b.ty.vec(width), e); } -}; -BuiltinPolyfill::BuiltinPolyfill() = default; - -BuiltinPolyfill::~BuiltinPolyfill() = default; - -Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src, - const DataMap& data, - DataMap&) const { - auto* cfg = data.Get(); - if (!cfg) { - return SkipTransform; - } - - auto& polyfill = cfg->builtins; - - utils::Hashmap builtin_polyfills; - - ProgramBuilder b; - CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; - State s{ctx, polyfill}; - - bool made_changes = false; - for (auto* node : src->ASTNodes().Objects()) { - Switch( - node, - [&](const ast::CallExpression* expr) { - auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As(); - if (!call || call->Stage() == sem::EvaluationStage::kConstant || - call->Stage() == sem::EvaluationStage::kNotEvaluated) { - return; // Don't polyfill @const expressions - } - auto* builtin = call->Target()->As(); - if (!builtin) { - return; - } - Symbol fn; + /// Examines the call expression @p expr, applying any necessary polyfill transforms + void Call(const ast::CallExpression* expr) { + auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As(); + if (!call || call->Stage() == sem::EvaluationStage::kConstant || + call->Stage() == sem::EvaluationStage::kNotEvaluated) { + return; // Don't polyfill @const expressions + } + Symbol fn = Switch( + call->Target(), // + [&](const sem::Builtin* builtin) { switch (builtin->Type()) { case sem::BuiltinType::kAcosh: - if (polyfill.acosh != Level::kNone) { - fn = builtin_polyfills.GetOrCreate( - builtin, [&] { return s.acosh(builtin->ReturnType()); }); + if (cfg.builtins.acosh != Level::kNone) { + return builtin_polyfills.GetOrCreate( + builtin, [&] { return acosh(builtin->ReturnType()); }); } - break; + return Symbol{}; + case sem::BuiltinType::kAsinh: - if (polyfill.asinh) { - fn = builtin_polyfills.GetOrCreate( - builtin, [&] { return s.asinh(builtin->ReturnType()); }); + if (cfg.builtins.asinh) { + return builtin_polyfills.GetOrCreate( + builtin, [&] { return asinh(builtin->ReturnType()); }); } - break; + return Symbol{}; + case sem::BuiltinType::kAtanh: - if (polyfill.atanh != Level::kNone) { - fn = builtin_polyfills.GetOrCreate( - builtin, [&] { return s.atanh(builtin->ReturnType()); }); + if (cfg.builtins.atanh != Level::kNone) { + return builtin_polyfills.GetOrCreate( + builtin, [&] { return atanh(builtin->ReturnType()); }); } - break; + return Symbol{}; + case sem::BuiltinType::kClamp: - if (polyfill.clamp_int) { + if (cfg.builtins.clamp_int) { auto& sig = builtin->Signature(); if (sig.parameters[0]->Type()->is_integer_scalar_or_vector()) { - fn = builtin_polyfills.GetOrCreate( - builtin, [&] { return s.clampInteger(builtin->ReturnType()); }); + return builtin_polyfills.GetOrCreate( + builtin, [&] { return clampInteger(builtin->ReturnType()); }); } } - break; + return Symbol{}; + case sem::BuiltinType::kCountLeadingZeros: - if (polyfill.count_leading_zeros) { - fn = builtin_polyfills.GetOrCreate(builtin, [&] { - return s.countLeadingZeros(builtin->ReturnType()); - }); + if (cfg.builtins.count_leading_zeros) { + return builtin_polyfills.GetOrCreate( + builtin, [&] { return countLeadingZeros(builtin->ReturnType()); }); } - break; + return Symbol{}; + case sem::BuiltinType::kCountTrailingZeros: - if (polyfill.count_trailing_zeros) { - fn = builtin_polyfills.GetOrCreate(builtin, [&] { - return s.countTrailingZeros(builtin->ReturnType()); - }); + if (cfg.builtins.count_trailing_zeros) { + return builtin_polyfills.GetOrCreate( + builtin, [&] { return countTrailingZeros(builtin->ReturnType()); }); } - break; + return Symbol{}; + case sem::BuiltinType::kExtractBits: - if (polyfill.extract_bits != Level::kNone) { - fn = builtin_polyfills.GetOrCreate( - builtin, [&] { return s.extractBits(builtin->ReturnType()); }); + if (cfg.builtins.extract_bits != Level::kNone) { + return builtin_polyfills.GetOrCreate( + builtin, [&] { return extractBits(builtin->ReturnType()); }); } - break; + return Symbol{}; + case sem::BuiltinType::kFirstLeadingBit: - if (polyfill.first_leading_bit) { - fn = builtin_polyfills.GetOrCreate( - builtin, [&] { return s.firstLeadingBit(builtin->ReturnType()); }); + if (cfg.builtins.first_leading_bit) { + return builtin_polyfills.GetOrCreate( + builtin, [&] { return firstLeadingBit(builtin->ReturnType()); }); } - break; + return Symbol{}; + case sem::BuiltinType::kFirstTrailingBit: - if (polyfill.first_trailing_bit) { - fn = builtin_polyfills.GetOrCreate( - builtin, [&] { return s.firstTrailingBit(builtin->ReturnType()); }); + if (cfg.builtins.first_trailing_bit) { + return builtin_polyfills.GetOrCreate( + builtin, [&] { return firstTrailingBit(builtin->ReturnType()); }); } - break; + return Symbol{}; + case sem::BuiltinType::kInsertBits: - if (polyfill.insert_bits != Level::kNone) { - fn = builtin_polyfills.GetOrCreate( - builtin, [&] { return s.insertBits(builtin->ReturnType()); }); + if (cfg.builtins.insert_bits != Level::kNone) { + return builtin_polyfills.GetOrCreate( + builtin, [&] { return insertBits(builtin->ReturnType()); }); } - break; + return Symbol{}; + case sem::BuiltinType::kReflect: - // Only polyfill for vec2. See https://crbug.com/tint/1798 for more - // details. - if (polyfill.reflect_vec2_f32) { + // Only polyfill for vec2. See https://crbug.com/tint/1798 for + // more details. + if (cfg.builtins.reflect_vec2_f32) { auto& sig = builtin->Signature(); auto* vec = sig.return_type->As(); if (vec && vec->Width() == 2 && vec->type()->Is()) { - fn = builtin_polyfills.GetOrCreate( - builtin, [&] { return s.reflect(builtin->ReturnType()); }); + return builtin_polyfills.GetOrCreate( + builtin, [&] { return reflect(builtin->ReturnType()); }); } } - break; + return Symbol{}; + case sem::BuiltinType::kSaturate: - if (polyfill.saturate) { - fn = builtin_polyfills.GetOrCreate( - builtin, [&] { return s.saturate(builtin->ReturnType()); }); + if (cfg.builtins.saturate) { + return builtin_polyfills.GetOrCreate( + builtin, [&] { return saturate(builtin->ReturnType()); }); } - break; + return Symbol{}; + case sem::BuiltinType::kSign: - if (polyfill.sign_int) { + if (cfg.builtins.sign_int) { auto* ty = builtin->ReturnType(); if (ty->is_signed_integer_scalar_or_vector()) { - fn = builtin_polyfills.GetOrCreate(builtin, - [&] { return s.sign_int(ty); }); + return builtin_polyfills.GetOrCreate(builtin, + [&] { return sign_int(ty); }); } } - break; + return Symbol{}; + case sem::BuiltinType::kTextureSampleBaseClampToEdge: - if (polyfill.texture_sample_base_clamp_to_edge_2d_f32) { + if (cfg.builtins.texture_sample_base_clamp_to_edge_2d_f32) { auto& sig = builtin->Signature(); auto* tex = sig.Parameter(sem::ParameterUsage::kTexture); if (auto* stex = tex->Type()->As()) { if (stex->type()->Is()) { - fn = builtin_polyfills.GetOrCreate(builtin, [&] { - return s.textureSampleBaseClampToEdge_2d_f32(); + return builtin_polyfills.GetOrCreate(builtin, [&] { + return textureSampleBaseClampToEdge_2d_f32(); }); } } } - break; + return Symbol{}; + case sem::BuiltinType::kTextureStore: - if (polyfill.bgra8unorm) { + if (cfg.builtins.bgra8unorm) { auto& sig = builtin->Signature(); auto* tex = sig.Parameter(sem::ParameterUsage::kTexture); if (auto* stex = tex->Type()->As()) { if (stex->texel_format() == builtin::TexelFormat::kBgra8Unorm) { size_t value_idx = static_cast( sig.IndexOf(sem::ParameterUsage::kValue)); - ctx.Replace(expr, [&ctx, expr, value_idx] { + ctx.Replace(expr, [this, expr, value_idx] { utils::Vector args; for (auto* arg : expr->args) { arg = ctx.Clone(arg); @@ -1099,106 +1172,50 @@ Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src, } } } - break; + return Symbol{}; + case sem::BuiltinType::kQuantizeToF16: - if (polyfill.quantize_to_vec_f16) { + if (cfg.builtins.quantize_to_vec_f16) { if (auto* vec = builtin->ReturnType()->As()) { - fn = builtin_polyfills.GetOrCreate( - builtin, [&] { return s.quantizeToF16(vec); }); + return builtin_polyfills.GetOrCreate( + builtin, [&] { return quantizeToF16(vec); }); } } - break; + return Symbol{}; case sem::BuiltinType::kWorkgroupUniformLoad: - if (polyfill.workgroup_uniform_load) { - fn = builtin_polyfills.GetOrCreate(builtin, [&] { - return s.workgroupUniformLoad(builtin->ReturnType()); + if (cfg.builtins.workgroup_uniform_load) { + return builtin_polyfills.GetOrCreate(builtin, [&] { + return workgroupUniformLoad(builtin->ReturnType()); }); } - break; + return Symbol{}; default: - break; - } - - if (fn.IsValid()) { - ctx.Replace(call->Declaration(), [&ctx, fn, expr] { - return ctx.dst->Call(fn, ctx.Clone(expr->args)); - }); - made_changes = true; - } - }, - [&](const ast::BinaryExpression* bin_op) { - if (auto* sem = src->Sem().Get(bin_op); - !sem || sem->Stage() == sem::EvaluationStage::kConstant || - sem->Stage() == sem::EvaluationStage::kNotEvaluated) { - return; // Don't polyfill @const expressions - } - switch (bin_op->op) { - case ast::BinaryOp::kShiftLeft: - case ast::BinaryOp::kShiftRight: { - if (polyfill.bitshift_modulo) { - ctx.Replace(bin_op, [bin_op, &s] { return s.BitshiftModulo(bin_op); }); - made_changes = true; - } - break; - } - case ast::BinaryOp::kDivide: { - if (polyfill.int_div_mod) { - auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef(); - if (lhs_ty->is_integer_scalar_or_vector()) { - ctx.Replace(bin_op, [bin_op, &s] { return s.IntDivMod(bin_op); }); - made_changes = true; - } - } - break; - } - case ast::BinaryOp::kModulo: { - if (polyfill.int_div_mod) { - auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef(); - if (lhs_ty->is_integer_scalar_or_vector()) { - ctx.Replace(bin_op, [bin_op, &s] { return s.IntDivMod(bin_op); }); - made_changes = true; - } - } - if (polyfill.precise_float_mod) { - auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef(); - if (lhs_ty->is_float_scalar_or_vector()) { - ctx.Replace(bin_op, - [bin_op, &s] { return s.PreciseFloatMod(bin_op); }); - made_changes = true; - } - } - break; - } - default: - break; - } - }, - [&](const ast::Expression* expr) { - if (polyfill.bgra8unorm) { - if (auto* ty_expr = src->Sem().Get(expr)) { - if (auto* tex = ty_expr->Type()->As()) { - if (tex->texel_format() == builtin::TexelFormat::kBgra8Unorm) { - ctx.Replace(expr, [&ctx, tex] { - return ctx.dst->Expr(ctx.dst->ty.storage_texture( - tex->dim(), builtin::TexelFormat::kRgba8Unorm, - tex->access())); - }); - made_changes = true; - } - } - } + return Symbol{}; } }); - } - if (!made_changes) { + if (fn.IsValid()) { + ctx.Replace(call->Declaration(), + [this, fn, expr] { return ctx.dst->Call(fn, ctx.Clone(expr->args)); }); + made_changes = true; + } + } +}; + +BuiltinPolyfill::BuiltinPolyfill() = default; + +BuiltinPolyfill::~BuiltinPolyfill() = default; + +Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src, + const DataMap& data, + DataMap&) const { + auto* cfg = data.Get(); + if (!cfg) { return SkipTransform; } - - ctx.Clone(); - return Program(std::move(b)); + return State{src, *cfg}.Run(); } BuiltinPolyfill::Config::Config(const Builtins& b) : builtins(b) {}