mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-09 21:47:47 +00:00
Add const-eval for log and log2.
This CL adds const-eval routines for `log` and `log2`. Bug: tint:1581 Change-Id: I052b5ddd3bc8bdcd7c0925fa7912dcbe1a60e299 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111323 Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Dan Sinclair <dsinclair@chromium.org> Reviewed-by: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
committed by
Dawn LUCI CQ
parent
7a9fe30b11
commit
f2ad5fd260
@@ -493,10 +493,10 @@ fn ldexp<T: f32_f16>(T, i32) -> T
|
||||
fn ldexp<N: num, T: f32_f16>(vec<N, T>, vec<N, i32>) -> vec<N, T>
|
||||
@const fn length<T: fa_f32_f16>(@test_value(0.0) T) -> T
|
||||
@const fn length<N: num, T: fa_f32_f16>(@test_value(0.0) vec<N, T>) -> T
|
||||
fn log<T: f32_f16>(T) -> T
|
||||
fn log<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
|
||||
fn log2<T: f32_f16>(T) -> T
|
||||
fn log2<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
|
||||
@const fn log<T: fa_f32_f16>(T) -> T
|
||||
@const fn log<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
|
||||
@const fn log2<T: fa_f32_f16>(T) -> T
|
||||
@const fn log2<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
|
||||
@const fn max<T: fia_fiu32_f16>(T, T) -> T
|
||||
@const fn max<N: num, T: fia_fiu32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
@const fn min<T: fia_fiu32_f16>(T, T) -> T
|
||||
|
||||
@@ -1099,7 +1099,7 @@ TEST_P(FloatAllMatching, Scalar) {
|
||||
|
||||
utils::Vector<const ast::Expression*, 8> params;
|
||||
for (uint32_t i = 0; i < num_params; ++i) {
|
||||
params.Push(Expr(f32(i)));
|
||||
params.Push(Expr(f32(i + 1)));
|
||||
}
|
||||
auto* builtin = Call(name, params);
|
||||
Func("func", utils::Empty, ty.void_(),
|
||||
@@ -1120,7 +1120,7 @@ TEST_P(FloatAllMatching, Vec2) {
|
||||
|
||||
utils::Vector<const ast::Expression*, 8> params;
|
||||
for (uint32_t i = 0; i < num_params; ++i) {
|
||||
params.Push(vec2<f32>(f32(i), f32(i)));
|
||||
params.Push(vec2<f32>(f32(i + 1), f32(i + 1)));
|
||||
}
|
||||
auto* builtin = Call(name, params);
|
||||
Func("func", utils::Empty, ty.void_(),
|
||||
@@ -1141,7 +1141,7 @@ TEST_P(FloatAllMatching, Vec3) {
|
||||
|
||||
utils::Vector<const ast::Expression*, 8> params;
|
||||
for (uint32_t i = 0; i < num_params; ++i) {
|
||||
params.Push(vec3<f32>(f32(i), f32(i), f32(i)));
|
||||
params.Push(vec3<f32>(f32(i + 1), f32(i + 1), f32(i + 1)));
|
||||
}
|
||||
auto* builtin = Call(name, params);
|
||||
Func("func", utils::Empty, ty.void_(),
|
||||
@@ -1162,7 +1162,7 @@ TEST_P(FloatAllMatching, Vec4) {
|
||||
|
||||
utils::Vector<const ast::Expression*, 8> params;
|
||||
for (uint32_t i = 0; i < num_params; ++i) {
|
||||
params.Push(vec4<f32>(f32(i), f32(i), f32(i), f32(i)));
|
||||
params.Push(vec4<f32>(f32(i + 1), f32(i + 1), f32(i + 1), f32(i + 1)));
|
||||
}
|
||||
auto* builtin = Call(name, params);
|
||||
Func("func", utils::Empty, ty.void_(),
|
||||
|
||||
@@ -2300,6 +2300,40 @@ ConstEval::Result ConstEval::length(const sem::Type* ty,
|
||||
return r;
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::log(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto transform = [&](const sem::Constant* c0) {
|
||||
auto create = [&](auto v) -> ImplResult {
|
||||
using NumberT = decltype(v);
|
||||
if (v <= NumberT(0)) {
|
||||
AddError("log must be called with a value > 0", source);
|
||||
return utils::Failure;
|
||||
}
|
||||
return CreateElement(builder, source, c0->Type(), NumberT(std::log(v)));
|
||||
};
|
||||
return Dispatch_fa_f32_f16(create, c0);
|
||||
};
|
||||
return TransformElements(builder, ty, transform, args[0]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::log2(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto transform = [&](const sem::Constant* c0) {
|
||||
auto create = [&](auto v) -> ImplResult {
|
||||
using NumberT = decltype(v);
|
||||
if (v <= NumberT(0)) {
|
||||
AddError("log2 must be called with a value > 0", source);
|
||||
return utils::Failure;
|
||||
}
|
||||
return CreateElement(builder, source, c0->Type(), NumberT(std::log2(v)));
|
||||
};
|
||||
return Dispatch_fa_f32_f16(create, c0);
|
||||
};
|
||||
return TransformElements(builder, ty, transform, args[0]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::max(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
|
||||
@@ -628,6 +628,24 @@ class ConstEval {
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// log builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
Result log(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// log2 builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
Result log2(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// max builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
|
||||
@@ -1296,6 +1296,110 @@ INSTANTIATE_TEST_SUITE_P( //
|
||||
LengthCases<f32>(),
|
||||
LengthCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> LogCases() {
|
||||
auto error_msg = [] { return "12:34 error: log must be called with a value > 0"; };
|
||||
return std::vector<Case>{C({T(1)}, T(0)), //
|
||||
C({T(54.598150033)}, T(4)).FloatComp(0.002), //
|
||||
|
||||
E({T::Lowest()}, error_msg()), E({T(0)}, error_msg()),
|
||||
E({-T(0)}, error_msg())};
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
Log,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kLog),
|
||||
testing::ValuesIn(Concat(LogCases<AFloat>(), //
|
||||
LogCases<f32>(),
|
||||
LogCases<f16>()))));
|
||||
template <typename T>
|
||||
std::vector<Case> LogF16Cases() {
|
||||
return std::vector<Case>{
|
||||
C({T::Highest()}, T(11.085938)).FloatComp(),
|
||||
};
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
LogF16,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kLog),
|
||||
testing::ValuesIn(LogF16Cases<f16>())));
|
||||
template <typename T>
|
||||
std::vector<Case> LogF32Cases() {
|
||||
return std::vector<Case>{
|
||||
C({T::Highest()}, T(88.722839)).FloatComp(),
|
||||
};
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
LogF32,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kLog),
|
||||
testing::ValuesIn(LogF32Cases<f32>())));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> LogAbstractCases() {
|
||||
return std::vector<Case>{
|
||||
C({T::Highest()}, T(709.78271)).FloatComp(),
|
||||
};
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
LogAbstract,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kLog),
|
||||
testing::ValuesIn(LogAbstractCases<AFloat>())));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> Log2Cases() {
|
||||
auto error_msg = [] { return "12:34 error: log2 must be called with a value > 0"; };
|
||||
return std::vector<Case>{
|
||||
C({T(1)}, T(0)), //
|
||||
C({T(4)}, T(2)), //
|
||||
|
||||
E({T::Lowest()}, error_msg()),
|
||||
E({T(0)}, error_msg()),
|
||||
E({-T(0)}, error_msg()),
|
||||
};
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
Log2,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kLog2),
|
||||
testing::ValuesIn(Concat(Log2Cases<AFloat>(), //
|
||||
Log2Cases<f32>(),
|
||||
Log2Cases<f16>()))));
|
||||
template <typename T>
|
||||
std::vector<Case> Log2F16Cases() {
|
||||
return std::vector<Case>{
|
||||
C({T::Highest()}, T(15.9922)).FloatComp(),
|
||||
};
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
Log2F16,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kLog2),
|
||||
testing::ValuesIn(Log2F16Cases<f16>())));
|
||||
template <typename T>
|
||||
std::vector<Case> Log2F32Cases() {
|
||||
return std::vector<Case>{
|
||||
C({T::Highest()}, T(128)).FloatComp(),
|
||||
};
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
Log2F32,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kLog2),
|
||||
testing::ValuesIn(Log2F32Cases<f32>())));
|
||||
template <typename T>
|
||||
std::vector<Case> Log2AbstractCases() {
|
||||
return std::vector<Case>{
|
||||
C({T::Highest()}, T(1024)).FloatComp(),
|
||||
};
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
Log2Abstract,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kLog2),
|
||||
testing::ValuesIn(Log2AbstractCases<AFloat>())));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> MaxCases() {
|
||||
return {
|
||||
|
||||
@@ -12750,48 +12750,48 @@ constexpr OverloadInfo kOverloads[] = {
|
||||
/* num parameters */ 1,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 0,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[10],
|
||||
/* parameters */ &kParameters[874],
|
||||
/* return matcher indices */ &kMatcherIndices[3],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::log,
|
||||
},
|
||||
{
|
||||
/* [370] */
|
||||
/* num parameters */ 1,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[4],
|
||||
/* parameters */ &kParameters[875],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::log,
|
||||
},
|
||||
{
|
||||
/* [371] */
|
||||
/* num parameters */ 1,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 0,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[10],
|
||||
/* parameters */ &kParameters[876],
|
||||
/* return matcher indices */ &kMatcherIndices[3],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::log2,
|
||||
},
|
||||
{
|
||||
/* [372] */
|
||||
/* num parameters */ 1,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[4],
|
||||
/* parameters */ &kParameters[877],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::log2,
|
||||
},
|
||||
{
|
||||
/* [373] */
|
||||
@@ -14315,15 +14315,15 @@ constexpr IntrinsicInfo kBuiltins[] = {
|
||||
},
|
||||
{
|
||||
/* [48] */
|
||||
/* fn log<T : f32_f16>(T) -> T */
|
||||
/* fn log<N : num, T : f32_f16>(vec<N, T>) -> vec<N, T> */
|
||||
/* fn log<T : fa_f32_f16>(T) -> T */
|
||||
/* fn log<N : num, T : fa_f32_f16>(vec<N, T>) -> vec<N, T> */
|
||||
/* num overloads */ 2,
|
||||
/* overloads */ &kOverloads[369],
|
||||
},
|
||||
{
|
||||
/* [49] */
|
||||
/* fn log2<T : f32_f16>(T) -> T */
|
||||
/* fn log2<N : num, T : f32_f16>(vec<N, T>) -> vec<N, T> */
|
||||
/* fn log2<T : fa_f32_f16>(T) -> T */
|
||||
/* fn log2<N : num, T : fa_f32_f16>(vec<N, T>) -> vec<N, T> */
|
||||
/* num overloads */ 2,
|
||||
/* overloads */ &kOverloads[371],
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user