mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-10 05:57:51 +00:00
Add const-eval for exp and exp2.
This CL adds const-eval routines for `exp` and `exp2`. Bug: tint:1581 Change-Id: I59cc77aee64bfadf6ff10548f27f3d253b68a129 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111322 Reviewed-by: Antonio Maiorano <amaiorano@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
committed by
Dawn LUCI CQ
parent
e9ad15ae7f
commit
ae739d6d1c
@@ -460,10 +460,10 @@ fn dot4U8Packed(u32, u32) -> u32
|
||||
@stage("fragment") fn dpdyCoarse<N: num>(vec<N, f32>) -> vec<N, f32>
|
||||
@stage("fragment") fn dpdyFine(f32) -> f32
|
||||
@stage("fragment") fn dpdyFine<N: num>(vec<N, f32>) -> vec<N, f32>
|
||||
fn exp<T: f32_f16>(T) -> T
|
||||
fn exp<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
|
||||
fn exp2<T: f32_f16>(T) -> T
|
||||
fn exp2<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
|
||||
@const fn exp<T: fa_f32_f16>(T) -> T
|
||||
@const fn exp<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
|
||||
@const fn exp2<T: fa_f32_f16>(T) -> T
|
||||
@const fn exp2<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
|
||||
@const fn extractBits<T: iu32>(T, u32, u32) -> T
|
||||
@const fn extractBits<N: num, T: iu32>(vec<N, T>, u32, u32) -> vec<N, T>
|
||||
fn faceForward<N: num, T: f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
|
||||
@@ -202,6 +202,15 @@ std::string OverflowErrorMessage(VALUE_TY value, std::string_view target_ty) {
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
template <typename NumberT>
|
||||
std::string OverflowExpErrorMessage(std::string_view base, NumberT value) {
|
||||
std::stringstream ss;
|
||||
ss << std::setprecision(20);
|
||||
ss << base << "^" << value << " cannot be represented as "
|
||||
<< "'" << FriendlyName<NumberT>() << "'";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/// @returns the number of consecutive leading bits in `@p e` set to `@p bit_value_to_count`.
|
||||
template <typename T>
|
||||
std::make_unsigned_t<T> CountLeadingBits(T e, T bit_value_to_count) {
|
||||
@@ -2037,6 +2046,42 @@ ConstEval::Result ConstEval::dot(const sem::Type*,
|
||||
return r;
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::exp(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto transform = [&](const sem::Constant* c0) {
|
||||
auto create = [&](auto e0) -> ImplResult {
|
||||
using NumberT = decltype(e0);
|
||||
auto val = NumberT(std::exp(e0));
|
||||
if (!std::isfinite(val.value)) {
|
||||
AddError(OverflowExpErrorMessage("e", e0), source);
|
||||
return utils::Failure;
|
||||
}
|
||||
return CreateElement(builder, source, c0->Type(), val);
|
||||
};
|
||||
return Dispatch_fa_f32_f16(create, c0);
|
||||
};
|
||||
return TransformElements(builder, ty, transform, args[0]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::exp2(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto transform = [&](const sem::Constant* c0) {
|
||||
auto create = [&](auto e0) -> ImplResult {
|
||||
using NumberT = decltype(e0);
|
||||
auto val = NumberT(std::exp2(e0));
|
||||
if (!std::isfinite(val.value)) {
|
||||
AddError(OverflowExpErrorMessage("2", e0), source);
|
||||
return utils::Failure;
|
||||
}
|
||||
return CreateElement(builder, source, c0->Type(), val);
|
||||
};
|
||||
return Dispatch_fa_f32_f16(create, c0);
|
||||
};
|
||||
return TransformElements(builder, ty, transform, args[0]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::extractBits(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
|
||||
@@ -556,6 +556,24 @@ class ConstEval {
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// exp builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
Result exp(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// exp2 builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
Result exp2(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// extractBits builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
|
||||
@@ -1101,6 +1101,46 @@ INSTANTIATE_TEST_SUITE_P( //
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kDegrees),
|
||||
testing::ValuesIn(DegreesF16Cases<f16>())));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> ExpCases() {
|
||||
auto error_msg = [](auto a) { return "12:34 error: " + OverflowExpErrorMessage("e", a); };
|
||||
return std::vector<Case>{C({T(0)}, T(1)), //
|
||||
C({-T(0)}, T(1)), //
|
||||
C({T(2)}, T(7.3890562)).FloatComp(),
|
||||
C({-T(2)}, T(0.13533528)).FloatComp(), //
|
||||
C({T::Lowest()}, T(0)),
|
||||
|
||||
E({T::Highest()}, error_msg(T::Highest()))};
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
Exp,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kExp),
|
||||
testing::ValuesIn(Concat(ExpCases<AFloat>(), //
|
||||
ExpCases<f32>(),
|
||||
ExpCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> Exp2Cases() {
|
||||
auto error_msg = [](auto a) { return "12:34 error: " + OverflowExpErrorMessage("2", a); };
|
||||
return std::vector<Case>{
|
||||
C({T(0)}, T(1)), //
|
||||
C({-T(0)}, T(1)), //
|
||||
C({T(2)}, T(4.0)),
|
||||
C({-T(2)}, T(0.25)), //
|
||||
C({T::Lowest()}, T(0)),
|
||||
|
||||
E({T::Highest()}, error_msg(T::Highest())),
|
||||
};
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
Exp2,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kExp2),
|
||||
testing::ValuesIn(Concat(Exp2Cases<AFloat>(), //
|
||||
Exp2Cases<f32>(),
|
||||
Exp2Cases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> ExtractBitsCases() {
|
||||
using UT = Number<std::make_unsigned_t<UnwrapNumber<T>>>;
|
||||
|
||||
@@ -223,7 +223,7 @@ inline std::string OverflowErrorMessage(NumberT lhs, const char* op, NumberT rhs
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/// Returns the overflow error message for converions
|
||||
/// Returns the overflow error message for conversions
|
||||
template <typename VALUE_TY>
|
||||
std::string OverflowErrorMessage(VALUE_TY value, std::string_view target_ty) {
|
||||
std::stringstream ss;
|
||||
@@ -233,6 +233,16 @@ std::string OverflowErrorMessage(VALUE_TY value, std::string_view target_ty) {
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/// Returns the overflow error message for exponentiation
|
||||
template <typename NumberT>
|
||||
std::string OverflowExpErrorMessage(std::string_view base, NumberT value) {
|
||||
std::stringstream ss;
|
||||
ss << std::setprecision(20);
|
||||
ss << base << "^" << value << " cannot be represented as "
|
||||
<< "'" << FriendlyName<NumberT>() << "'";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
using builder::IsValue;
|
||||
using builder::Mat;
|
||||
using builder::Val;
|
||||
|
||||
@@ -12366,48 +12366,48 @@ constexpr OverloadInfo kOverloads[] = {
|
||||
/* num parameters */ 1,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 0,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[10],
|
||||
/* parameters */ &kParameters[850],
|
||||
/* return matcher indices */ &kMatcherIndices[3],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::exp,
|
||||
},
|
||||
{
|
||||
/* [338] */
|
||||
/* num parameters */ 1,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[4],
|
||||
/* parameters */ &kParameters[851],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::exp,
|
||||
},
|
||||
{
|
||||
/* [339] */
|
||||
/* num parameters */ 1,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 0,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[10],
|
||||
/* parameters */ &kParameters[852],
|
||||
/* return matcher indices */ &kMatcherIndices[3],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::exp2,
|
||||
},
|
||||
{
|
||||
/* [340] */
|
||||
/* num parameters */ 1,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[4],
|
||||
/* parameters */ &kParameters[853],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::exp2,
|
||||
},
|
||||
{
|
||||
/* [341] */
|
||||
@@ -14197,15 +14197,15 @@ constexpr IntrinsicInfo kBuiltins[] = {
|
||||
},
|
||||
{
|
||||
/* [31] */
|
||||
/* fn exp<T : f32_f16>(T) -> T */
|
||||
/* fn exp<N : num, T : f32_f16>(vec<N, T>) -> vec<N, T> */
|
||||
/* fn exp<T : fa_f32_f16>(T) -> T */
|
||||
/* fn exp<N : num, T : fa_f32_f16>(vec<N, T>) -> vec<N, T> */
|
||||
/* num overloads */ 2,
|
||||
/* overloads */ &kOverloads[337],
|
||||
},
|
||||
{
|
||||
/* [32] */
|
||||
/* fn exp2<T : f32_f16>(T) -> T */
|
||||
/* fn exp2<N : num, T : f32_f16>(vec<N, T>) -> vec<N, T> */
|
||||
/* fn exp2<T : fa_f32_f16>(T) -> T */
|
||||
/* fn exp2<N : num, T : fa_f32_f16>(vec<N, T>) -> vec<N, T> */
|
||||
/* num overloads */ 2,
|
||||
/* overloads */ &kOverloads[339],
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user