tint: fix HLSL countOneBits and reverseBits for i32 args

These two functions in HLSL only accept and return uint. Thus, if the
result of these calls is passed to a function that wants int, it will
fail, or call the uint overload if one exists. Fixed by casting the
result to int if the arg is int.

Bug: tint:1550
Change-Id: Id4c0970a29ac4c83ee5b78be8d2762e05e4a3f03
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91001
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Ryan Harrison <rharrison@chromium.org>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Antonio Maiorano 2022-05-20 01:58:40 +00:00 committed by Dawn LUCI CQ
parent 782577accf
commit ab4c035762
10 changed files with 68 additions and 15 deletions

View File

@ -999,23 +999,25 @@ bool GeneratorImpl::EmitFunctionCall(std::ostream& out,
bool GeneratorImpl::EmitBuiltinCall(std::ostream& out, bool GeneratorImpl::EmitBuiltinCall(std::ostream& out,
const sem::Call* call, const sem::Call* call,
const sem::Builtin* builtin) { const sem::Builtin* builtin) {
const auto type = builtin->Type();
auto* expr = call->Declaration(); auto* expr = call->Declaration();
if (builtin->IsTexture()) { if (builtin->IsTexture()) {
return EmitTextureCall(out, call, builtin); return EmitTextureCall(out, call, builtin);
} }
if (builtin->Type() == sem::BuiltinType::kSelect) { if (type == sem::BuiltinType::kSelect) {
return EmitSelectCall(out, expr); return EmitSelectCall(out, expr);
} }
if (builtin->Type() == sem::BuiltinType::kModf) { if (type == sem::BuiltinType::kModf) {
return EmitModfCall(out, expr, builtin); return EmitModfCall(out, expr, builtin);
} }
if (builtin->Type() == sem::BuiltinType::kFrexp) { if (type == sem::BuiltinType::kFrexp) {
return EmitFrexpCall(out, expr, builtin); return EmitFrexpCall(out, expr, builtin);
} }
if (builtin->Type() == sem::BuiltinType::kDegrees) { if (type == sem::BuiltinType::kDegrees) {
return EmitDegreesCall(out, expr, builtin); return EmitDegreesCall(out, expr, builtin);
} }
if (builtin->Type() == sem::BuiltinType::kRadians) { if (type == sem::BuiltinType::kRadians) {
return EmitRadiansCall(out, expr, builtin); return EmitRadiansCall(out, expr, builtin);
} }
if (builtin->IsDataPacking()) { if (builtin->IsDataPacking()) {
@ -1033,11 +1035,27 @@ bool GeneratorImpl::EmitBuiltinCall(std::ostream& out,
if (builtin->IsDP4a()) { if (builtin->IsDP4a()) {
return EmitDP4aCall(out, expr, builtin); return EmitDP4aCall(out, expr, builtin);
} }
auto name = generate_builtin_name(builtin); auto name = generate_builtin_name(builtin);
if (name.empty()) { if (name.empty()) {
return false; return false;
} }
// Handle single argument builtins that only accept and return uint (not int overload). We need
// to explicitly cast the return value (we also cast the arg for good measure). See
// crbug.com/tint/1550
if (type == sem::BuiltinType::kCountOneBits || type == sem::BuiltinType::kReverseBits) {
auto* arg = call->Arguments()[0];
if (arg->Type()->UnwrapRef()->is_signed_scalar_or_vector()) {
out << "asint(" << name << "(asuint(";
if (!EmitExpression(out, arg->Declaration())) {
return false;
}
out << ")))";
return true;
}
}
out << name << "("; out << name << "(";
bool first = true; bool first = true;
@ -1053,6 +1071,7 @@ bool GeneratorImpl::EmitBuiltinCall(std::ostream& out,
} }
out << ")"; out << ")";
return true; return true;
} }
@ -2546,7 +2565,7 @@ std::string GeneratorImpl::generate_builtin_name(const sem::Builtin* builtin) {
case sem::BuiltinType::kTranspose: case sem::BuiltinType::kTranspose:
case sem::BuiltinType::kTrunc: case sem::BuiltinType::kTrunc:
return builtin->str(); return builtin->str();
case sem::BuiltinType::kCountOneBits: case sem::BuiltinType::kCountOneBits: // uint
return "countbits"; return "countbits";
case sem::BuiltinType::kDpdx: case sem::BuiltinType::kDpdx:
return "ddx"; return "ddx";
@ -2574,7 +2593,7 @@ std::string GeneratorImpl::generate_builtin_name(const sem::Builtin* builtin) {
return "rsqrt"; return "rsqrt";
case sem::BuiltinType::kMix: case sem::BuiltinType::kMix:
return "lerp"; return "lerp";
case sem::BuiltinType::kReverseBits: case sem::BuiltinType::kReverseBits: // uint
return "reversebits"; return "reversebits";
case sem::BuiltinType::kSmoothstep: case sem::BuiltinType::kSmoothstep:
case sem::BuiltinType::kSmoothStep: case sem::BuiltinType::kSmoothStep:

View File

@ -778,5 +778,39 @@ void test_function() {
)"); )");
} }
TEST_F(HlslGeneratorImplTest_Builtin, CountOneBits) {
auto* val = Var("val1", ty.i32());
auto* call = Call("countOneBits", val);
WrapInFunction(val, call);
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"([numthreads(1, 1, 1)]
void test_function() {
int val1 = 0;
const int tint_symbol = asint(countbits(asuint(val1)));
return;
}
)");
}
TEST_F(HlslGeneratorImplTest_Builtin, ReverseBits) {
auto* val = Var("val1", ty.i32());
auto* call = Call("reverseBits", val);
WrapInFunction(val, call);
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"([numthreads(1, 1, 1)]
void test_function() {
int val1 = 0;
const int tint_symbol = asint(reversebits(asuint(val1)));
return;
}
)");
}
} // namespace } // namespace
} // namespace tint::writer::hlsl } // namespace tint::writer::hlsl

View File

@ -1,5 +1,5 @@
void countOneBits_0f7980() { void countOneBits_0f7980() {
int4 res = countbits(int4(0, 0, 0, 0)); int4 res = asint(countbits(asuint(int4(0, 0, 0, 0))));
} }
struct tint_symbol { struct tint_symbol {

View File

@ -1,5 +1,5 @@
void countOneBits_65d2ae() { void countOneBits_65d2ae() {
int3 res = countbits(int3(0, 0, 0)); int3 res = asint(countbits(asuint(int3(0, 0, 0))));
} }
struct tint_symbol { struct tint_symbol {

View File

@ -1,5 +1,5 @@
void countOneBits_af90e2() { void countOneBits_af90e2() {
int2 res = countbits(int2(0, 0)); int2 res = asint(countbits(asuint(int2(0, 0))));
} }
struct tint_symbol { struct tint_symbol {

View File

@ -1,5 +1,5 @@
void countOneBits_fd88b2() { void countOneBits_fd88b2() {
int res = countbits(1); int res = asint(countbits(asuint(1)));
} }
struct tint_symbol { struct tint_symbol {

View File

@ -1,5 +1,5 @@
void reverseBits_222177() { void reverseBits_222177() {
int2 res = reversebits(int2(0, 0)); int2 res = asint(reversebits(asuint(int2(0, 0))));
} }
struct tint_symbol { struct tint_symbol {

View File

@ -1,5 +1,5 @@
void reverseBits_4dbd6f() { void reverseBits_4dbd6f() {
int4 res = reversebits(int4(0, 0, 0, 0)); int4 res = asint(reversebits(asuint(int4(0, 0, 0, 0))));
} }
struct tint_symbol { struct tint_symbol {

View File

@ -1,5 +1,5 @@
void reverseBits_7c4269() { void reverseBits_7c4269() {
int res = reversebits(1); int res = asint(reversebits(asuint(1)));
} }
struct tint_symbol { struct tint_symbol {

View File

@ -1,5 +1,5 @@
void reverseBits_c21bc1() { void reverseBits_c21bc1() {
int3 res = reversebits(int3(0, 0, 0)); int3 res = asint(reversebits(asuint(int3(0, 0, 0))));
} }
struct tint_symbol { struct tint_symbol {