tint: const eval of mix builtin

Bug: tint:1581
Change-Id: I3b9f0ff3a58956616daf17b3d4a922979fc30216
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/113680
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Antonio Maiorano
2022-12-12 15:31:21 +00:00
committed by Dawn LUCI CQ
parent b5fedb7b5a
commit 7c9e639e35
131 changed files with 2658 additions and 416 deletions

View File

@@ -2864,6 +2864,49 @@ ConstEval::Result ConstEval::min(const type::Type* ty,
return TransformElements(builder, ty, transform, args[0], args[1]);
}
ConstEval::Result ConstEval::mix(const type::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1, size_t index) {
auto create = [&](auto e1, auto e2) -> ImplResult {
using NumberT = decltype(e1);
// e3 is either a vector or a scalar
NumberT e3;
auto* c2 = args[2];
if (c2->Type()->Is<type::Vector>()) {
e3 = c2->Index(index)->As<NumberT>();
} else {
e3 = c2->As<NumberT>();
}
// Implement as `e1 * (1 - e3) + e2 * e3)` instead of as `e1 + e3 * (e2 - e1)` to avoid
// float precision loss when e1 and e2 significantly differ in magnitude.
auto one_sub_e3 = Sub(source, NumberT{1}, e3);
if (!one_sub_e3) {
return utils::Failure;
}
auto e1_mul_one_sub_e3 = Mul(source, e1, one_sub_e3.Get());
if (!e1_mul_one_sub_e3) {
return utils::Failure;
}
auto e2_mul_e3 = Mul(source, e2, e3);
if (!e2_mul_e3) {
return utils::Failure;
}
auto r = Add(source, e1_mul_one_sub_e3.Get(), e2_mul_e3.Get());
if (!r) {
return utils::Failure;
}
return CreateElement(builder, source, c0->Type(), r.Get());
};
return Dispatch_fa_f32_f16(create, c0, c1);
};
auto r = TransformElements(builder, ty, transform, args[0], args[1]);
if (!r) {
AddNote("when calculating mix", source);
}
return r;
}
ConstEval::Result ConstEval::modf(const type::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {

View File

@@ -757,6 +757,15 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// mix builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result mix(const type::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// modf builtin
/// @param ty the expression type
/// @param args the input arguments

View File

@@ -179,7 +179,7 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) {
CheckConstant(value, expected_case.values[0], expected_case.flags);
}
} else {
EXPECT_FALSE(r()->Resolve());
ASSERT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), c.expected.Failure().error);
}
}
@@ -1733,6 +1733,94 @@ INSTANTIATE_TEST_SUITE_P( //
MinCases<AFloat>(),
MinCases<f32>(),
MinCases<f16>()))));
template <typename T>
std::vector<Case> MixCases() {
auto r = std::vector<Case>{
C({T(0), T(1), T(0)}, T(0)), //
C({T(0), T(1), T(1)}, T(1)), //
C({T(0), T(1), T(2)}, T(2)), //
C({T(0), T(1), T::Highest()}, T::Highest()), //
C({T::Lowest(), T::Highest(), T(1)}, T::Highest()), //
C({T::Lowest(), T::Highest(), T(0)}, T::Lowest()), //
C({T(0), T(1), T(0.25)}, T(0.25)), //
C({T(0), T(1), T(0.5)}, T(0.5)), //
C({T(0), T(1), T(0.75)}, T(0.75)), //
C({T(0), T(1000), T(0.25)}, T(250)), //
C({T(0), T(1000), T(0.5)}, T(500)), //
C({T(0), T(1000), T(0.75)}, T(750)), //
// Swap e1 and e2//
C({T(1), T(0), T(0)}, T(1)), //
C({T(1), T(0), T(1)}, T(0)), //
C({T(1), T(0), T(2)}, T(-1)), //
C({T::Highest(), T::Lowest(), T(1)}, T::Lowest()), //
C({T::Highest(), T::Lowest(), T(0)}, T::Highest()), //
C({T(1), T(0), T(0.25)}, T(0.75)), //
C({T(1), T(0), T(0.5)}, T(0.5)), //
C({T(1), T(0), T(0.75)}, T(0.25)), //
C({T(1000), T(0), T(0.25)}, T(750)), //
C({T(1000), T(0), T(0.5)}, T(500)), //
C({T(1000), T(0), T(0.75)}, T(250)),
// mix(vec, vec, vec) cases
C({Vec(T(0), T(0), T(0)), //
Vec(T(1), T(1), T(1)), //
Vec(T(0), T(1), T(2))},
Vec(T(0), T(1), T(2))),
// mix(vec, vec, scalar) cases
C({Vec(T(0), T(1), T(0)), //
Vec(T(1), T(0), T(1000)), //
Val(T(0.25))},
Vec(T(0.25), T(0.75), T(250))),
};
// Can't interpolate lowest value for f16 because (1 - lowest) is not representable as f16.
if constexpr (!std::is_same_v<T, f16>) {
ConcatInto(r, std::vector<Case>{
C({T(0), T(1), T::Lowest()}, T::Lowest()),
C({T(1), T(0), T::Highest()}, T::Lowest()),
});
}
auto error_msg = [](auto a, const char* op, auto b) {
return "12:34 error: " + OverflowErrorMessage(a, op, b) + R"(
12:34 note: when calculating mix)";
};
auto kLargeValue = T{T::Highest() / 2};
// Test f16 separately as it overflows for a different reason at the boundary inputs.
// Specifically, (1 - lowest) fails for f16 because the result is not representable.
if constexpr (!std::is_same_v<T, f16>) {
ConcatInto( //
r,
std::vector<Case>{
E({T(0), T::Highest(), T::Highest()}, error_msg(T::Highest(), "*", T::Highest())),
E({T(0), T::Lowest(), T::Lowest()}, error_msg(T::Lowest(), "*", T::Lowest())),
E({T::Highest(), T(0), T::Lowest()}, error_msg(T::Highest(), "*", T::Highest())),
E({-kLargeValue, kLargeValue, T(2)},
error_msg(T{-kLargeValue * T(1 - 2)}, "+", T{kLargeValue * T(2)})),
});
} else {
ConcatInto( //
r,
std::vector<Case>{
E({T(0), T::Highest(), T::Highest()}, error_msg(T::Highest(), "*", T::Highest())),
E({T(0), T::Lowest(), T::Lowest()}, error_msg(T(1), "-", T::Lowest())),
E({T::Highest(), T(0), T::Lowest()}, error_msg(T(1), "-", T::Lowest())),
E({-kLargeValue, kLargeValue, T(2)},
error_msg(T{-kLargeValue * T(1 - 2)}, "+", T{kLargeValue * T(2)})),
});
}
return r;
}
INSTANTIATE_TEST_SUITE_P( //
Mix,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kMix),
testing::ValuesIn(Concat(MixCases<AFloat>(), //
MixCases<f32>(), //
MixCases<f16>()))));
template <typename T>
std::vector<Case> ModfCases() {
return {

View File

@@ -101,7 +101,8 @@ inline void CheckConstant(const sem::Constant* got_constant,
[&](const auto& expected) {
using T = std::decay_t<decltype(expected)>;
ASSERT_TRUE(std::holds_alternative<T>(got_scalar));
ASSERT_TRUE(std::holds_alternative<T>(got_scalar))
<< "Scalar variant index: " << got_scalar.index();
auto got = std::get<T>(got_scalar);
if constexpr (std::is_same_v<bool, T>) {

View File

@@ -11441,36 +11441,36 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 3,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[468],
/* return matcher indices */ &kMatcherIndices[3],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::mix,
},
{
/* [263] */
/* num parameters */ 3,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[471],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::mix,
},
{
/* [264] */
/* num parameters */ 3,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[474],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::mix,
},
{
/* [265] */
@@ -14294,9 +14294,9 @@ constexpr IntrinsicInfo kBuiltins[] = {
},
{
/* [52] */
/* fn mix<T : f32_f16>(T, T, T) -> T */
/* fn mix<N : num, T : f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T> */
/* fn mix<N : num, T : f32_f16>(vec<N, T>, vec<N, T>, T) -> vec<N, T> */
/* fn mix<T : fa_f32_f16>(T, T, T) -> T */
/* fn mix<N : num, T : fa_f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T> */
/* fn mix<N : num, T : fa_f32_f16>(vec<N, T>, vec<N, T>, T) -> vec<N, T> */
/* num overloads */ 3,
/* overloads */ &kOverloads[262],
},

View File

@@ -1483,68 +1483,6 @@ Case::Create<u32V, u32V, u32V, u32V, u32V, u32V, AIntV>(
// clang-format on
));
struct IntrinsicTableAbstractTernaryTest_NonConstEval : public ResolverTestWithParam<Case> {
std::unique_ptr<IntrinsicTable> table = IntrinsicTable::Create(*this);
};
TEST_P(IntrinsicTableAbstractTernaryTest_NonConstEval, MatchMix) {
auto* arg_a = GetParam().arg_a(*this);
auto* arg_b = GetParam().arg_b(*this);
auto* arg_c = GetParam().arg_c(*this);
auto builtin = table->Lookup(sem::BuiltinType::kMix, utils::Vector{arg_a, arg_b, arg_c},
sem::EvaluationStage::kConstant, Source{{12, 34}});
bool matched = builtin.sem != nullptr;
bool expected_match = GetParam().expected_match;
EXPECT_EQ(matched, expected_match) << Diagnostics().str();
auto* result = builtin.sem ? builtin.sem->ReturnType() : nullptr;
auto* expected_result = GetParam().expected_result(*this);
EXPECT_TYPE(result, expected_result);
auto* param_a = builtin.sem ? builtin.sem->Parameters()[0]->Type() : nullptr;
auto* expected_param_a = GetParam().expected_param_a(*this);
EXPECT_TYPE(param_a, expected_param_a);
auto* param_b = builtin.sem ? builtin.sem->Parameters()[1]->Type() : nullptr;
auto* expected_param_b = GetParam().expected_param_b(*this);
EXPECT_TYPE(param_b, expected_param_b);
auto* param_c = builtin.sem ? builtin.sem->Parameters()[2]->Type() : nullptr;
auto* expected_param_c = GetParam().expected_param_c(*this);
EXPECT_TYPE(param_c, expected_param_c);
}
INSTANTIATE_TEST_SUITE_P(
AFloat_f32,
IntrinsicTableAbstractTernaryTest_NonConstEval,
testing::Values( // clang-format off
// result | param a | param b | param c | arg a | arg b | arg c
Case::Create<f32, f32, f32, f32, AFloat, AFloat, AFloat>(),
Case::Create<f32, f32, f32, f32, AFloat, AFloat, f32>(),
Case::Create<f32, f32, f32, f32, AFloat, f32, AFloat>(),
Case::Create<f32, f32, f32, f32, AFloat, f32, f32>(),
Case::Create<f32, f32, f32, f32, f32, AFloat, AFloat>(),
Case::Create<f32, f32, f32, f32, f32, AFloat, f32>(),
Case::Create<f32, f32, f32, f32, f32, f32, AFloat>()
// clang-format on
));
INSTANTIATE_TEST_SUITE_P(
VecAFloat_Vecf32,
IntrinsicTableAbstractTernaryTest_NonConstEval,
testing::Values( // clang-format off
// result | param a | param b | param c | arg a | arg b | arg c
Case::Create<f32V, f32V, f32V, f32V, AFloatV, AFloatV, AFloatV>(),
Case::Create<f32V, f32V, f32V, f32V, AFloatV, AFloatV, f32V>(),
Case::Create<f32V, f32V, f32V, f32V, AFloatV, f32V, AFloatV>(),
Case::Create<f32V, f32V, f32V, f32V, AFloatV, f32V, f32V>(),
Case::Create<f32V, f32V, f32V, f32V, f32V, AFloatV, AFloatV>(),
Case::Create<f32V, f32V, f32V, f32V, f32V, AFloatV, f32V>(),
Case::Create<f32V, f32V, f32V, f32V, f32V, f32V, AFloatV> ()
// clang-format on
));
} // namespace AbstractTernaryTests
} // namespace