mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-21 02:39:11 +00:00
tint: const eval of mix builtin
Bug: tint:1581 Change-Id: I3b9f0ff3a58956616daf17b3d4a922979fc30216 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/113680 Commit-Queue: Antonio Maiorano <amaiorano@google.com> Reviewed-by: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
committed by
Dawn LUCI CQ
parent
b5fedb7b5a
commit
7c9e639e35
@@ -2864,6 +2864,49 @@ ConstEval::Result ConstEval::min(const type::Type* ty,
|
||||
return TransformElements(builder, ty, transform, args[0], args[1]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::mix(const type::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1, size_t index) {
|
||||
auto create = [&](auto e1, auto e2) -> ImplResult {
|
||||
using NumberT = decltype(e1);
|
||||
// e3 is either a vector or a scalar
|
||||
NumberT e3;
|
||||
auto* c2 = args[2];
|
||||
if (c2->Type()->Is<type::Vector>()) {
|
||||
e3 = c2->Index(index)->As<NumberT>();
|
||||
} else {
|
||||
e3 = c2->As<NumberT>();
|
||||
}
|
||||
// Implement as `e1 * (1 - e3) + e2 * e3)` instead of as `e1 + e3 * (e2 - e1)` to avoid
|
||||
// float precision loss when e1 and e2 significantly differ in magnitude.
|
||||
auto one_sub_e3 = Sub(source, NumberT{1}, e3);
|
||||
if (!one_sub_e3) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto e1_mul_one_sub_e3 = Mul(source, e1, one_sub_e3.Get());
|
||||
if (!e1_mul_one_sub_e3) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto e2_mul_e3 = Mul(source, e2, e3);
|
||||
if (!e2_mul_e3) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto r = Add(source, e1_mul_one_sub_e3.Get(), e2_mul_e3.Get());
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
return CreateElement(builder, source, c0->Type(), r.Get());
|
||||
};
|
||||
return Dispatch_fa_f32_f16(create, c0, c1);
|
||||
};
|
||||
auto r = TransformElements(builder, ty, transform, args[0], args[1]);
|
||||
if (!r) {
|
||||
AddNote("when calculating mix", source);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::modf(const type::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
|
||||
@@ -757,6 +757,15 @@ class ConstEval {
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// mix builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
Result mix(const type::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// modf builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
|
||||
@@ -179,7 +179,7 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) {
|
||||
CheckConstant(value, expected_case.values[0], expected_case.flags);
|
||||
}
|
||||
} else {
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
ASSERT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), c.expected.Failure().error);
|
||||
}
|
||||
}
|
||||
@@ -1733,6 +1733,94 @@ INSTANTIATE_TEST_SUITE_P( //
|
||||
MinCases<AFloat>(),
|
||||
MinCases<f32>(),
|
||||
MinCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> MixCases() {
|
||||
auto r = std::vector<Case>{
|
||||
C({T(0), T(1), T(0)}, T(0)), //
|
||||
C({T(0), T(1), T(1)}, T(1)), //
|
||||
C({T(0), T(1), T(2)}, T(2)), //
|
||||
C({T(0), T(1), T::Highest()}, T::Highest()), //
|
||||
C({T::Lowest(), T::Highest(), T(1)}, T::Highest()), //
|
||||
C({T::Lowest(), T::Highest(), T(0)}, T::Lowest()), //
|
||||
C({T(0), T(1), T(0.25)}, T(0.25)), //
|
||||
C({T(0), T(1), T(0.5)}, T(0.5)), //
|
||||
C({T(0), T(1), T(0.75)}, T(0.75)), //
|
||||
C({T(0), T(1000), T(0.25)}, T(250)), //
|
||||
C({T(0), T(1000), T(0.5)}, T(500)), //
|
||||
C({T(0), T(1000), T(0.75)}, T(750)), //
|
||||
// Swap e1 and e2//
|
||||
C({T(1), T(0), T(0)}, T(1)), //
|
||||
C({T(1), T(0), T(1)}, T(0)), //
|
||||
C({T(1), T(0), T(2)}, T(-1)), //
|
||||
C({T::Highest(), T::Lowest(), T(1)}, T::Lowest()), //
|
||||
C({T::Highest(), T::Lowest(), T(0)}, T::Highest()), //
|
||||
C({T(1), T(0), T(0.25)}, T(0.75)), //
|
||||
C({T(1), T(0), T(0.5)}, T(0.5)), //
|
||||
C({T(1), T(0), T(0.75)}, T(0.25)), //
|
||||
C({T(1000), T(0), T(0.25)}, T(750)), //
|
||||
C({T(1000), T(0), T(0.5)}, T(500)), //
|
||||
C({T(1000), T(0), T(0.75)}, T(250)),
|
||||
|
||||
// mix(vec, vec, vec) cases
|
||||
C({Vec(T(0), T(0), T(0)), //
|
||||
Vec(T(1), T(1), T(1)), //
|
||||
Vec(T(0), T(1), T(2))},
|
||||
Vec(T(0), T(1), T(2))),
|
||||
|
||||
// mix(vec, vec, scalar) cases
|
||||
C({Vec(T(0), T(1), T(0)), //
|
||||
Vec(T(1), T(0), T(1000)), //
|
||||
Val(T(0.25))},
|
||||
Vec(T(0.25), T(0.75), T(250))),
|
||||
};
|
||||
// Can't interpolate lowest value for f16 because (1 - lowest) is not representable as f16.
|
||||
if constexpr (!std::is_same_v<T, f16>) {
|
||||
ConcatInto(r, std::vector<Case>{
|
||||
C({T(0), T(1), T::Lowest()}, T::Lowest()),
|
||||
C({T(1), T(0), T::Highest()}, T::Lowest()),
|
||||
});
|
||||
}
|
||||
|
||||
auto error_msg = [](auto a, const char* op, auto b) {
|
||||
return "12:34 error: " + OverflowErrorMessage(a, op, b) + R"(
|
||||
12:34 note: when calculating mix)";
|
||||
};
|
||||
auto kLargeValue = T{T::Highest() / 2};
|
||||
// Test f16 separately as it overflows for a different reason at the boundary inputs.
|
||||
// Specifically, (1 - lowest) fails for f16 because the result is not representable.
|
||||
if constexpr (!std::is_same_v<T, f16>) {
|
||||
ConcatInto( //
|
||||
r,
|
||||
std::vector<Case>{
|
||||
E({T(0), T::Highest(), T::Highest()}, error_msg(T::Highest(), "*", T::Highest())),
|
||||
E({T(0), T::Lowest(), T::Lowest()}, error_msg(T::Lowest(), "*", T::Lowest())),
|
||||
E({T::Highest(), T(0), T::Lowest()}, error_msg(T::Highest(), "*", T::Highest())),
|
||||
E({-kLargeValue, kLargeValue, T(2)},
|
||||
error_msg(T{-kLargeValue * T(1 - 2)}, "+", T{kLargeValue * T(2)})),
|
||||
});
|
||||
} else {
|
||||
ConcatInto( //
|
||||
r,
|
||||
std::vector<Case>{
|
||||
E({T(0), T::Highest(), T::Highest()}, error_msg(T::Highest(), "*", T::Highest())),
|
||||
E({T(0), T::Lowest(), T::Lowest()}, error_msg(T(1), "-", T::Lowest())),
|
||||
E({T::Highest(), T(0), T::Lowest()}, error_msg(T(1), "-", T::Lowest())),
|
||||
E({-kLargeValue, kLargeValue, T(2)},
|
||||
error_msg(T{-kLargeValue * T(1 - 2)}, "+", T{kLargeValue * T(2)})),
|
||||
});
|
||||
}
|
||||
|
||||
return r;
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
Mix,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kMix),
|
||||
testing::ValuesIn(Concat(MixCases<AFloat>(), //
|
||||
MixCases<f32>(), //
|
||||
MixCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> ModfCases() {
|
||||
return {
|
||||
|
||||
@@ -101,7 +101,8 @@ inline void CheckConstant(const sem::Constant* got_constant,
|
||||
[&](const auto& expected) {
|
||||
using T = std::decay_t<decltype(expected)>;
|
||||
|
||||
ASSERT_TRUE(std::holds_alternative<T>(got_scalar));
|
||||
ASSERT_TRUE(std::holds_alternative<T>(got_scalar))
|
||||
<< "Scalar variant index: " << got_scalar.index();
|
||||
auto got = std::get<T>(got_scalar);
|
||||
|
||||
if constexpr (std::is_same_v<bool, T>) {
|
||||
|
||||
@@ -11441,36 +11441,36 @@ constexpr OverloadInfo kOverloads[] = {
|
||||
/* num parameters */ 3,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 0,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[10],
|
||||
/* parameters */ &kParameters[468],
|
||||
/* return matcher indices */ &kMatcherIndices[3],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::mix,
|
||||
},
|
||||
{
|
||||
/* [263] */
|
||||
/* num parameters */ 3,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[4],
|
||||
/* parameters */ &kParameters[471],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::mix,
|
||||
},
|
||||
{
|
||||
/* [264] */
|
||||
/* num parameters */ 3,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[4],
|
||||
/* parameters */ &kParameters[474],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::mix,
|
||||
},
|
||||
{
|
||||
/* [265] */
|
||||
@@ -14294,9 +14294,9 @@ constexpr IntrinsicInfo kBuiltins[] = {
|
||||
},
|
||||
{
|
||||
/* [52] */
|
||||
/* fn mix<T : f32_f16>(T, T, T) -> T */
|
||||
/* fn mix<N : num, T : f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T> */
|
||||
/* fn mix<N : num, T : f32_f16>(vec<N, T>, vec<N, T>, T) -> vec<N, T> */
|
||||
/* fn mix<T : fa_f32_f16>(T, T, T) -> T */
|
||||
/* fn mix<N : num, T : fa_f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T> */
|
||||
/* fn mix<N : num, T : fa_f32_f16>(vec<N, T>, vec<N, T>, T) -> vec<N, T> */
|
||||
/* num overloads */ 3,
|
||||
/* overloads */ &kOverloads[262],
|
||||
},
|
||||
|
||||
@@ -1483,68 +1483,6 @@ Case::Create<u32V, u32V, u32V, u32V, u32V, u32V, AIntV>(
|
||||
// clang-format on
|
||||
));
|
||||
|
||||
struct IntrinsicTableAbstractTernaryTest_NonConstEval : public ResolverTestWithParam<Case> {
|
||||
std::unique_ptr<IntrinsicTable> table = IntrinsicTable::Create(*this);
|
||||
};
|
||||
|
||||
TEST_P(IntrinsicTableAbstractTernaryTest_NonConstEval, MatchMix) {
|
||||
auto* arg_a = GetParam().arg_a(*this);
|
||||
auto* arg_b = GetParam().arg_b(*this);
|
||||
auto* arg_c = GetParam().arg_c(*this);
|
||||
auto builtin = table->Lookup(sem::BuiltinType::kMix, utils::Vector{arg_a, arg_b, arg_c},
|
||||
sem::EvaluationStage::kConstant, Source{{12, 34}});
|
||||
|
||||
bool matched = builtin.sem != nullptr;
|
||||
bool expected_match = GetParam().expected_match;
|
||||
EXPECT_EQ(matched, expected_match) << Diagnostics().str();
|
||||
|
||||
auto* result = builtin.sem ? builtin.sem->ReturnType() : nullptr;
|
||||
auto* expected_result = GetParam().expected_result(*this);
|
||||
EXPECT_TYPE(result, expected_result);
|
||||
|
||||
auto* param_a = builtin.sem ? builtin.sem->Parameters()[0]->Type() : nullptr;
|
||||
auto* expected_param_a = GetParam().expected_param_a(*this);
|
||||
EXPECT_TYPE(param_a, expected_param_a);
|
||||
|
||||
auto* param_b = builtin.sem ? builtin.sem->Parameters()[1]->Type() : nullptr;
|
||||
auto* expected_param_b = GetParam().expected_param_b(*this);
|
||||
EXPECT_TYPE(param_b, expected_param_b);
|
||||
|
||||
auto* param_c = builtin.sem ? builtin.sem->Parameters()[2]->Type() : nullptr;
|
||||
auto* expected_param_c = GetParam().expected_param_c(*this);
|
||||
EXPECT_TYPE(param_c, expected_param_c);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
AFloat_f32,
|
||||
IntrinsicTableAbstractTernaryTest_NonConstEval,
|
||||
testing::Values( // clang-format off
|
||||
// result | param a | param b | param c | arg a | arg b | arg c
|
||||
Case::Create<f32, f32, f32, f32, AFloat, AFloat, AFloat>(),
|
||||
Case::Create<f32, f32, f32, f32, AFloat, AFloat, f32>(),
|
||||
Case::Create<f32, f32, f32, f32, AFloat, f32, AFloat>(),
|
||||
Case::Create<f32, f32, f32, f32, AFloat, f32, f32>(),
|
||||
Case::Create<f32, f32, f32, f32, f32, AFloat, AFloat>(),
|
||||
Case::Create<f32, f32, f32, f32, f32, AFloat, f32>(),
|
||||
Case::Create<f32, f32, f32, f32, f32, f32, AFloat>()
|
||||
// clang-format on
|
||||
));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
VecAFloat_Vecf32,
|
||||
IntrinsicTableAbstractTernaryTest_NonConstEval,
|
||||
testing::Values( // clang-format off
|
||||
// result | param a | param b | param c | arg a | arg b | arg c
|
||||
Case::Create<f32V, f32V, f32V, f32V, AFloatV, AFloatV, AFloatV>(),
|
||||
Case::Create<f32V, f32V, f32V, f32V, AFloatV, AFloatV, f32V>(),
|
||||
Case::Create<f32V, f32V, f32V, f32V, AFloatV, f32V, AFloatV>(),
|
||||
Case::Create<f32V, f32V, f32V, f32V, AFloatV, f32V, f32V>(),
|
||||
Case::Create<f32V, f32V, f32V, f32V, f32V, AFloatV, AFloatV>(),
|
||||
Case::Create<f32V, f32V, f32V, f32V, f32V, AFloatV, f32V>(),
|
||||
Case::Create<f32V, f32V, f32V, f32V, f32V, f32V, AFloatV> ()
|
||||
// clang-format on
|
||||
));
|
||||
|
||||
} // namespace AbstractTernaryTests
|
||||
|
||||
} // namespace
|
||||
|
||||
Reference in New Issue
Block a user