tint/transform: Polyfill bit-shift with RHS modulo

Fixed: tint:1453
Fixed: tint:1543
Change-Id: Idb5af752d7a3bb9e181cc47430ad4ddfb707873d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/108440
Auto-Submit: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
Ben Clayton
2022-11-03 19:15:17 +00:00
committed by Dawn LUCI CQ
parent 91e27f25f9
commit 02f04d914d
47 changed files with 345 additions and 157 deletions

View File

@@ -614,7 +614,7 @@ Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src,
auto& builtins = cfg->builtins;
utils::Hashmap<const sem::Builtin*, Symbol, 8> polyfills;
utils::Hashmap<const sem::Builtin*, Symbol, 8> builtin_polyfills;
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
@@ -622,114 +622,137 @@ Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src,
bool made_changes = false;
for (auto* node : src->ASTNodes().Objects()) {
if (auto* call = src->Sem().Get<sem::Call>(node)) {
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
if (call->Stage() == sem::EvaluationStage::kConstant) {
continue; // Don't polyfill @const expressions
}
Symbol polyfill;
switch (builtin->Type()) {
case sem::BuiltinType::kAcosh:
if (builtins.acosh != Level::kNone) {
polyfill = polyfills.GetOrCreate(
builtin, [&] { return s.acosh(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kAsinh:
if (builtins.asinh) {
polyfill = polyfills.GetOrCreate(
builtin, [&] { return s.asinh(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kAtanh:
if (builtins.atanh != Level::kNone) {
polyfill = polyfills.GetOrCreate(
builtin, [&] { return s.atanh(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kClamp:
if (builtins.clamp_int) {
auto& sig = builtin->Signature();
if (sig.parameters[0]->Type()->is_integer_scalar_or_vector()) {
polyfill = polyfills.GetOrCreate(
builtin, [&] { return s.clampInteger(builtin->ReturnType()); });
}
}
break;
case sem::BuiltinType::kCountLeadingZeros:
if (builtins.count_leading_zeros) {
polyfill = polyfills.GetOrCreate(builtin, [&] {
return s.countLeadingZeros(builtin->ReturnType());
});
}
break;
case sem::BuiltinType::kCountTrailingZeros:
if (builtins.count_trailing_zeros) {
polyfill = polyfills.GetOrCreate(builtin, [&] {
return s.countTrailingZeros(builtin->ReturnType());
});
}
break;
case sem::BuiltinType::kExtractBits:
if (builtins.extract_bits != Level::kNone) {
polyfill = polyfills.GetOrCreate(
builtin, [&] { return s.extractBits(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kFirstLeadingBit:
if (builtins.first_leading_bit) {
polyfill = polyfills.GetOrCreate(
builtin, [&] { return s.firstLeadingBit(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kFirstTrailingBit:
if (builtins.first_trailing_bit) {
polyfill = polyfills.GetOrCreate(
builtin, [&] { return s.firstTrailingBit(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kInsertBits:
if (builtins.insert_bits != Level::kNone) {
polyfill = polyfills.GetOrCreate(
builtin, [&] { return s.insertBits(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kSaturate:
if (builtins.saturate) {
polyfill = polyfills.GetOrCreate(
builtin, [&] { return s.saturate(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kTextureSampleBaseClampToEdge:
if (builtins.texture_sample_base_clamp_to_edge_2d_f32) {
auto& sig = builtin->Signature();
auto* tex = sig.Parameter(sem::ParameterUsage::kTexture);
if (auto* stex = tex->Type()->As<sem::SampledTexture>()) {
if (stex->type()->Is<sem::F32>()) {
polyfill = polyfills.GetOrCreate(builtin, [&] {
return s.textureSampleBaseClampToEdge_2d_f32();
});
}
}
}
break;
case sem::BuiltinType::kQuantizeToF16:
if (builtins.quantize_to_vec_f16) {
if (auto* vec = builtin->ReturnType()->As<sem::Vector>()) {
polyfill = polyfills.GetOrCreate(
builtin, [&] { return s.quantizeToF16(vec); });
}
}
break;
auto* expr = src->Sem().Get<sem::Expression>(node);
if (!expr || expr->Stage() == sem::EvaluationStage::kConstant) {
continue; // Don't polyfill @const expressions
}
default:
break;
}
if (polyfill.IsValid()) {
auto* replacement = s.b.Call(polyfill, ctx.Clone(call->Declaration()->args));
ctx.Replace(call->Declaration(), replacement);
made_changes = true;
}
if (auto* call = expr->As<sem::Call>()) {
auto* builtin = call->Target()->As<sem::Builtin>();
if (!builtin) {
continue;
}
Symbol polyfill;
switch (builtin->Type()) {
case sem::BuiltinType::kAcosh:
if (builtins.acosh != Level::kNone) {
polyfill = builtin_polyfills.GetOrCreate(
builtin, [&] { return s.acosh(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kAsinh:
if (builtins.asinh) {
polyfill = builtin_polyfills.GetOrCreate(
builtin, [&] { return s.asinh(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kAtanh:
if (builtins.atanh != Level::kNone) {
polyfill = builtin_polyfills.GetOrCreate(
builtin, [&] { return s.atanh(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kClamp:
if (builtins.clamp_int) {
auto& sig = builtin->Signature();
if (sig.parameters[0]->Type()->is_integer_scalar_or_vector()) {
polyfill = builtin_polyfills.GetOrCreate(
builtin, [&] { return s.clampInteger(builtin->ReturnType()); });
}
}
break;
case sem::BuiltinType::kCountLeadingZeros:
if (builtins.count_leading_zeros) {
polyfill = builtin_polyfills.GetOrCreate(
builtin, [&] { return s.countLeadingZeros(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kCountTrailingZeros:
if (builtins.count_trailing_zeros) {
polyfill = builtin_polyfills.GetOrCreate(
builtin, [&] { return s.countTrailingZeros(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kExtractBits:
if (builtins.extract_bits != Level::kNone) {
polyfill = builtin_polyfills.GetOrCreate(
builtin, [&] { return s.extractBits(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kFirstLeadingBit:
if (builtins.first_leading_bit) {
polyfill = builtin_polyfills.GetOrCreate(
builtin, [&] { return s.firstLeadingBit(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kFirstTrailingBit:
if (builtins.first_trailing_bit) {
polyfill = builtin_polyfills.GetOrCreate(
builtin, [&] { return s.firstTrailingBit(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kInsertBits:
if (builtins.insert_bits != Level::kNone) {
polyfill = builtin_polyfills.GetOrCreate(
builtin, [&] { return s.insertBits(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kSaturate:
if (builtins.saturate) {
polyfill = builtin_polyfills.GetOrCreate(
builtin, [&] { return s.saturate(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kTextureSampleBaseClampToEdge:
if (builtins.texture_sample_base_clamp_to_edge_2d_f32) {
auto& sig = builtin->Signature();
auto* tex = sig.Parameter(sem::ParameterUsage::kTexture);
if (auto* stex = tex->Type()->As<sem::SampledTexture>()) {
if (stex->type()->Is<sem::F32>()) {
polyfill = builtin_polyfills.GetOrCreate(builtin, [&] {
return s.textureSampleBaseClampToEdge_2d_f32();
});
}
}
}
break;
case sem::BuiltinType::kQuantizeToF16:
if (builtins.quantize_to_vec_f16) {
if (auto* vec = builtin->ReturnType()->As<sem::Vector>()) {
polyfill = builtin_polyfills.GetOrCreate(
builtin, [&] { return s.quantizeToF16(vec); });
}
}
break;
default:
break;
}
if (polyfill.IsValid()) {
auto* replacement = s.b.Call(polyfill, ctx.Clone(call->Declaration()->args));
ctx.Replace(call->Declaration(), replacement);
made_changes = true;
}
} else if (auto* bin_op = node->As<ast::BinaryExpression>()) {
switch (bin_op->op) {
case ast::BinaryOp::kShiftLeft:
case ast::BinaryOp::kShiftRight:
if (builtins.bitshift_modulo) {
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
auto* lhs_el_ty = sem::Type::DeepestElementOf(lhs_ty);
const ast::Expression* mask = b.Expr(AInt(lhs_el_ty->Size() * 8 - 1));
if (rhs_ty->Is<sem::Vector>()) {
mask = b.Construct(CreateASTTypeFor(ctx, rhs_ty), mask);
}
auto* mod = b.And(ctx.Clone(bin_op->rhs), mask);
ctx.Replace(bin_op->rhs, mod);
made_changes = true;
}
break;
default:
break;
}
}
}

View File

@@ -47,6 +47,8 @@ class BuiltinPolyfill final : public Castable<BuiltinPolyfill, Transform> {
bool asinh = false;
/// What level should `atanh` be polyfilled?
Level atanh = Level::kNone;
/// Should the RHS of `<<` and `>>` be wrapped in a modulo bit-width of LHS?
bool bitshift_modulo = false;
/// Should `clamp()` be polyfilled for integer values (scalar or vector)?
bool clamp_int = false;
/// Should `countLeadingZeros()` be polyfilled?
@@ -66,7 +68,7 @@ class BuiltinPolyfill final : public Castable<BuiltinPolyfill, Transform> {
/// Should `textureSampleBaseClampToEdge()` be polyfilled for texture_2d<f32> textures?
bool texture_sample_base_clamp_to_edge_2d_f32 = false;
/// Should the vector form of `quantizeToF16()` be polyfilled with a scalar implementation?
/// See crbug.com/tint/1741
/// See crbug.com/tint/1741
bool quantize_to_vec_f16 = false;
};

View File

@@ -398,6 +398,145 @@ fn f() {
EXPECT_EQ(expect, str(got));
}
////////////////////////////////////////////////////////////////////////////////
// bitshiftModulo
////////////////////////////////////////////////////////////////////////////////
DataMap polyfillBitshiftModulo() {
BuiltinPolyfill::Builtins builtins;
builtins.bitshift_modulo = true;
DataMap data;
data.Add<BuiltinPolyfill::Config>(builtins);
return data;
}
TEST_F(BuiltinPolyfillTest, ShouldRunBitshiftModulo_shl_scalar) {
auto* src = R"(
fn f() {
let v = 15u;
let r = 1i << v;
}
)";
EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillBitshiftModulo()));
}
TEST_F(BuiltinPolyfillTest, ShouldRunBitshiftModulo_shl_vector) {
auto* src = R"(
fn f() {
let v = 15u;
let r = vec3(1i) << vec3(v);
}
)";
EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillBitshiftModulo()));
}
TEST_F(BuiltinPolyfillTest, ShouldRunBitshiftModulo_shr_scalar) {
auto* src = R"(
fn f() {
let v = 15u;
let r = 1i >> v;
}
)";
EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillBitshiftModulo()));
}
TEST_F(BuiltinPolyfillTest, ShouldRunBitshiftModulo_shr_vector) {
auto* src = R"(
fn f() {
let v = 15u;
let r = vec3(1i) >> vec3(v);
}
)";
EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillBitshiftModulo()));
}
TEST_F(BuiltinPolyfillTest, BitshiftModulo_shl_scalar) {
auto* src = R"(
fn f() {
let v = 15u;
let r = 1i << v;
}
)";
auto* expect = R"(
fn f() {
let v = 15u;
let r = (1i << (v & 31));
}
)";
auto got = Run<BuiltinPolyfill>(src, polyfillBitshiftModulo());
EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, BitshiftModulo_shl_vector) {
auto* src = R"(
fn f() {
let v = 15u;
let r = vec3(1i) << vec3(v);
}
)";
auto* expect = R"(
fn f() {
let v = 15u;
let r = (vec3(1i) << (vec3(v) & vec3<u32>(31)));
}
)";
auto got = Run<BuiltinPolyfill>(src, polyfillBitshiftModulo());
EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, BitshiftModulo_shr_scalar) {
auto* src = R"(
fn f() {
let v = 15u;
let r = 1i >> v;
}
)";
auto* expect = R"(
fn f() {
let v = 15u;
let r = (1i >> (v & 31));
}
)";
auto got = Run<BuiltinPolyfill>(src, polyfillBitshiftModulo());
EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, BitshiftModulo_shr_vector) {
auto* src = R"(
fn f() {
let v = 15u;
let r = vec3(1i) >> vec3(v);
}
)";
auto* expect = R"(
fn f() {
let v = 15u;
let r = (vec3(1i) >> (vec3(v) & vec3<u32>(31)));
}
)";
auto got = Run<BuiltinPolyfill>(src, polyfillBitshiftModulo());
EXPECT_EQ(expect, str(got));
}
////////////////////////////////////////////////////////////////////////////////
// clampInteger
////////////////////////////////////////////////////////////////////////////////

View File

@@ -186,6 +186,7 @@ SanitizedResult Sanitize(const Program* in,
transform::BuiltinPolyfill::Builtins polyfills;
polyfills.acosh = transform::BuiltinPolyfill::Level::kRangeCheck;
polyfills.atanh = transform::BuiltinPolyfill::Level::kRangeCheck;
polyfills.bitshift_modulo = true;
polyfills.count_leading_zeros = true;
polyfills.count_trailing_zeros = true;
polyfills.extract_bits = transform::BuiltinPolyfill::Level::kClampParameters;

View File

@@ -162,6 +162,7 @@ SanitizedResult Sanitize(const Program* in, const Options& options) {
polyfills.acosh = transform::BuiltinPolyfill::Level::kFull;
polyfills.asinh = true;
polyfills.atanh = transform::BuiltinPolyfill::Level::kFull;
polyfills.bitshift_modulo = true;
polyfills.clamp_int = true;
// TODO(crbug.com/tint/1449): Some of these can map to HLSL's `firstbitlow`
// and `firstbithigh`.

View File

@@ -171,6 +171,7 @@ SanitizedResult Sanitize(const Program* in, const Options& options) {
transform::BuiltinPolyfill::Builtins polyfills;
polyfills.acosh = transform::BuiltinPolyfill::Level::kRangeCheck;
polyfills.atanh = transform::BuiltinPolyfill::Level::kRangeCheck;
polyfills.bitshift_modulo = true; // crbug.com/tint/1543
polyfills.clamp_int = true;
polyfills.extract_bits = transform::BuiltinPolyfill::Level::kClampParameters;
polyfills.first_leading_bit = true;

View File

@@ -52,6 +52,7 @@ SanitizedResult Sanitize(const Program* in, const Options& options) {
transform::BuiltinPolyfill::Builtins polyfills;
polyfills.acosh = transform::BuiltinPolyfill::Level::kRangeCheck;
polyfills.atanh = transform::BuiltinPolyfill::Level::kRangeCheck;
polyfills.bitshift_modulo = true;
polyfills.clamp_int = true;
polyfills.count_leading_zeros = true;
polyfills.count_trailing_zeros = true;