mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-20 18:29:23 +00:00
tint: const eval of pow builtin
Bug: tint:1581 Change-Id: I11999f0adbd4b12d362e8f47772ac0a625b8fa68 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/113821 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
@@ -204,10 +204,10 @@ std::string OverflowErrorMessage(VALUE_TY value, std::string_view target_ty) {
|
||||
}
|
||||
|
||||
template <typename NumberT>
|
||||
std::string OverflowExpErrorMessage(std::string_view base, NumberT value) {
|
||||
std::string OverflowExpErrorMessage(std::string_view base, NumberT exp) {
|
||||
std::stringstream ss;
|
||||
ss << std::setprecision(20);
|
||||
ss << base << "^" << value << " cannot be represented as "
|
||||
ss << base << "^" << exp << " cannot be represented as "
|
||||
<< "'" << FriendlyName<NumberT>() << "'";
|
||||
return ss.str();
|
||||
}
|
||||
@@ -463,6 +463,7 @@ struct Composite : ImplConstant {
|
||||
/// CreateElement constructs and returns an Element<T>.
|
||||
template <typename T>
|
||||
ImplResult CreateElement(ProgramBuilder& builder, const Source& source, const type::Type* t, T v) {
|
||||
static_assert(IsNumber<T> || std::is_same_v<T, bool>, "T must be a Number or bool");
|
||||
TINT_ASSERT(Resolver, t->is_scalar());
|
||||
|
||||
if constexpr (IsFloatingPoint<T>) {
|
||||
@@ -3075,6 +3076,23 @@ ConstEval::Result ConstEval::pack4x8unorm(const type::Type* ty,
|
||||
return CreateElement(builder, source, ty, ret);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::pow(const type::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
|
||||
auto create = [&](auto e1, auto e2) -> ImplResult {
|
||||
auto r = CheckedPow(e1, e2);
|
||||
if (!r) {
|
||||
AddError(OverflowErrorMessage(e1, "^", e2), source);
|
||||
return utils::Failure;
|
||||
}
|
||||
return CreateElement(builder, source, c0->Type(), *r);
|
||||
};
|
||||
return Dispatch_fa_f32_f16(create, c0, c1);
|
||||
};
|
||||
return TransformElements(builder, ty, transform, args[0], args[1]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::radians(const type::Type* ty,
|
||||
utils::VectorRef<const constant::Constant*> args,
|
||||
const Source& source) {
|
||||
|
||||
@@ -833,6 +833,15 @@ class ConstEval {
|
||||
utils::VectorRef<const constant::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// pow builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
Result pow(const type::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// radians builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
|
||||
@@ -1971,6 +1971,54 @@ INSTANTIATE_TEST_SUITE_P( //
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kPack2X16Unorm),
|
||||
testing::ValuesIn(Pack2x16unormCases())));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> PowCases() {
|
||||
auto error_msg = [](auto base, auto exp) {
|
||||
return "12:34 error: " + OverflowErrorMessage(base, "^", exp);
|
||||
};
|
||||
return {
|
||||
C({T(0), T(1)}, T(0)), //
|
||||
C({T(0), T::Highest()}, T(0)), //
|
||||
C({T(1), T(1)}, T(1)), //
|
||||
C({T(1), T::Lowest()}, T(1)), //
|
||||
C({T(2), T(2)}, T(4)), //
|
||||
C({T(2), T(3)}, T(8)), //
|
||||
// Positive base, negative exponent
|
||||
C({T(1), T::Highest()}, T(1)), //
|
||||
C({T(1), -T(1)}, T(1)), //
|
||||
C({T(2), -T(2)}, T(0.25)), //
|
||||
C({T(2), -T(3)}, T(0.125)), //
|
||||
// Decimal values
|
||||
C({T(2.5), T(3)}, T(15.625)), //
|
||||
C({T(2), T(3.5)}, T(11.313708498)).FloatComp(), //
|
||||
C({T(2.5), T(3.5)}, T(24.705294220)).FloatComp(), //
|
||||
C({T(2), -T(3.5)}, T(0.0883883476)).FloatComp(), //
|
||||
|
||||
// Vector tests
|
||||
C({Vec(T(0), T(1), T(2)), Vec(T(2), T(2), T(2))}, Vec(T(0), T(1), T(4))),
|
||||
C({Vec(T(2), T(2), T(2)), Vec(T(2), T(3), T(4))}, Vec(T(4), T(8), T(16))),
|
||||
|
||||
// Error if base < 0
|
||||
E({-T(1), T(1)}, error_msg(-T(1), T(1))),
|
||||
E({-T(1), T::Highest()}, error_msg(-T(1), T::Highest())),
|
||||
E({T::Lowest(), T(1)}, error_msg(T::Lowest(), T(1))),
|
||||
E({T::Lowest(), T::Highest()}, error_msg(T::Lowest(), T::Highest())),
|
||||
E({T::Lowest(), T::Lowest()}, error_msg(T::Lowest(), T::Lowest())),
|
||||
|
||||
// Error if base == 0 and exp <= 0
|
||||
E({T(0), T(0)}, error_msg(T(0), T(0))),
|
||||
E({T(0), -T(1)}, error_msg(T(0), -T(1))),
|
||||
E({T(0), T::Lowest()}, error_msg(T(0), T::Lowest())),
|
||||
};
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
Pow,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kPow),
|
||||
testing::ValuesIn(Concat(PowCases<AFloat>(), //
|
||||
PowCases<f32>(), //
|
||||
PowCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> ReverseBitsCases() {
|
||||
using B = BitValues<T>;
|
||||
|
||||
@@ -238,10 +238,10 @@ std::string OverflowErrorMessage(VALUE_TY value, std::string_view target_ty) {
|
||||
|
||||
/// Returns the overflow error message for exponentiation
|
||||
template <typename NumberT>
|
||||
std::string OverflowExpErrorMessage(std::string_view base, NumberT value) {
|
||||
std::string OverflowExpErrorMessage(std::string_view base, NumberT exp) {
|
||||
std::stringstream ss;
|
||||
ss << std::setprecision(20);
|
||||
ss << base << "^" << value << " cannot be represented as "
|
||||
ss << base << "^" << exp << " cannot be represented as "
|
||||
<< "'" << FriendlyName<NumberT>() << "'";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
@@ -12797,24 +12797,24 @@ constexpr OverloadInfo kOverloads[] = {
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 0,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[10],
|
||||
/* parameters */ &kParameters[622],
|
||||
/* return matcher indices */ &kMatcherIndices[3],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::pow,
|
||||
},
|
||||
{
|
||||
/* [376] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[4],
|
||||
/* parameters */ &kParameters[624],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::pow,
|
||||
},
|
||||
{
|
||||
/* [377] */
|
||||
@@ -14345,8 +14345,8 @@ constexpr IntrinsicInfo kBuiltins[] = {
|
||||
},
|
||||
{
|
||||
/* [60] */
|
||||
/* fn pow<T : f32_f16>(T, T) -> T */
|
||||
/* fn pow<N : num, T : f32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T> */
|
||||
/* fn pow<T : fa_f32_f16>(T, T) -> T */
|
||||
/* fn pow<N : num, T : fa_f32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T> */
|
||||
/* num overloads */ 2,
|
||||
/* overloads */ &kOverloads[375],
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user