tint: Implement signed-int overloads of sign()

Bug: tint:1581
Fixed: tint:1782
Change-Id: Ia029bf9d1ce1d978c5cabc3016cb8ad1b4bac06a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/113243
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ben Clayton
2022-12-07 19:52:49 +00:00
committed by Dawn LUCI CQ
parent 8e56c8baa7
commit 7017ec264c
121 changed files with 4614 additions and 90 deletions

View File

@@ -548,8 +548,8 @@ fn pow<N: num, T: f32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T>
@const("select_bool") fn select<T: scalar>(T, T, bool) -> T
@const("select_bool") fn select<T: scalar, N: num>(vec<N, T>, vec<N, T>, bool) -> vec<N, T>
@const("select_boolvec") fn select<N: num, T: scalar>(vec<N, T>, vec<N, T>, vec<N, bool>) -> vec<N, T>
@const fn sign<T: fa_f32_f16>(T) -> T
@const fn sign<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
@const fn sign<T: fia_fi32_f16>(T) -> T
@const fn sign<N: num, T: fia_fi32_f16>(vec<N, T>) -> vec<N, T>
@const fn sin<T: fa_f32_f16>(T) -> T
@const fn sin<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
@const fn sinh<T: fa_f32_f16>(T) -> T

View File

@@ -3298,7 +3298,7 @@ ConstEval::Result ConstEval::sign(const sem::Type* ty,
}
return CreateElement(builder, source, c0->Type(), result);
};
return Dispatch_fa_f32_f16(create, c0);
return Dispatch_fia_fi32_f16(create, c0);
};
return TransformElements(builder, ty, transform, args[0]);
}

View File

@@ -2200,27 +2200,40 @@ INSTANTIATE_TEST_SUITE_P( //
template <typename T>
std::vector<Case> SignCases() {
return {
C({-T(1)}, -T(1)),
C({-T(0.5)}, -T(1)),
std::vector<Case> cases = {
C({T(0)}, T(0)),
C({-T(0)}, T(0)),
C({T(0.5)}, T(1)),
C({-T(1)}, -T(1)),
C({-T(10)}, -T(1)),
C({-T(100)}, -T(1)),
C({T(1)}, T(1)),
C({T(10)}, T(1)),
C({T(100)}, T(1)),
C({T::Highest()}, T(1.0)),
C({T::Lowest()}, -T(1.0)),
// Vector tests
C({Vec(-T(0.5), T(0), T(0.5))}, Vec(-T(1.0), T(0.0), T(1.0))),
C({Vec(T::Highest(), T::Lowest())}, Vec(T(1.0), -T(1.0))),
};
ConcatIntoIf<IsFloatingPoint<T>>(
cases, std::vector<Case>{
C({-T(0.5)}, -T(1)),
C({T(0.5)}, T(1)),
C({Vec(-T(0.5), T(0), T(0.5))}, Vec(-T(1.0), T(0.0), T(1.0))),
});
return cases;
}
INSTANTIATE_TEST_SUITE_P( //
Sign,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kSign),
testing::ValuesIn(Concat(SignCases<AFloat>(), //
testing::ValuesIn(Concat(SignCases<AInt>(), //
SignCases<i32>(),
SignCases<AFloat>(),
SignCases<f32>(),
SignCases<f16>()))));

View File

@@ -8194,12 +8194,12 @@ constexpr TemplateTypeInfo kTemplateTypes[] = {
{
/* [28] */
/* name */ "T",
/* matcher index */ 64,
/* matcher index */ 60,
},
{
/* [29] */
/* name */ "T",
/* matcher index */ 60,
/* matcher index */ 64,
},
{
/* [30] */
@@ -11369,7 +11369,7 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[28],
/* template types */ &kTemplateTypes[29],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[734],
/* return matcher indices */ &kMatcherIndices[3],
@@ -11381,7 +11381,7 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[28],
/* template types */ &kTemplateTypes[29],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[736],
/* return matcher indices */ &kMatcherIndices[30],
@@ -11417,7 +11417,7 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[28],
/* template types */ &kTemplateTypes[29],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[742],
/* return matcher indices */ &kMatcherIndices[3],
@@ -11429,7 +11429,7 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[28],
/* template types */ &kTemplateTypes[29],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[744],
/* return matcher indices */ &kMatcherIndices[30],
@@ -12941,7 +12941,7 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[23],
/* template types */ &kTemplateTypes[28],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[892],
/* return matcher indices */ &kMatcherIndices[3],
@@ -12953,7 +12953,7 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[23],
/* template types */ &kTemplateTypes[28],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[893],
/* return matcher indices */ &kMatcherIndices[30],
@@ -13229,7 +13229,7 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[28],
/* template types */ &kTemplateTypes[29],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[950],
/* return matcher indices */ &kMatcherIndices[3],
@@ -13241,7 +13241,7 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[28],
/* template types */ &kTemplateTypes[29],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[951],
/* return matcher indices */ &kMatcherIndices[30],
@@ -13253,7 +13253,7 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[29],
/* template types */ &kTemplateTypes[28],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[952],
/* return matcher indices */ &kMatcherIndices[3],
@@ -13265,7 +13265,7 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[29],
/* template types */ &kTemplateTypes[28],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[953],
/* return matcher indices */ &kMatcherIndices[30],
@@ -13277,7 +13277,7 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[28],
/* template types */ &kTemplateTypes[29],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[726],
/* return matcher indices */ &kMatcherIndices[3],
@@ -13289,7 +13289,7 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[28],
/* template types */ &kTemplateTypes[29],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[728],
/* return matcher indices */ &kMatcherIndices[30],
@@ -13445,7 +13445,7 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[28],
/* template types */ &kTemplateTypes[29],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[774],
/* return matcher indices */ &kMatcherIndices[3],
@@ -13457,7 +13457,7 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[28],
/* template types */ &kTemplateTypes[29],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[776],
/* return matcher indices */ &kMatcherIndices[30],
@@ -13469,7 +13469,7 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[28],
/* template types */ &kTemplateTypes[29],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[778],
/* return matcher indices */ &kMatcherIndices[3],
@@ -13481,7 +13481,7 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[28],
/* template types */ &kTemplateTypes[29],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[780],
/* return matcher indices */ &kMatcherIndices[30],
@@ -14407,8 +14407,8 @@ constexpr IntrinsicInfo kBuiltins[] = {
},
{
/* [69] */
/* fn sign<T : fa_f32_f16>(T) -> T */
/* fn sign<N : num, T : fa_f32_f16>(vec<N, T>) -> vec<N, T> */
/* fn sign<T : fia_fi32_f16>(T) -> T */
/* fn sign<N : num, T : fia_fi32_f16>(vec<N, T>) -> vec<N, T> */
/* num overloads */ 2,
/* overloads */ &kOverloads[387],
},

View File

@@ -590,6 +590,32 @@ struct BuiltinPolyfill::State {
return name;
}
/// Builds the polyfill function for the `sign` builtin when the element type is integer
/// @param ty the parameter and return type for the function
/// @return the polyfill function name
Symbol sign_int(const sem::Type* ty) {
const uint32_t width = WidthOf(ty);
auto zero = [&] { return ScalarOrVector(width, 0_a); };
// pos_or_neg_one = (v > 0) ? 1 : -1
auto pos_or_neg_one = b.Call("select", //
ScalarOrVector(width, -1_a), //
ScalarOrVector(width, 1_a), //
b.GreaterThan("v", zero()));
auto name = b.Symbols().New("tint_sign");
b.Func(name,
utils::Vector{
b.Param("v", T(ty)),
},
T(ty),
utils::Vector{
b.Return(b.Call("select", pos_or_neg_one, zero(), b.Equal("v", zero()))),
});
return name;
}
/// Builds the polyfill function for the `textureSampleBaseClampToEdge` builtin, when the
/// texture type is texture_2d<f32>.
/// @return the polyfill function name
@@ -855,6 +881,15 @@ Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src,
builtin, [&] { return s.saturate(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kSign:
if (polyfill.sign_int) {
auto* ty = builtin->ReturnType();
if (ty->is_signed_integer_scalar_or_vector()) {
fn = builtin_polyfills.GetOrCreate(builtin,
[&] { return s.sign_int(ty); });
}
}
break;
case sem::BuiltinType::kTextureSampleBaseClampToEdge:
if (polyfill.texture_sample_base_clamp_to_edge_2d_f32) {
auto& sig = builtin->Signature();

View File

@@ -68,6 +68,8 @@ class BuiltinPolyfill final : public Castable<BuiltinPolyfill, Transform> {
bool int_div_mod = false;
/// Should `saturate()` be polyfilled?
bool saturate = false;
/// Should `sign()` be polyfilled for integer types?
bool sign_int = false;
/// Should `textureSampleBaseClampToEdge()` be polyfilled for texture_2d<f32> textures?
bool texture_sample_base_clamp_to_edge_2d_f32 = false;
/// Should the vector form of `quantizeToF16()` be polyfilled with a scalar implementation?

View File

@@ -34,11 +34,7 @@ TEST_F(BuiltinPolyfillTest, ShouldRunEmptyModule) {
TEST_F(BuiltinPolyfillTest, EmptyModule) {
auto* src = R"()";
auto* expect = src;
auto got = Run<BuiltinPolyfill>(src);
EXPECT_EQ(expect, str(got));
EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
}
////////////////////////////////////////////////////////////////////////////////
@@ -73,11 +69,7 @@ fn f() {
}
)";
auto* expect = src;
auto got = Run<BuiltinPolyfill>(src, polyfillAcosh(Level::kFull));
EXPECT_EQ(expect, str(got));
EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillAcosh(Level::kFull)));
}
TEST_F(BuiltinPolyfillTest, Acosh_Full_f32) {
@@ -206,11 +198,7 @@ fn f() {
}
)";
auto* expect = src;
auto got = Run<BuiltinPolyfill>(src, polyfillSinh());
EXPECT_EQ(expect, str(got));
EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillSinh()));
}
TEST_F(BuiltinPolyfillTest, Asinh_f32) {
@@ -293,11 +281,7 @@ fn f() {
}
)";
auto* expect = src;
auto got = Run<BuiltinPolyfill>(src, polyfillAtanh(Level::kFull));
EXPECT_EQ(expect, str(got));
EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillAtanh(Level::kFull)));
}
TEST_F(BuiltinPolyfillTest, Atanh_Full_f32) {
@@ -603,11 +587,7 @@ fn f() {
}
)";
auto* expect = src;
auto got = Run<BuiltinPolyfill>(src, polyfillClampInteger());
EXPECT_EQ(expect, str(got));
EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillClampInteger()));
}
TEST_F(BuiltinPolyfillTest, ClampInteger_i32) {
@@ -732,11 +712,7 @@ fn f() {
}
)";
auto* expect = src;
auto got = Run<BuiltinPolyfill>(src, polyfillCountLeadingZeros());
EXPECT_EQ(expect, str(got));
EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillCountLeadingZeros()));
}
TEST_F(BuiltinPolyfillTest, CountLeadingZeros_i32) {
@@ -909,11 +885,7 @@ fn f() {
}
)";
auto* expect = src;
auto got = Run<BuiltinPolyfill>(src, polyfillCountTrailingZeros());
EXPECT_EQ(expect, str(got));
EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillCountTrailingZeros()));
}
TEST_F(BuiltinPolyfillTest, CountTrailingZeros_i32) {
@@ -1088,11 +1060,7 @@ fn f() {
}
)";
auto* expect = src;
auto got = Run<BuiltinPolyfill>(src, polyfillExtractBits(Level::kFull));
EXPECT_EQ(expect, str(got));
EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillExtractBits(Level::kFull)));
}
TEST_F(BuiltinPolyfillTest, ExtractBits_Full_i32) {
@@ -1345,11 +1313,7 @@ fn f() {
}
)";
auto* expect = src;
auto got = Run<BuiltinPolyfill>(src, polyfillFirstLeadingBit());
EXPECT_EQ(expect, str(got));
EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillFirstLeadingBit()));
}
TEST_F(BuiltinPolyfillTest, FirstLeadingBit_i32) {
@@ -1522,11 +1486,7 @@ fn f() {
}
)";
auto* expect = src;
auto got = Run<BuiltinPolyfill>(src, polyfillFirstTrailingBit());
EXPECT_EQ(expect, str(got));
EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillFirstTrailingBit()));
}
TEST_F(BuiltinPolyfillTest, FirstTrailingBit_i32) {
@@ -1701,11 +1661,7 @@ fn f() {
}
)";
auto* expect = src;
auto got = Run<BuiltinPolyfill>(src, polyfillInsertBits(Level::kFull));
EXPECT_EQ(expect, str(got));
EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillInsertBits(Level::kFull)));
}
TEST_F(BuiltinPolyfillTest, InsertBits_Full_i32) {
@@ -2715,11 +2671,7 @@ fn f() {
}
)";
auto* expect = src;
auto got = Run<BuiltinPolyfill>(src, polyfillSaturate());
EXPECT_EQ(expect, str(got));
EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillSaturate()));
}
TEST_F(BuiltinPolyfillTest, Saturate_f32) {
@@ -2826,6 +2778,99 @@ fn f() {
EXPECT_EQ(expect, str(got));
}
////////////////////////////////////////////////////////////////////////////////
// sign_int
////////////////////////////////////////////////////////////////////////////////
DataMap polyfillSignInt() {
BuiltinPolyfill::Builtins builtins;
builtins.sign_int = true;
DataMap data;
data.Add<BuiltinPolyfill::Config>(builtins);
return data;
}
TEST_F(BuiltinPolyfillTest, ShouldRunSign_i32) {
auto* src = R"(
fn f() {
let v = 1i;
sign(v);
}
)";
EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
EXPECT_TRUE(ShouldRun<BuiltinPolyfill>(src, polyfillSignInt()));
}
TEST_F(BuiltinPolyfillTest, ShouldRunSign_f32) {
auto* src = R"(
fn f() {
let v = 1f;
sign(v);
}
)";
EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src));
EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillSignInt()));
}
TEST_F(BuiltinPolyfillTest, SignInt_ConstantExpression) {
auto* src = R"(
fn f() {
let r : i32 = sign(1i);
}
)";
EXPECT_FALSE(ShouldRun<BuiltinPolyfill>(src, polyfillSignInt()));
}
TEST_F(BuiltinPolyfillTest, SignInt_i32) {
auto* src = R"(
fn f() {
let v = 1i;
let r : i32 = sign(v);
}
)";
auto* expect = R"(
fn tint_sign(v : i32) -> i32 {
return select(select(-1, 1, (v > 0)), 0, (v == 0));
}
fn f() {
let v = 1i;
let r : i32 = tint_sign(v);
}
)";
auto got = Run<BuiltinPolyfill>(src, polyfillSignInt());
EXPECT_EQ(expect, str(got));
}
TEST_F(BuiltinPolyfillTest, SignInt_vec3_i32) {
auto* src = R"(
fn f() {
let v = 1i;
let r : vec3<i32> = sign(vec3<i32>(v));
}
)";
auto* expect = R"(
fn tint_sign(v : vec3<i32>) -> vec3<i32> {
return select(select(vec3(-1), vec3(1), (v > vec3(0))), vec3(0), (v == vec3(0)));
}
fn f() {
let v = 1i;
let r : vec3<i32> = tint_sign(vec3<i32>(v));
}
)";
auto got = Run<BuiltinPolyfill>(src, polyfillSignInt());
EXPECT_EQ(expect, str(got));
}
////////////////////////////////////////////////////////////////////////////////
// textureSampleBaseClampToEdge
////////////////////////////////////////////////////////////////////////////////

View File

@@ -181,6 +181,7 @@ SanitizedResult Sanitize(const Program* in, const Options& options) {
polyfills.first_trailing_bit = true;
polyfills.insert_bits = transform::BuiltinPolyfill::Level::kClampParameters;
polyfills.int_div_mod = true;
polyfills.sign_int = true;
polyfills.texture_sample_base_clamp_to_edge_2d_f32 = true;
data.Add<transform::BuiltinPolyfill::Config>(polyfills);
manager.Add<transform::BuiltinPolyfill>();

View File

@@ -201,7 +201,11 @@ uint32_t builtin_to_glsl_method(const sem::Builtin* builtin) {
case BuiltinType::kRound:
return GLSLstd450RoundEven;
case BuiltinType::kSign:
return GLSLstd450FSign;
if (builtin->ReturnType()->is_signed_integer_scalar_or_vector()) {
return GLSLstd450SSign;
} else {
return GLSLstd450FSign;
}
case BuiltinType::kSin:
return GLSLstd450Sin;
case BuiltinType::kSinh: