Add const-eval for log and log2.

This CL adds const-eval routines for `log` and `log2`.

Bug: tint:1581
Change-Id: I052b5ddd3bc8bdcd7c0925fa7912dcbe1a60e299
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111323
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
dan sinclair
2022-11-23 17:34:40 +00:00
committed by Dawn LUCI CQ
parent 7a9fe30b11
commit f2ad5fd260
190 changed files with 4816 additions and 408 deletions

View File

@@ -493,10 +493,10 @@ fn ldexp<T: f32_f16>(T, i32) -> T
fn ldexp<N: num, T: f32_f16>(vec<N, T>, vec<N, i32>) -> vec<N, T>
@const fn length<T: fa_f32_f16>(@test_value(0.0) T) -> T
@const fn length<N: num, T: fa_f32_f16>(@test_value(0.0) vec<N, T>) -> T
fn log<T: f32_f16>(T) -> T
fn log<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn log2<T: f32_f16>(T) -> T
fn log2<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
@const fn log<T: fa_f32_f16>(T) -> T
@const fn log<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
@const fn log2<T: fa_f32_f16>(T) -> T
@const fn log2<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
@const fn max<T: fia_fiu32_f16>(T, T) -> T
@const fn max<N: num, T: fia_fiu32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T>
@const fn min<T: fia_fiu32_f16>(T, T) -> T

View File

@@ -1099,7 +1099,7 @@ TEST_P(FloatAllMatching, Scalar) {
utils::Vector<const ast::Expression*, 8> params;
for (uint32_t i = 0; i < num_params; ++i) {
params.Push(Expr(f32(i)));
params.Push(Expr(f32(i + 1)));
}
auto* builtin = Call(name, params);
Func("func", utils::Empty, ty.void_(),
@@ -1120,7 +1120,7 @@ TEST_P(FloatAllMatching, Vec2) {
utils::Vector<const ast::Expression*, 8> params;
for (uint32_t i = 0; i < num_params; ++i) {
params.Push(vec2<f32>(f32(i), f32(i)));
params.Push(vec2<f32>(f32(i + 1), f32(i + 1)));
}
auto* builtin = Call(name, params);
Func("func", utils::Empty, ty.void_(),
@@ -1141,7 +1141,7 @@ TEST_P(FloatAllMatching, Vec3) {
utils::Vector<const ast::Expression*, 8> params;
for (uint32_t i = 0; i < num_params; ++i) {
params.Push(vec3<f32>(f32(i), f32(i), f32(i)));
params.Push(vec3<f32>(f32(i + 1), f32(i + 1), f32(i + 1)));
}
auto* builtin = Call(name, params);
Func("func", utils::Empty, ty.void_(),
@@ -1162,7 +1162,7 @@ TEST_P(FloatAllMatching, Vec4) {
utils::Vector<const ast::Expression*, 8> params;
for (uint32_t i = 0; i < num_params; ++i) {
params.Push(vec4<f32>(f32(i), f32(i), f32(i), f32(i)));
params.Push(vec4<f32>(f32(i + 1), f32(i + 1), f32(i + 1), f32(i + 1)));
}
auto* builtin = Call(name, params);
Func("func", utils::Empty, ty.void_(),

View File

@@ -2300,6 +2300,40 @@ ConstEval::Result ConstEval::length(const sem::Type* ty,
return r;
}
ConstEval::Result ConstEval::log(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto transform = [&](const sem::Constant* c0) {
auto create = [&](auto v) -> ImplResult {
using NumberT = decltype(v);
if (v <= NumberT(0)) {
AddError("log must be called with a value > 0", source);
return utils::Failure;
}
return CreateElement(builder, source, c0->Type(), NumberT(std::log(v)));
};
return Dispatch_fa_f32_f16(create, c0);
};
return TransformElements(builder, ty, transform, args[0]);
}
ConstEval::Result ConstEval::log2(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto transform = [&](const sem::Constant* c0) {
auto create = [&](auto v) -> ImplResult {
using NumberT = decltype(v);
if (v <= NumberT(0)) {
AddError("log2 must be called with a value > 0", source);
return utils::Failure;
}
return CreateElement(builder, source, c0->Type(), NumberT(std::log2(v)));
};
return Dispatch_fa_f32_f16(create, c0);
};
return TransformElements(builder, ty, transform, args[0]);
}
ConstEval::Result ConstEval::max(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {

View File

@@ -628,6 +628,24 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// log builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result log(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// log2 builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result log2(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// max builtin
/// @param ty the expression type
/// @param args the input arguments

View File

@@ -1296,6 +1296,110 @@ INSTANTIATE_TEST_SUITE_P( //
LengthCases<f32>(),
LengthCases<f16>()))));
template <typename T>
std::vector<Case> LogCases() {
auto error_msg = [] { return "12:34 error: log must be called with a value > 0"; };
return std::vector<Case>{C({T(1)}, T(0)), //
C({T(54.598150033)}, T(4)).FloatComp(0.002), //
E({T::Lowest()}, error_msg()), E({T(0)}, error_msg()),
E({-T(0)}, error_msg())};
}
INSTANTIATE_TEST_SUITE_P( //
Log,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kLog),
testing::ValuesIn(Concat(LogCases<AFloat>(), //
LogCases<f32>(),
LogCases<f16>()))));
template <typename T>
std::vector<Case> LogF16Cases() {
return std::vector<Case>{
C({T::Highest()}, T(11.085938)).FloatComp(),
};
}
INSTANTIATE_TEST_SUITE_P( //
LogF16,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kLog),
testing::ValuesIn(LogF16Cases<f16>())));
template <typename T>
std::vector<Case> LogF32Cases() {
return std::vector<Case>{
C({T::Highest()}, T(88.722839)).FloatComp(),
};
}
INSTANTIATE_TEST_SUITE_P( //
LogF32,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kLog),
testing::ValuesIn(LogF32Cases<f32>())));
template <typename T>
std::vector<Case> LogAbstractCases() {
return std::vector<Case>{
C({T::Highest()}, T(709.78271)).FloatComp(),
};
}
INSTANTIATE_TEST_SUITE_P( //
LogAbstract,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kLog),
testing::ValuesIn(LogAbstractCases<AFloat>())));
template <typename T>
std::vector<Case> Log2Cases() {
auto error_msg = [] { return "12:34 error: log2 must be called with a value > 0"; };
return std::vector<Case>{
C({T(1)}, T(0)), //
C({T(4)}, T(2)), //
E({T::Lowest()}, error_msg()),
E({T(0)}, error_msg()),
E({-T(0)}, error_msg()),
};
}
INSTANTIATE_TEST_SUITE_P( //
Log2,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kLog2),
testing::ValuesIn(Concat(Log2Cases<AFloat>(), //
Log2Cases<f32>(),
Log2Cases<f16>()))));
template <typename T>
std::vector<Case> Log2F16Cases() {
return std::vector<Case>{
C({T::Highest()}, T(15.9922)).FloatComp(),
};
}
INSTANTIATE_TEST_SUITE_P( //
Log2F16,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kLog2),
testing::ValuesIn(Log2F16Cases<f16>())));
template <typename T>
std::vector<Case> Log2F32Cases() {
return std::vector<Case>{
C({T::Highest()}, T(128)).FloatComp(),
};
}
INSTANTIATE_TEST_SUITE_P( //
Log2F32,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kLog2),
testing::ValuesIn(Log2F32Cases<f32>())));
template <typename T>
std::vector<Case> Log2AbstractCases() {
return std::vector<Case>{
C({T::Highest()}, T(1024)).FloatComp(),
};
}
INSTANTIATE_TEST_SUITE_P( //
Log2Abstract,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kLog2),
testing::ValuesIn(Log2AbstractCases<AFloat>())));
template <typename T>
std::vector<Case> MaxCases() {
return {

View File

@@ -12750,48 +12750,48 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[874],
/* return matcher indices */ &kMatcherIndices[3],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::log,
},
{
/* [370] */
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[875],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::log,
},
{
/* [371] */
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[876],
/* return matcher indices */ &kMatcherIndices[3],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::log2,
},
{
/* [372] */
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[877],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::log2,
},
{
/* [373] */
@@ -14315,15 +14315,15 @@ constexpr IntrinsicInfo kBuiltins[] = {
},
{
/* [48] */
/* fn log<T : f32_f16>(T) -> T */
/* fn log<N : num, T : f32_f16>(vec<N, T>) -> vec<N, T> */
/* fn log<T : fa_f32_f16>(T) -> T */
/* fn log<N : num, T : fa_f32_f16>(vec<N, T>) -> vec<N, T> */
/* num overloads */ 2,
/* overloads */ &kOverloads[369],
},
{
/* [49] */
/* fn log2<T : f32_f16>(T) -> T */
/* fn log2<N : num, T : f32_f16>(vec<N, T>) -> vec<N, T> */
/* fn log2<T : fa_f32_f16>(T) -> T */
/* fn log2<N : num, T : fa_f32_f16>(vec<N, T>) -> vec<N, T> */
/* num overloads */ 2,
/* overloads */ &kOverloads[371],
},