tint: const eval of pow builtin

Bug: tint:1581
Change-Id: I11999f0adbd4b12d362e8f47772ac0a625b8fa68
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/113821
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
Antonio Maiorano
2022-12-13 16:29:06 +00:00
parent 7f5b9d0b6f
commit be96967778
98 changed files with 2419 additions and 194 deletions

View File

@@ -204,10 +204,10 @@ std::string OverflowErrorMessage(VALUE_TY value, std::string_view target_ty) {
}
template <typename NumberT>
std::string OverflowExpErrorMessage(std::string_view base, NumberT value) {
std::string OverflowExpErrorMessage(std::string_view base, NumberT exp) {
std::stringstream ss;
ss << std::setprecision(20);
ss << base << "^" << value << " cannot be represented as "
ss << base << "^" << exp << " cannot be represented as "
<< "'" << FriendlyName<NumberT>() << "'";
return ss.str();
}
@@ -463,6 +463,7 @@ struct Composite : ImplConstant {
/// CreateElement constructs and returns an Element<T>.
template <typename T>
ImplResult CreateElement(ProgramBuilder& builder, const Source& source, const type::Type* t, T v) {
static_assert(IsNumber<T> || std::is_same_v<T, bool>, "T must be a Number or bool");
TINT_ASSERT(Resolver, t->is_scalar());
if constexpr (IsFloatingPoint<T>) {
@@ -3075,6 +3076,23 @@ ConstEval::Result ConstEval::pack4x8unorm(const type::Type* ty,
return CreateElement(builder, source, ty, ret);
}
ConstEval::Result ConstEval::pow(const type::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
auto create = [&](auto e1, auto e2) -> ImplResult {
auto r = CheckedPow(e1, e2);
if (!r) {
AddError(OverflowErrorMessage(e1, "^", e2), source);
return utils::Failure;
}
return CreateElement(builder, source, c0->Type(), *r);
};
return Dispatch_fa_f32_f16(create, c0, c1);
};
return TransformElements(builder, ty, transform, args[0], args[1]);
}
ConstEval::Result ConstEval::radians(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args,
const Source& source) {

View File

@@ -833,6 +833,15 @@ class ConstEval {
utils::VectorRef<const constant::Constant*> args,
const Source& source);
/// pow builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result pow(const type::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// radians builtin
/// @param ty the expression type
/// @param args the input arguments

View File

@@ -1971,6 +1971,54 @@ INSTANTIATE_TEST_SUITE_P( //
testing::Combine(testing::Values(sem::BuiltinType::kPack2X16Unorm),
testing::ValuesIn(Pack2x16unormCases())));
template <typename T>
std::vector<Case> PowCases() {
auto error_msg = [](auto base, auto exp) {
return "12:34 error: " + OverflowErrorMessage(base, "^", exp);
};
return {
C({T(0), T(1)}, T(0)), //
C({T(0), T::Highest()}, T(0)), //
C({T(1), T(1)}, T(1)), //
C({T(1), T::Lowest()}, T(1)), //
C({T(2), T(2)}, T(4)), //
C({T(2), T(3)}, T(8)), //
// Positive base, negative exponent
C({T(1), T::Highest()}, T(1)), //
C({T(1), -T(1)}, T(1)), //
C({T(2), -T(2)}, T(0.25)), //
C({T(2), -T(3)}, T(0.125)), //
// Decimal values
C({T(2.5), T(3)}, T(15.625)), //
C({T(2), T(3.5)}, T(11.313708498)).FloatComp(), //
C({T(2.5), T(3.5)}, T(24.705294220)).FloatComp(), //
C({T(2), -T(3.5)}, T(0.0883883476)).FloatComp(), //
// Vector tests
C({Vec(T(0), T(1), T(2)), Vec(T(2), T(2), T(2))}, Vec(T(0), T(1), T(4))),
C({Vec(T(2), T(2), T(2)), Vec(T(2), T(3), T(4))}, Vec(T(4), T(8), T(16))),
// Error if base < 0
E({-T(1), T(1)}, error_msg(-T(1), T(1))),
E({-T(1), T::Highest()}, error_msg(-T(1), T::Highest())),
E({T::Lowest(), T(1)}, error_msg(T::Lowest(), T(1))),
E({T::Lowest(), T::Highest()}, error_msg(T::Lowest(), T::Highest())),
E({T::Lowest(), T::Lowest()}, error_msg(T::Lowest(), T::Lowest())),
// Error if base == 0 and exp <= 0
E({T(0), T(0)}, error_msg(T(0), T(0))),
E({T(0), -T(1)}, error_msg(T(0), -T(1))),
E({T(0), T::Lowest()}, error_msg(T(0), T::Lowest())),
};
}
INSTANTIATE_TEST_SUITE_P( //
Pow,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kPow),
testing::ValuesIn(Concat(PowCases<AFloat>(), //
PowCases<f32>(), //
PowCases<f16>()))));
template <typename T>
std::vector<Case> ReverseBitsCases() {
using B = BitValues<T>;

View File

@@ -238,10 +238,10 @@ std::string OverflowErrorMessage(VALUE_TY value, std::string_view target_ty) {
/// Returns the overflow error message for exponentiation
template <typename NumberT>
std::string OverflowExpErrorMessage(std::string_view base, NumberT value) {
std::string OverflowExpErrorMessage(std::string_view base, NumberT exp) {
std::stringstream ss;
ss << std::setprecision(20);
ss << base << "^" << value << " cannot be represented as "
ss << base << "^" << exp << " cannot be represented as "
<< "'" << FriendlyName<NumberT>() << "'";
return ss.str();
}

View File

@@ -12797,24 +12797,24 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[622],
/* return matcher indices */ &kMatcherIndices[3],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::pow,
},
{
/* [376] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[624],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::pow,
},
{
/* [377] */
@@ -14345,8 +14345,8 @@ constexpr IntrinsicInfo kBuiltins[] = {
},
{
/* [60] */
/* fn pow<T : f32_f16>(T, T) -> T */
/* fn pow<N : num, T : f32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T> */
/* fn pow<T : fa_f32_f16>(T, T) -> T */
/* fn pow<N : num, T : fa_f32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T> */
/* num overloads */ 2,
/* overloads */ &kOverloads[375],
},