Add const-eval for exp and exp2.

This CL adds const-eval routines for `exp` and `exp2`.

Bug: tint:1581
Change-Id: I59cc77aee64bfadf6ff10548f27f3d253b68a129
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111322
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
dan sinclair
2022-11-23 16:37:21 +00:00
committed by Dawn LUCI CQ
parent e9ad15ae7f
commit ae739d6d1c
226 changed files with 6186 additions and 397 deletions

View File

@@ -460,10 +460,10 @@ fn dot4U8Packed(u32, u32) -> u32
@stage("fragment") fn dpdyCoarse<N: num>(vec<N, f32>) -> vec<N, f32>
@stage("fragment") fn dpdyFine(f32) -> f32
@stage("fragment") fn dpdyFine<N: num>(vec<N, f32>) -> vec<N, f32>
fn exp<T: f32_f16>(T) -> T
fn exp<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn exp2<T: f32_f16>(T) -> T
fn exp2<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
@const fn exp<T: fa_f32_f16>(T) -> T
@const fn exp<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
@const fn exp2<T: fa_f32_f16>(T) -> T
@const fn exp2<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
@const fn extractBits<T: iu32>(T, u32, u32) -> T
@const fn extractBits<N: num, T: iu32>(vec<N, T>, u32, u32) -> vec<N, T>
fn faceForward<N: num, T: f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T>

View File

@@ -202,6 +202,15 @@ std::string OverflowErrorMessage(VALUE_TY value, std::string_view target_ty) {
return ss.str();
}
template <typename NumberT>
std::string OverflowExpErrorMessage(std::string_view base, NumberT value) {
std::stringstream ss;
ss << std::setprecision(20);
ss << base << "^" << value << " cannot be represented as "
<< "'" << FriendlyName<NumberT>() << "'";
return ss.str();
}
/// @returns the number of consecutive leading bits in `@p e` set to `@p bit_value_to_count`.
template <typename T>
std::make_unsigned_t<T> CountLeadingBits(T e, T bit_value_to_count) {
@@ -2037,6 +2046,42 @@ ConstEval::Result ConstEval::dot(const sem::Type*,
return r;
}
ConstEval::Result ConstEval::exp(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto transform = [&](const sem::Constant* c0) {
auto create = [&](auto e0) -> ImplResult {
using NumberT = decltype(e0);
auto val = NumberT(std::exp(e0));
if (!std::isfinite(val.value)) {
AddError(OverflowExpErrorMessage("e", e0), source);
return utils::Failure;
}
return CreateElement(builder, source, c0->Type(), val);
};
return Dispatch_fa_f32_f16(create, c0);
};
return TransformElements(builder, ty, transform, args[0]);
}
ConstEval::Result ConstEval::exp2(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto transform = [&](const sem::Constant* c0) {
auto create = [&](auto e0) -> ImplResult {
using NumberT = decltype(e0);
auto val = NumberT(std::exp2(e0));
if (!std::isfinite(val.value)) {
AddError(OverflowExpErrorMessage("2", e0), source);
return utils::Failure;
}
return CreateElement(builder, source, c0->Type(), val);
};
return Dispatch_fa_f32_f16(create, c0);
};
return TransformElements(builder, ty, transform, args[0]);
}
ConstEval::Result ConstEval::extractBits(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {

View File

@@ -556,6 +556,24 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// exp builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result exp(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// exp2 builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result exp2(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// extractBits builtin
/// @param ty the expression type
/// @param args the input arguments

View File

@@ -1101,6 +1101,46 @@ INSTANTIATE_TEST_SUITE_P( //
testing::Combine(testing::Values(sem::BuiltinType::kDegrees),
testing::ValuesIn(DegreesF16Cases<f16>())));
template <typename T>
std::vector<Case> ExpCases() {
auto error_msg = [](auto a) { return "12:34 error: " + OverflowExpErrorMessage("e", a); };
return std::vector<Case>{C({T(0)}, T(1)), //
C({-T(0)}, T(1)), //
C({T(2)}, T(7.3890562)).FloatComp(),
C({-T(2)}, T(0.13533528)).FloatComp(), //
C({T::Lowest()}, T(0)),
E({T::Highest()}, error_msg(T::Highest()))};
}
INSTANTIATE_TEST_SUITE_P( //
Exp,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kExp),
testing::ValuesIn(Concat(ExpCases<AFloat>(), //
ExpCases<f32>(),
ExpCases<f16>()))));
template <typename T>
std::vector<Case> Exp2Cases() {
auto error_msg = [](auto a) { return "12:34 error: " + OverflowExpErrorMessage("2", a); };
return std::vector<Case>{
C({T(0)}, T(1)), //
C({-T(0)}, T(1)), //
C({T(2)}, T(4.0)),
C({-T(2)}, T(0.25)), //
C({T::Lowest()}, T(0)),
E({T::Highest()}, error_msg(T::Highest())),
};
}
INSTANTIATE_TEST_SUITE_P( //
Exp2,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kExp2),
testing::ValuesIn(Concat(Exp2Cases<AFloat>(), //
Exp2Cases<f32>(),
Exp2Cases<f16>()))));
template <typename T>
std::vector<Case> ExtractBitsCases() {
using UT = Number<std::make_unsigned_t<UnwrapNumber<T>>>;

View File

@@ -223,7 +223,7 @@ inline std::string OverflowErrorMessage(NumberT lhs, const char* op, NumberT rhs
return ss.str();
}
/// Returns the overflow error message for converions
/// Returns the overflow error message for conversions
template <typename VALUE_TY>
std::string OverflowErrorMessage(VALUE_TY value, std::string_view target_ty) {
std::stringstream ss;
@@ -233,6 +233,16 @@ std::string OverflowErrorMessage(VALUE_TY value, std::string_view target_ty) {
return ss.str();
}
/// Returns the overflow error message for exponentiation
template <typename NumberT>
std::string OverflowExpErrorMessage(std::string_view base, NumberT value) {
std::stringstream ss;
ss << std::setprecision(20);
ss << base << "^" << value << " cannot be represented as "
<< "'" << FriendlyName<NumberT>() << "'";
return ss.str();
}
using builder::IsValue;
using builder::Mat;
using builder::Val;

View File

@@ -12366,48 +12366,48 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[850],
/* return matcher indices */ &kMatcherIndices[3],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::exp,
},
{
/* [338] */
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[851],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::exp,
},
{
/* [339] */
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[852],
/* return matcher indices */ &kMatcherIndices[3],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::exp2,
},
{
/* [340] */
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[853],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::exp2,
},
{
/* [341] */
@@ -14197,15 +14197,15 @@ constexpr IntrinsicInfo kBuiltins[] = {
},
{
/* [31] */
/* fn exp<T : f32_f16>(T) -> T */
/* fn exp<N : num, T : f32_f16>(vec<N, T>) -> vec<N, T> */
/* fn exp<T : fa_f32_f16>(T) -> T */
/* fn exp<N : num, T : fa_f32_f16>(vec<N, T>) -> vec<N, T> */
/* num overloads */ 2,
/* overloads */ &kOverloads[337],
},
{
/* [32] */
/* fn exp2<T : f32_f16>(T) -> T */
/* fn exp2<N : num, T : f32_f16>(vec<N, T>) -> vec<N, T> */
/* fn exp2<T : fa_f32_f16>(T) -> T */
/* fn exp2<N : num, T : fa_f32_f16>(vec<N, T>) -> vec<N, T> */
/* num overloads */ 2,
/* overloads */ &kOverloads[339],
},