tint: fix builtin calls and binary ops with abstract args of different type
If a call to atan2 with args of type AFloat and AInt is made, Resolver would correctly select the atan2(AFloat, AFloat) overload, but if the input args were of type (AFloat, AInt), it would attempt to constant evaluate without first converting the AInt arg to AFloat. The same would occur for a binary operation, say AFloat + AInt. Before constant evaluating, the Resolver now converts AInt to AFloat if necessary. Bug: chromium:1350147 Change-Id: I85390c5d7af7e706115278ece34b2b18b8574f9f Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/98543 Reviewed-by: Ben Clayton <bclayton@google.com> Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
parent
90d5eb6128
commit
a58d8c9fac
|
@ -3131,27 +3131,27 @@ TEST_F(ResolverConstEvalTest, UnaryNegateLowestAbstract) {
|
|||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
namespace binary_op {
|
||||
|
||||
template <typename T>
|
||||
struct Values {
|
||||
T lhs;
|
||||
T rhs;
|
||||
T expect;
|
||||
using Types = std::variant<AInt, AFloat, u32, i32, f32, f16>;
|
||||
|
||||
struct Case {
|
||||
Types lhs;
|
||||
Types rhs;
|
||||
Types expected;
|
||||
bool is_overflow;
|
||||
};
|
||||
|
||||
struct Case {
|
||||
std::variant<Values<AInt>, Values<AFloat>, Values<u32>, Values<i32>, Values<f32>, Values<f16>>
|
||||
values;
|
||||
};
|
||||
|
||||
static std::ostream& operator<<(std::ostream& o, const Case& c) {
|
||||
std::visit([&](auto&& v) { o << v.lhs << " " << v.rhs; }, c.values);
|
||||
std::visit(
|
||||
[&](auto&& lhs, auto&& rhs, auto&& expected) {
|
||||
o << "lhs: " << lhs << ", rhs: " << rhs << ", expected: " << expected;
|
||||
},
|
||||
c.lhs, c.rhs, c.expected);
|
||||
return o;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Case C(T lhs, T rhs, T expect, bool is_overflow = false) {
|
||||
return Case{Values<T>{lhs, rhs, expect, is_overflow}};
|
||||
template <typename T, typename U, typename V>
|
||||
Case C(T lhs, U rhs, V expected, bool is_overflow = false) {
|
||||
return Case{lhs, rhs, expected, is_overflow};
|
||||
}
|
||||
|
||||
using ResolverConstEvalBinaryOpTest = ResolverTestWithParam<std::tuple<ast::BinaryOp, Case>>;
|
||||
|
@ -3161,16 +3161,16 @@ TEST_P(ResolverConstEvalBinaryOpTest, Test) {
|
|||
auto op = std::get<0>(GetParam());
|
||||
auto c = std::get<1>(GetParam());
|
||||
std::visit(
|
||||
[&](auto&& values) {
|
||||
using T = decltype(values.expect);
|
||||
[&](auto&& lhs, auto&& rhs, auto&& expected) {
|
||||
using T = std::decay_t<decltype(expected)>;
|
||||
|
||||
if constexpr (std::is_same_v<T, AInt> || std::is_same_v<T, AFloat>) {
|
||||
if (values.is_overflow) {
|
||||
if (c.is_overflow) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(op, Expr(values.lhs), Expr(values.rhs));
|
||||
auto* expr = create<ast::BinaryExpression>(op, Expr(lhs), Expr(rhs));
|
||||
GlobalConst("C", nullptr, expr);
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
|
@ -3179,17 +3179,26 @@ TEST_P(ResolverConstEvalBinaryOpTest, Test) {
|
|||
const sem::Constant* value = sem->ConstantValue();
|
||||
ASSERT_NE(value, nullptr);
|
||||
EXPECT_TYPE(value->Type(), sem->Type());
|
||||
EXPECT_EQ(value->As<T>(), values.expect);
|
||||
EXPECT_EQ(value->As<T>(), expected);
|
||||
|
||||
if constexpr (IsInteger<UnwrapNumber<T>>) {
|
||||
// Check that the constant's integer doesn't contain unexpected data in the MSBs
|
||||
// that are outside of the bit-width of T.
|
||||
EXPECT_EQ(value->As<AInt>(), AInt(values.expect));
|
||||
EXPECT_EQ(value->As<AInt>(), AInt(expected));
|
||||
}
|
||||
},
|
||||
c.values);
|
||||
c.lhs, c.rhs, c.expected);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(MixedAbstractArgs,
|
||||
ResolverConstEvalBinaryOpTest,
|
||||
testing::Combine(testing::Values(ast::BinaryOp::kAdd),
|
||||
testing::ValuesIn(std::vector{
|
||||
// Mixed abstract type args
|
||||
C(1_a, 2.3_a, 3.3_a),
|
||||
C(2.3_a, 1_a, 3.3_a),
|
||||
})));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> OpAddIntCases() {
|
||||
static_assert(IsInteger<UnwrapNumber<T>>);
|
||||
|
@ -3225,8 +3234,7 @@ INSTANTIATE_TEST_SUITE_P(Add,
|
|||
OpAddIntCases<u32>(),
|
||||
OpAddFloatCases<AFloat>(),
|
||||
OpAddFloatCases<f32>(),
|
||||
OpAddFloatCases<f16>() //
|
||||
))));
|
||||
OpAddFloatCases<f16>()))));
|
||||
|
||||
TEST_F(ResolverConstEvalTest, BinaryAbstractAddOverflow_AInt) {
|
||||
GlobalConst("c", nullptr, Add(Source{{1, 1}}, Expr(AInt::Highest()), 1_a));
|
||||
|
@ -3254,6 +3262,19 @@ TEST_F(ResolverConstEvalTest, BinaryAbstractAddUnderflow_AFloat) {
|
|||
EXPECT_EQ(r()->error(), "1:1 error: '-inf' cannot be represented as 'abstract-float'");
|
||||
}
|
||||
|
||||
TEST_F(ResolverConstEvalTest, BinaryAbstractMixed) {
|
||||
auto* a = Const("a", nullptr, Expr(1_a));
|
||||
auto* b = Const("b", nullptr, Expr(2.3_a));
|
||||
auto* c = Add(Expr("a"), Expr("b"));
|
||||
WrapInFunction(a, b, c);
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
auto* sem = Sem().Get(c);
|
||||
ASSERT_TRUE(sem);
|
||||
ASSERT_TRUE(sem->ConstantValue());
|
||||
auto result = sem->ConstantValue()->As<AFloat>();
|
||||
EXPECT_EQ(result, 3.3f);
|
||||
}
|
||||
|
||||
} // namespace binary_op
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -3262,32 +3283,24 @@ TEST_F(ResolverConstEvalTest, BinaryAbstractAddUnderflow_AFloat) {
|
|||
|
||||
namespace builtin {
|
||||
|
||||
template <typename T>
|
||||
struct Values {
|
||||
utils::Vector<T, 8> args;
|
||||
T result;
|
||||
using Types = std::variant<AInt, AFloat, u32, i32, f32, f16>;
|
||||
|
||||
struct Case {
|
||||
utils::Vector<Types, 8> args;
|
||||
Types result;
|
||||
bool result_pos_or_neg;
|
||||
};
|
||||
|
||||
struct Case {
|
||||
std::variant<Values<AInt>, Values<AFloat>, Values<u32>, Values<i32>, Values<f32>, Values<f16>>
|
||||
values;
|
||||
};
|
||||
|
||||
static std::ostream& operator<<(std::ostream& o, const Case& c) {
|
||||
std::visit(
|
||||
[&](auto&& v) {
|
||||
for (auto& e : v.args) {
|
||||
o << e << ((&e != &v.args.Back()) ? " " : "");
|
||||
}
|
||||
},
|
||||
c.values);
|
||||
for (auto& a : c.args) {
|
||||
std::visit([&](auto&& v) { o << v << ((&a != &c.args.Back()) ? " " : ""); }, a);
|
||||
}
|
||||
return o;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Case C(std::initializer_list<T> args, T result, bool result_pos_or_neg = false) {
|
||||
return Case{Values<T>{std::move(args), result, result_pos_or_neg}};
|
||||
Case C(std::initializer_list<Types> args, T result, bool result_pos_or_neg = false) {
|
||||
return Case{std::move(args), std::move(result), result_pos_or_neg};
|
||||
}
|
||||
|
||||
using ResolverConstEvalBuiltinTest = ResolverTestWithParam<std::tuple<sem::BuiltinType, Case>>;
|
||||
|
@ -3297,12 +3310,15 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) {
|
|||
|
||||
auto builtin = std::get<0>(GetParam());
|
||||
auto c = std::get<1>(GetParam());
|
||||
|
||||
utils::Vector<const ast::Expression*, 8> args;
|
||||
for (auto& a : c.args) {
|
||||
std::visit([&](auto&& v) { args.Push(Expr(v)); }, a);
|
||||
}
|
||||
|
||||
std::visit(
|
||||
[&](auto&& values) {
|
||||
using T = decltype(values.result);
|
||||
auto args = utils::Transform(values.args, [&](auto&& a) {
|
||||
return static_cast<const ast::Expression*>(Expr(a));
|
||||
});
|
||||
[&](auto&& result) {
|
||||
using T = std::decay_t<decltype(result)>;
|
||||
auto* expr = Call(sem::str(builtin), std::move(args));
|
||||
|
||||
GlobalConst("C", nullptr, expr);
|
||||
|
@ -3317,22 +3333,22 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) {
|
|||
auto actual = value->As<T>();
|
||||
|
||||
if constexpr (IsFloatingPoint<UnwrapNumber<T>>) {
|
||||
if (std::isnan(values.result)) {
|
||||
if (std::isnan(result)) {
|
||||
EXPECT_TRUE(std::isnan(actual));
|
||||
} else {
|
||||
EXPECT_FLOAT_EQ(values.result_pos_or_neg ? Abs(actual) : actual, values.result);
|
||||
EXPECT_FLOAT_EQ(c.result_pos_or_neg ? Abs(actual) : actual, result);
|
||||
}
|
||||
} else {
|
||||
EXPECT_EQ(values.result_pos_or_neg ? Abs(actual) : actual, values.result);
|
||||
EXPECT_EQ(c.result_pos_or_neg ? Abs(actual) : actual, result);
|
||||
}
|
||||
|
||||
if constexpr (IsInteger<UnwrapNumber<T>>) {
|
||||
// Check that the constant's integer doesn't contain unexpected data in the MSBs
|
||||
// that are outside of the bit-width of T.
|
||||
EXPECT_EQ(value->As<AInt>(), AInt(values.result));
|
||||
EXPECT_EQ(value->As<AInt>(), AInt(result));
|
||||
}
|
||||
},
|
||||
c.values);
|
||||
c.result);
|
||||
}
|
||||
|
||||
template <typename T, bool finite_only>
|
||||
|
@ -3391,6 +3407,15 @@ std::vector<Case> Atan2Cases() {
|
|||
return cases;
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
MixedAbstractArgs,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kAtan2),
|
||||
testing::ValuesIn(std::vector{
|
||||
C({1_a, 1.0_a}, 0.78539819_a),
|
||||
C({1.0_a, 1_a}, 0.78539819_a),
|
||||
})));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
Atan2,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
|
|
|
@ -1456,6 +1456,29 @@ bool Resolver::ShouldMaterializeArgument(const sem::Type* parameter_ty) const {
|
|||
return param_el_ty && !param_el_ty->Is<sem::AbstractNumeric>();
|
||||
}
|
||||
|
||||
bool Resolver::Convert(const sem::Constant*& c, const sem::Type* target_ty, const Source& source) {
|
||||
auto r = const_eval_.Convert(target_ty, c, source);
|
||||
if (!r) {
|
||||
return false;
|
||||
}
|
||||
c = r.Get();
|
||||
return true;
|
||||
}
|
||||
|
||||
template <size_t N>
|
||||
utils::Result<utils::Vector<const sem::Constant*, N>> Resolver::ConvertArguments(
|
||||
const utils::Vector<const sem::Expression*, N>& args,
|
||||
const sem::CallTarget* target) {
|
||||
auto const_args = utils::Transform(args, [](auto* arg) { return arg->ConstantValue(); });
|
||||
for (size_t i = 0, n = std::min(args.Length(), target->Parameters().Length()); i < n; i++) {
|
||||
if (!Convert(const_args[i], target->Parameters()[i]->Type(),
|
||||
args[i]->Declaration()->source)) {
|
||||
return utils::Failure;
|
||||
}
|
||||
}
|
||||
return const_args;
|
||||
}
|
||||
|
||||
sem::Expression* Resolver::IndexAccessor(const ast::IndexAccessorExpression* expr) {
|
||||
auto* idx = Materialize(sem_.Get(expr->index));
|
||||
if (!idx) {
|
||||
|
@ -1893,9 +1916,12 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr,
|
|||
// If the builtin is @const, and all arguments have constant values, evaluate the builtin now.
|
||||
const sem::Constant* value = nullptr;
|
||||
if (stage == sem::EvaluationStage::kConstant) {
|
||||
auto const_args = utils::Transform(args, [](auto* arg) { return arg->ConstantValue(); });
|
||||
if (auto r = (const_eval_.*builtin.const_eval_fn)(builtin.sem->ReturnType(), const_args,
|
||||
expr->source)) {
|
||||
auto const_args = ConvertArguments(args, builtin.sem);
|
||||
if (!const_args) {
|
||||
return nullptr;
|
||||
}
|
||||
if (auto r = (const_eval_.*builtin.const_eval_fn)(builtin.sem->ReturnType(),
|
||||
const_args.Get(), expr->source)) {
|
||||
value = r.Get();
|
||||
} else {
|
||||
return nullptr;
|
||||
|
@ -2302,6 +2328,14 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
|
|||
if (stage == sem::EvaluationStage::kConstant) {
|
||||
if (op.const_eval_fn) {
|
||||
auto const_args = utils::Vector{lhs->ConstantValue(), rhs->ConstantValue()};
|
||||
// Implicit conversion (e.g. AInt -> AFloat)
|
||||
if (!Convert(const_args[0], op.result, lhs->Declaration()->source)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!Convert(const_args[1], op.result, rhs->Declaration()->source)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (auto r = (const_eval_.*op.const_eval_fn)(op.result, const_args, expr->source)) {
|
||||
value = r.Get();
|
||||
} else {
|
||||
|
|
|
@ -220,6 +220,18 @@ class Resolver {
|
|||
/// `parameter_ty` should be materialized.
|
||||
bool ShouldMaterializeArgument(const sem::Type* parameter_ty) const;
|
||||
|
||||
/// Converts `c` to `target_ty`
|
||||
/// @returns true on success, false on failure.
|
||||
bool Convert(const sem::Constant*& c, const sem::Type* target_ty, const Source& source);
|
||||
|
||||
/// Transforms `args` to a vector of constants, and converts each constant to the call target's
|
||||
/// parameter type.
|
||||
/// @returns the vector of constants, `utils::Failure` on failure.
|
||||
template <size_t N>
|
||||
utils::Result<utils::Vector<const sem::Constant*, N>> ConvertArguments(
|
||||
const utils::Vector<const sem::Expression*, N>& args,
|
||||
const sem::CallTarget* target);
|
||||
|
||||
/// @param ty the type that may hold abstract numeric types
|
||||
/// @param target_ty the target type for the expression (variable type, parameter type, etc).
|
||||
/// May be nullptr.
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
fn original_clusterfuzz_code() {
|
||||
atan2(1,.1);
|
||||
}
|
||||
|
||||
fn more_tests_that_would_fail() {
|
||||
// Builtin calls with mixed abstract args would fail because AInt would not materialize to AFloat.
|
||||
{
|
||||
let a = atan2(1, 0.1);
|
||||
let b = atan2(0.1, 1);
|
||||
}
|
||||
|
||||
// Same for binary operators
|
||||
{
|
||||
let a = 1 + 1.5;
|
||||
let b = 1.5 + 1;
|
||||
}
|
||||
|
||||
// Once above was fixed, builtin calls without assignment would also fail in backends because
|
||||
// abstract constant value is not handled by backends. These should be removed by RemovePhonies
|
||||
// transform.
|
||||
{
|
||||
atan2(1, 0.1);
|
||||
atan2(0.1, 1);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
[numthreads(1, 1, 1)]
|
||||
void unused_entry_point() {
|
||||
return;
|
||||
}
|
||||
|
||||
void original_clusterfuzz_code() {
|
||||
}
|
||||
|
||||
void more_tests_that_would_fail() {
|
||||
{
|
||||
const float a = 1.471127629f;
|
||||
const float b = 0.099668652f;
|
||||
}
|
||||
{
|
||||
const float a = 2.5f;
|
||||
const float b = 2.5f;
|
||||
}
|
||||
{
|
||||
}
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
[numthreads(1, 1, 1)]
|
||||
void unused_entry_point() {
|
||||
return;
|
||||
}
|
||||
|
||||
void original_clusterfuzz_code() {
|
||||
}
|
||||
|
||||
void more_tests_that_would_fail() {
|
||||
{
|
||||
const float a = 1.471127629f;
|
||||
const float b = 0.099668652f;
|
||||
}
|
||||
{
|
||||
const float a = 2.5f;
|
||||
const float b = 2.5f;
|
||||
}
|
||||
{
|
||||
}
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
#version 310 es
|
||||
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
void unused_entry_point() {
|
||||
return;
|
||||
}
|
||||
void original_clusterfuzz_code() {
|
||||
}
|
||||
|
||||
void more_tests_that_would_fail() {
|
||||
{
|
||||
float a = 1.471127629f;
|
||||
float b = 0.099668652f;
|
||||
}
|
||||
{
|
||||
float a = 2.5f;
|
||||
float b = 2.5f;
|
||||
}
|
||||
{
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
void original_clusterfuzz_code() {
|
||||
}
|
||||
|
||||
void more_tests_that_would_fail() {
|
||||
{
|
||||
float const a = 1.471127629f;
|
||||
float const b = 0.099668652f;
|
||||
}
|
||||
{
|
||||
float const a = 2.5f;
|
||||
float const b = 2.5f;
|
||||
}
|
||||
{
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
; SPIR-V
|
||||
; Version: 1.3
|
||||
; Generator: Google Tint Compiler; 0
|
||||
; Bound: 13
|
||||
; Schema: 0
|
||||
OpCapability Shader
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
|
||||
OpExecutionMode %unused_entry_point LocalSize 1 1 1
|
||||
OpName %unused_entry_point "unused_entry_point"
|
||||
OpName %original_clusterfuzz_code "original_clusterfuzz_code"
|
||||
OpName %more_tests_that_would_fail "more_tests_that_would_fail"
|
||||
%void = OpTypeVoid
|
||||
%1 = OpTypeFunction %void
|
||||
%float = OpTypeFloat 32
|
||||
%float_1_47112763 = OpConstant %float 1.47112763
|
||||
%float_0_0996686518 = OpConstant %float 0.0996686518
|
||||
%float_2_5 = OpConstant %float 2.5
|
||||
%unused_entry_point = OpFunction %void None %1
|
||||
%4 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
%original_clusterfuzz_code = OpFunction %void None %1
|
||||
%6 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
%more_tests_that_would_fail = OpFunction %void None %1
|
||||
%8 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -0,0 +1,18 @@
|
|||
fn original_clusterfuzz_code() {
|
||||
atan2(1, 0.100000001);
|
||||
}
|
||||
|
||||
fn more_tests_that_would_fail() {
|
||||
{
|
||||
let a = atan2(1, 0.100000001);
|
||||
let b = atan2(0.100000001, 1);
|
||||
}
|
||||
{
|
||||
let a = (1 + 1.5);
|
||||
let b = (1.5 + 1);
|
||||
}
|
||||
{
|
||||
atan2(1, 0.100000001);
|
||||
atan2(0.100000001, 1);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue