tint: fix builtin calls and binary ops with abstract args of different type

If a call to atan2 with args of type AFloat and AInt is made, Resolver
would correctly select the atan2(AFloat, AFloat) overload, but if the
input args were of type (AFloat, AInt), it would attempt to constant
evaluate without first converting the AInt arg to AFloat. The same would
occur for a binary operation, say AFloat + AInt. Before constant
evaluating, the Resolver now converts AInt to AFloat if necessary.

Bug: chromium:1350147
Change-Id: I85390c5d7af7e706115278ece34b2b18b8574f9f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/98543
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Antonio Maiorano 2022-08-10 20:01:17 +00:00 committed by Dawn LUCI CQ
parent 90d5eb6128
commit a58d8c9fac
10 changed files with 279 additions and 54 deletions

View File

@ -3131,27 +3131,27 @@ TEST_F(ResolverConstEvalTest, UnaryNegateLowestAbstract) {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
namespace binary_op { namespace binary_op {
template <typename T> using Types = std::variant<AInt, AFloat, u32, i32, f32, f16>;
struct Values {
T lhs; struct Case {
T rhs; Types lhs;
T expect; Types rhs;
Types expected;
bool is_overflow; bool is_overflow;
}; };
struct Case {
std::variant<Values<AInt>, Values<AFloat>, Values<u32>, Values<i32>, Values<f32>, Values<f16>>
values;
};
static std::ostream& operator<<(std::ostream& o, const Case& c) { static std::ostream& operator<<(std::ostream& o, const Case& c) {
std::visit([&](auto&& v) { o << v.lhs << " " << v.rhs; }, c.values); std::visit(
[&](auto&& lhs, auto&& rhs, auto&& expected) {
o << "lhs: " << lhs << ", rhs: " << rhs << ", expected: " << expected;
},
c.lhs, c.rhs, c.expected);
return o; return o;
} }
template <typename T> template <typename T, typename U, typename V>
Case C(T lhs, T rhs, T expect, bool is_overflow = false) { Case C(T lhs, U rhs, V expected, bool is_overflow = false) {
return Case{Values<T>{lhs, rhs, expect, is_overflow}}; return Case{lhs, rhs, expected, is_overflow};
} }
using ResolverConstEvalBinaryOpTest = ResolverTestWithParam<std::tuple<ast::BinaryOp, Case>>; using ResolverConstEvalBinaryOpTest = ResolverTestWithParam<std::tuple<ast::BinaryOp, Case>>;
@ -3161,16 +3161,16 @@ TEST_P(ResolverConstEvalBinaryOpTest, Test) {
auto op = std::get<0>(GetParam()); auto op = std::get<0>(GetParam());
auto c = std::get<1>(GetParam()); auto c = std::get<1>(GetParam());
std::visit( std::visit(
[&](auto&& values) { [&](auto&& lhs, auto&& rhs, auto&& expected) {
using T = decltype(values.expect); using T = std::decay_t<decltype(expected)>;
if constexpr (std::is_same_v<T, AInt> || std::is_same_v<T, AFloat>) { if constexpr (std::is_same_v<T, AInt> || std::is_same_v<T, AFloat>) {
if (values.is_overflow) { if (c.is_overflow) {
return; return;
} }
} }
auto* expr = create<ast::BinaryExpression>(op, Expr(values.lhs), Expr(values.rhs)); auto* expr = create<ast::BinaryExpression>(op, Expr(lhs), Expr(rhs));
GlobalConst("C", nullptr, expr); GlobalConst("C", nullptr, expr);
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
@ -3179,17 +3179,26 @@ TEST_P(ResolverConstEvalBinaryOpTest, Test) {
const sem::Constant* value = sem->ConstantValue(); const sem::Constant* value = sem->ConstantValue();
ASSERT_NE(value, nullptr); ASSERT_NE(value, nullptr);
EXPECT_TYPE(value->Type(), sem->Type()); EXPECT_TYPE(value->Type(), sem->Type());
EXPECT_EQ(value->As<T>(), values.expect); EXPECT_EQ(value->As<T>(), expected);
if constexpr (IsInteger<UnwrapNumber<T>>) { if constexpr (IsInteger<UnwrapNumber<T>>) {
// Check that the constant's integer doesn't contain unexpected data in the MSBs // Check that the constant's integer doesn't contain unexpected data in the MSBs
// that are outside of the bit-width of T. // that are outside of the bit-width of T.
EXPECT_EQ(value->As<AInt>(), AInt(values.expect)); EXPECT_EQ(value->As<AInt>(), AInt(expected));
} }
}, },
c.values); c.lhs, c.rhs, c.expected);
} }
INSTANTIATE_TEST_SUITE_P(MixedAbstractArgs,
ResolverConstEvalBinaryOpTest,
testing::Combine(testing::Values(ast::BinaryOp::kAdd),
testing::ValuesIn(std::vector{
// Mixed abstract type args
C(1_a, 2.3_a, 3.3_a),
C(2.3_a, 1_a, 3.3_a),
})));
template <typename T> template <typename T>
std::vector<Case> OpAddIntCases() { std::vector<Case> OpAddIntCases() {
static_assert(IsInteger<UnwrapNumber<T>>); static_assert(IsInteger<UnwrapNumber<T>>);
@ -3225,8 +3234,7 @@ INSTANTIATE_TEST_SUITE_P(Add,
OpAddIntCases<u32>(), OpAddIntCases<u32>(),
OpAddFloatCases<AFloat>(), OpAddFloatCases<AFloat>(),
OpAddFloatCases<f32>(), OpAddFloatCases<f32>(),
OpAddFloatCases<f16>() // OpAddFloatCases<f16>()))));
))));
TEST_F(ResolverConstEvalTest, BinaryAbstractAddOverflow_AInt) { TEST_F(ResolverConstEvalTest, BinaryAbstractAddOverflow_AInt) {
GlobalConst("c", nullptr, Add(Source{{1, 1}}, Expr(AInt::Highest()), 1_a)); GlobalConst("c", nullptr, Add(Source{{1, 1}}, Expr(AInt::Highest()), 1_a));
@ -3254,6 +3262,19 @@ TEST_F(ResolverConstEvalTest, BinaryAbstractAddUnderflow_AFloat) {
EXPECT_EQ(r()->error(), "1:1 error: '-inf' cannot be represented as 'abstract-float'"); EXPECT_EQ(r()->error(), "1:1 error: '-inf' cannot be represented as 'abstract-float'");
} }
TEST_F(ResolverConstEvalTest, BinaryAbstractMixed) {
auto* a = Const("a", nullptr, Expr(1_a));
auto* b = Const("b", nullptr, Expr(2.3_a));
auto* c = Add(Expr("a"), Expr("b"));
WrapInFunction(a, b, c);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(c);
ASSERT_TRUE(sem);
ASSERT_TRUE(sem->ConstantValue());
auto result = sem->ConstantValue()->As<AFloat>();
EXPECT_EQ(result, 3.3f);
}
} // namespace binary_op } // namespace binary_op
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
@ -3262,32 +3283,24 @@ TEST_F(ResolverConstEvalTest, BinaryAbstractAddUnderflow_AFloat) {
namespace builtin { namespace builtin {
template <typename T> using Types = std::variant<AInt, AFloat, u32, i32, f32, f16>;
struct Values {
utils::Vector<T, 8> args; struct Case {
T result; utils::Vector<Types, 8> args;
Types result;
bool result_pos_or_neg; bool result_pos_or_neg;
}; };
struct Case {
std::variant<Values<AInt>, Values<AFloat>, Values<u32>, Values<i32>, Values<f32>, Values<f16>>
values;
};
static std::ostream& operator<<(std::ostream& o, const Case& c) { static std::ostream& operator<<(std::ostream& o, const Case& c) {
std::visit( for (auto& a : c.args) {
[&](auto&& v) { std::visit([&](auto&& v) { o << v << ((&a != &c.args.Back()) ? " " : ""); }, a);
for (auto& e : v.args) { }
o << e << ((&e != &v.args.Back()) ? " " : "");
}
},
c.values);
return o; return o;
} }
template <typename T> template <typename T>
Case C(std::initializer_list<T> args, T result, bool result_pos_or_neg = false) { Case C(std::initializer_list<Types> args, T result, bool result_pos_or_neg = false) {
return Case{Values<T>{std::move(args), result, result_pos_or_neg}}; return Case{std::move(args), std::move(result), result_pos_or_neg};
} }
using ResolverConstEvalBuiltinTest = ResolverTestWithParam<std::tuple<sem::BuiltinType, Case>>; using ResolverConstEvalBuiltinTest = ResolverTestWithParam<std::tuple<sem::BuiltinType, Case>>;
@ -3297,12 +3310,15 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) {
auto builtin = std::get<0>(GetParam()); auto builtin = std::get<0>(GetParam());
auto c = std::get<1>(GetParam()); auto c = std::get<1>(GetParam());
utils::Vector<const ast::Expression*, 8> args;
for (auto& a : c.args) {
std::visit([&](auto&& v) { args.Push(Expr(v)); }, a);
}
std::visit( std::visit(
[&](auto&& values) { [&](auto&& result) {
using T = decltype(values.result); using T = std::decay_t<decltype(result)>;
auto args = utils::Transform(values.args, [&](auto&& a) {
return static_cast<const ast::Expression*>(Expr(a));
});
auto* expr = Call(sem::str(builtin), std::move(args)); auto* expr = Call(sem::str(builtin), std::move(args));
GlobalConst("C", nullptr, expr); GlobalConst("C", nullptr, expr);
@ -3317,22 +3333,22 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) {
auto actual = value->As<T>(); auto actual = value->As<T>();
if constexpr (IsFloatingPoint<UnwrapNumber<T>>) { if constexpr (IsFloatingPoint<UnwrapNumber<T>>) {
if (std::isnan(values.result)) { if (std::isnan(result)) {
EXPECT_TRUE(std::isnan(actual)); EXPECT_TRUE(std::isnan(actual));
} else { } else {
EXPECT_FLOAT_EQ(values.result_pos_or_neg ? Abs(actual) : actual, values.result); EXPECT_FLOAT_EQ(c.result_pos_or_neg ? Abs(actual) : actual, result);
} }
} else { } else {
EXPECT_EQ(values.result_pos_or_neg ? Abs(actual) : actual, values.result); EXPECT_EQ(c.result_pos_or_neg ? Abs(actual) : actual, result);
} }
if constexpr (IsInteger<UnwrapNumber<T>>) { if constexpr (IsInteger<UnwrapNumber<T>>) {
// Check that the constant's integer doesn't contain unexpected data in the MSBs // Check that the constant's integer doesn't contain unexpected data in the MSBs
// that are outside of the bit-width of T. // that are outside of the bit-width of T.
EXPECT_EQ(value->As<AInt>(), AInt(values.result)); EXPECT_EQ(value->As<AInt>(), AInt(result));
} }
}, },
c.values); c.result);
} }
template <typename T, bool finite_only> template <typename T, bool finite_only>
@ -3391,6 +3407,15 @@ std::vector<Case> Atan2Cases() {
return cases; return cases;
} }
INSTANTIATE_TEST_SUITE_P( //
MixedAbstractArgs,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kAtan2),
testing::ValuesIn(std::vector{
C({1_a, 1.0_a}, 0.78539819_a),
C({1.0_a, 1_a}, 0.78539819_a),
})));
INSTANTIATE_TEST_SUITE_P( // INSTANTIATE_TEST_SUITE_P( //
Atan2, Atan2,
ResolverConstEvalBuiltinTest, ResolverConstEvalBuiltinTest,

View File

@ -1456,6 +1456,29 @@ bool Resolver::ShouldMaterializeArgument(const sem::Type* parameter_ty) const {
return param_el_ty && !param_el_ty->Is<sem::AbstractNumeric>(); return param_el_ty && !param_el_ty->Is<sem::AbstractNumeric>();
} }
bool Resolver::Convert(const sem::Constant*& c, const sem::Type* target_ty, const Source& source) {
auto r = const_eval_.Convert(target_ty, c, source);
if (!r) {
return false;
}
c = r.Get();
return true;
}
template <size_t N>
utils::Result<utils::Vector<const sem::Constant*, N>> Resolver::ConvertArguments(
const utils::Vector<const sem::Expression*, N>& args,
const sem::CallTarget* target) {
auto const_args = utils::Transform(args, [](auto* arg) { return arg->ConstantValue(); });
for (size_t i = 0, n = std::min(args.Length(), target->Parameters().Length()); i < n; i++) {
if (!Convert(const_args[i], target->Parameters()[i]->Type(),
args[i]->Declaration()->source)) {
return utils::Failure;
}
}
return const_args;
}
sem::Expression* Resolver::IndexAccessor(const ast::IndexAccessorExpression* expr) { sem::Expression* Resolver::IndexAccessor(const ast::IndexAccessorExpression* expr) {
auto* idx = Materialize(sem_.Get(expr->index)); auto* idx = Materialize(sem_.Get(expr->index));
if (!idx) { if (!idx) {
@ -1893,9 +1916,12 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr,
// If the builtin is @const, and all arguments have constant values, evaluate the builtin now. // If the builtin is @const, and all arguments have constant values, evaluate the builtin now.
const sem::Constant* value = nullptr; const sem::Constant* value = nullptr;
if (stage == sem::EvaluationStage::kConstant) { if (stage == sem::EvaluationStage::kConstant) {
auto const_args = utils::Transform(args, [](auto* arg) { return arg->ConstantValue(); }); auto const_args = ConvertArguments(args, builtin.sem);
if (auto r = (const_eval_.*builtin.const_eval_fn)(builtin.sem->ReturnType(), const_args, if (!const_args) {
expr->source)) { return nullptr;
}
if (auto r = (const_eval_.*builtin.const_eval_fn)(builtin.sem->ReturnType(),
const_args.Get(), expr->source)) {
value = r.Get(); value = r.Get();
} else { } else {
return nullptr; return nullptr;
@ -2302,6 +2328,14 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
if (stage == sem::EvaluationStage::kConstant) { if (stage == sem::EvaluationStage::kConstant) {
if (op.const_eval_fn) { if (op.const_eval_fn) {
auto const_args = utils::Vector{lhs->ConstantValue(), rhs->ConstantValue()}; auto const_args = utils::Vector{lhs->ConstantValue(), rhs->ConstantValue()};
// Implicit conversion (e.g. AInt -> AFloat)
if (!Convert(const_args[0], op.result, lhs->Declaration()->source)) {
return nullptr;
}
if (!Convert(const_args[1], op.result, rhs->Declaration()->source)) {
return nullptr;
}
if (auto r = (const_eval_.*op.const_eval_fn)(op.result, const_args, expr->source)) { if (auto r = (const_eval_.*op.const_eval_fn)(op.result, const_args, expr->source)) {
value = r.Get(); value = r.Get();
} else { } else {

View File

@ -220,6 +220,18 @@ class Resolver {
/// `parameter_ty` should be materialized. /// `parameter_ty` should be materialized.
bool ShouldMaterializeArgument(const sem::Type* parameter_ty) const; bool ShouldMaterializeArgument(const sem::Type* parameter_ty) const;
/// Converts `c` to `target_ty`
/// @returns true on success, false on failure.
bool Convert(const sem::Constant*& c, const sem::Type* target_ty, const Source& source);
/// Transforms `args` to a vector of constants, and converts each constant to the call target's
/// parameter type.
/// @returns the vector of constants, `utils::Failure` on failure.
template <size_t N>
utils::Result<utils::Vector<const sem::Constant*, N>> ConvertArguments(
const utils::Vector<const sem::Expression*, N>& args,
const sem::CallTarget* target);
/// @param ty the type that may hold abstract numeric types /// @param ty the type that may hold abstract numeric types
/// @param target_ty the target type for the expression (variable type, parameter type, etc). /// @param target_ty the target type for the expression (variable type, parameter type, etc).
/// May be nullptr. /// May be nullptr.

View File

@ -0,0 +1,25 @@
fn original_clusterfuzz_code() {
atan2(1,.1);
}
fn more_tests_that_would_fail() {
// Builtin calls with mixed abstract args would fail because AInt would not materialize to AFloat.
{
let a = atan2(1, 0.1);
let b = atan2(0.1, 1);
}
// Same for binary operators
{
let a = 1 + 1.5;
let b = 1.5 + 1;
}
// Once above was fixed, builtin calls without assignment would also fail in backends because
// abstract constant value is not handled by backends. These should be removed by RemovePhonies
// transform.
{
atan2(1, 0.1);
atan2(0.1, 1);
}
}

View File

@ -0,0 +1,20 @@
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}
void original_clusterfuzz_code() {
}
void more_tests_that_would_fail() {
{
const float a = 1.471127629f;
const float b = 0.099668652f;
}
{
const float a = 2.5f;
const float b = 2.5f;
}
{
}
}

View File

@ -0,0 +1,20 @@
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}
void original_clusterfuzz_code() {
}
void more_tests_that_would_fail() {
{
const float a = 1.471127629f;
const float b = 0.099668652f;
}
{
const float a = 2.5f;
const float b = 2.5f;
}
{
}
}

View File

@ -0,0 +1,22 @@
#version 310 es
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void unused_entry_point() {
return;
}
void original_clusterfuzz_code() {
}
void more_tests_that_would_fail() {
{
float a = 1.471127629f;
float b = 0.099668652f;
}
{
float a = 2.5f;
float b = 2.5f;
}
{
}
}

View File

@ -0,0 +1,19 @@
#include <metal_stdlib>
using namespace metal;
void original_clusterfuzz_code() {
}
void more_tests_that_would_fail() {
{
float const a = 1.471127629f;
float const b = 0.099668652f;
}
{
float const a = 2.5f;
float const b = 2.5f;
}
{
}
}

View File

@ -0,0 +1,30 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 13
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
OpExecutionMode %unused_entry_point LocalSize 1 1 1
OpName %unused_entry_point "unused_entry_point"
OpName %original_clusterfuzz_code "original_clusterfuzz_code"
OpName %more_tests_that_would_fail "more_tests_that_would_fail"
%void = OpTypeVoid
%1 = OpTypeFunction %void
%float = OpTypeFloat 32
%float_1_47112763 = OpConstant %float 1.47112763
%float_0_0996686518 = OpConstant %float 0.0996686518
%float_2_5 = OpConstant %float 2.5
%unused_entry_point = OpFunction %void None %1
%4 = OpLabel
OpReturn
OpFunctionEnd
%original_clusterfuzz_code = OpFunction %void None %1
%6 = OpLabel
OpReturn
OpFunctionEnd
%more_tests_that_would_fail = OpFunction %void None %1
%8 = OpLabel
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,18 @@
fn original_clusterfuzz_code() {
atan2(1, 0.100000001);
}
fn more_tests_that_would_fail() {
{
let a = atan2(1, 0.100000001);
let b = atan2(0.100000001, 1);
}
{
let a = (1 + 1.5);
let b = (1.5 + 1);
}
{
atan2(1, 0.100000001);
atan2(0.100000001, 1);
}
}