Add const-eval for cos and cosh

This CL adds const-eval for `cos` and `cosh`.

Bug: tint:1581
Change-Id: I8df8f979a7b351288cadccda88940fdb5a20d18f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/109561
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
dan sinclair
2022-11-10 15:52:41 +00:00
committed by Dawn LUCI CQ
parent 3c5cabbbec
commit 02d4ea06b9
325 changed files with 5211 additions and 858 deletions

View File

@@ -429,10 +429,10 @@ fn arrayLength<T, A: access>(ptr<storage, array<T>, A>) -> u32
@const fn ceil<N: num, T: fa_f32_f16>(@test_value(1.5) vec<N, T>) -> vec<N, T>
@const fn clamp<T: fia_fiu32_f16>(T, T, T) -> T
@const fn clamp<T: fia_fiu32_f16, N: num>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T>
fn cos<T: f32_f16>(T) -> T
fn cos<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn cosh<T: f32_f16>(T) -> T
fn cosh<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
@const fn cos<T: fa_f32_f16>(@test_value(0) T) -> T
@const fn cos<N: num, T: fa_f32_f16>(@test_value(0) vec<N, T>) -> vec<N, T>
@const fn cosh<T: fa_f32_f16>(@test_value(0) T) -> T
@const fn cosh<N: num, T: fa_f32_f16>(@test_value(0) vec<N, T>) -> vec<N, T>
@const fn countLeadingZeros<T: iu32>(T) -> T
@const fn countLeadingZeros<N: num, T: iu32>(vec<N, T>) -> vec<N, T>
@const fn countOneBits<T: iu32>(T) -> T

View File

@@ -1804,6 +1804,32 @@ ConstEval::Result ConstEval::clamp(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0], args[1], args[2]);
}
ConstEval::Result ConstEval::cos(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {
auto transform = [&](const sem::Constant* c0) {
auto create = [&](auto i) -> ImplResult {
using NumberT = decltype(i);
return CreateElement(builder, c0->Type(), NumberT(std::cos(i.value)));
};
return Dispatch_fa_f32_f16(create, c0);
};
return TransformElements(builder, ty, transform, args[0]);
}
ConstEval::Result ConstEval::cosh(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {
auto transform = [&](const sem::Constant* c0) {
auto create = [&](auto i) -> ImplResult {
using NumberT = decltype(i);
return CreateElement(builder, c0->Type(), NumberT(std::cosh(i.value)));
};
return Dispatch_fa_f32_f16(create, c0);
};
return TransformElements(builder, ty, transform, args[0]);
}
ConstEval::Result ConstEval::countLeadingZeros(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {

View File

@@ -485,6 +485,24 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// cos builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result cos(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// cosh builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result cosh(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// countLeadingZeros builtin
/// @param ty the expression type
/// @param args the input arguments

View File

@@ -685,6 +685,51 @@ INSTANTIATE_TEST_SUITE_P( //
ClampCases<f32>(),
ClampCases<f16>()))));
template <typename T>
std::vector<Case> CosCases() {
std::vector<Case> cases = {
C({-T(0)}, T(1)),
C({T(0)}, T(1)),
C({T(0.75)}, T(0.7316888689)).FloatComp(),
// Vector test
C({Vec(T(0), -T(0), T(0.75))}, Vec(T(1), T(1), T(0.7316888689))).FloatComp(),
};
return cases;
}
INSTANTIATE_TEST_SUITE_P( //
Cos,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kCos),
testing::ValuesIn(Concat(CosCases<AFloat>(), //
CosCases<f32>(),
CosCases<f16>()))));
template <typename T>
std::vector<Case> CoshCases() {
std::vector<Case> cases = {
C({T(0)}, T(1)),
C({-T(0)}, T(1)),
C({T(1)}, T(1.5430806348)).FloatComp(),
C({T(.75)}, T(1.2946832847)).FloatComp(),
// Vector tests
C({Vec(T(0), -T(0), T(1))}, Vec(T(1), T(1), T(1.5430806348))).FloatComp(),
};
return cases;
}
INSTANTIATE_TEST_SUITE_P( //
Cosh,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kCosh),
testing::ValuesIn(Concat(CoshCases<AFloat>(), //
CoshCases<f32>(),
CoshCases<f16>()))));
template <typename T>
std::vector<Case> CountLeadingZerosCases() {
using B = BitValues<T>;

View File

@@ -12777,48 +12777,48 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[25],
/* template types */ &kTemplateTypes[24],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[856],
/* return matcher indices */ &kMatcherIndices[1],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::cos,
},
{
/* [371] */
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[25],
/* template types */ &kTemplateTypes[24],
/* template numbers */ &kTemplateNumbers[5],
/* parameters */ &kParameters[855],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::cos,
},
{
/* [372] */
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[25],
/* template types */ &kTemplateTypes[24],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[854],
/* return matcher indices */ &kMatcherIndices[1],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::cosh,
},
{
/* [373] */
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[25],
/* template types */ &kTemplateTypes[24],
/* template numbers */ &kTemplateNumbers[5],
/* parameters */ &kParameters[853],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::cosh,
},
{
/* [374] */
@@ -14103,15 +14103,15 @@ constexpr IntrinsicInfo kBuiltins[] = {
},
{
/* [13] */
/* fn cos<T : f32_f16>(T) -> T */
/* fn cos<N : num, T : f32_f16>(vec<N, T>) -> vec<N, T> */
/* fn cos<T : fa_f32_f16>(@test_value(0) T) -> T */
/* fn cos<N : num, T : fa_f32_f16>(@test_value(0) vec<N, T>) -> vec<N, T> */
/* num overloads */ 2,
/* overloads */ &kOverloads[370],
},
{
/* [14] */
/* fn cosh<T : f32_f16>(T) -> T */
/* fn cosh<N : num, T : f32_f16>(vec<N, T>) -> vec<N, T> */
/* fn cosh<T : fa_f32_f16>(@test_value(0) T) -> T */
/* fn cosh<N : num, T : fa_f32_f16>(@test_value(0) vec<N, T>) -> vec<N, T> */
/* num overloads */ 2,
/* overloads */ &kOverloads[372],
},