Add const-eval for tan and tanh

This CL adds const-eval for `tan` and `tanh`.

Bug: tint:1581
Change-Id: I3d3506a6e7462bba1557cb88065d696ddc21b0f6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/109562
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
dan sinclair
2022-11-10 15:55:00 +00:00
committed by Dawn LUCI CQ
parent 02d4ea06b9
commit b77332edd2
189 changed files with 4791 additions and 396 deletions

View File

@@ -542,10 +542,10 @@ fn sqrt<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
@const fn step<T: fa_f32_f16>(T, T) -> T
@const fn step<N: num, T: fa_f32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T>
@stage("compute") fn storageBarrier()
fn tan<T: f32_f16>(T) -> T
fn tan<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn tanh<T: f32_f16>(T) -> T
fn tanh<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
@const fn tan<T: fa_f32_f16>(T) -> T
@const fn tan<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
@const fn tanh<T: fa_f32_f16>(T) -> T
@const fn tanh<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
fn transpose<M: num, N: num, T: f32_f16>(mat<M, N, T>) -> mat<N, M, T>
fn trunc<T: f32_f16>(T) -> T
fn trunc<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>

View File

@@ -2351,6 +2351,32 @@ ConstEval::Result ConstEval::step(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0], args[1]);
}
ConstEval::Result ConstEval::tan(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {
auto transform = [&](const sem::Constant* c0) {
auto create = [&](auto i) -> ImplResult {
using NumberT = decltype(i);
return CreateElement(builder, c0->Type(), NumberT(std::tan(i.value)));
};
return Dispatch_fa_f32_f16(create, c0);
};
return TransformElements(builder, ty, transform, args[0]);
}
ConstEval::Result ConstEval::tanh(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {
auto transform = [&](const sem::Constant* c0) {
auto create = [&](auto i) -> ImplResult {
using NumberT = decltype(i);
return CreateElement(builder, c0->Type(), NumberT(std::tanh(i.value)));
};
return Dispatch_fa_f32_f16(create, c0);
};
return TransformElements(builder, ty, transform, args[0]);
}
ConstEval::Result ConstEval::unpack2x16float(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {

View File

@@ -701,6 +701,24 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// tan builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result tan(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// tanh builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result tanh(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// unpack2x16float builtin
/// @param ty the expression type
/// @param args the input arguments

View File

@@ -1607,6 +1607,49 @@ INSTANTIATE_TEST_SUITE_P( //
StepCases<f32>(),
StepCases<f16>()))));
template <typename T>
std::vector<Case> TanCases() {
std::vector<Case> cases = {
C({-T(0)}, -T(0)),
C({T(0)}, T(0)),
C({T(.75)}, T(0.9315964599)).FloatComp(),
// Vector test
C({Vec(T(0), -T(0), T(.75))}, Vec(T(0), -T(0), T(0.9315964599))).FloatComp(),
};
return cases;
}
INSTANTIATE_TEST_SUITE_P( //
Tan,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kTan),
testing::ValuesIn(Concat(TanCases<AFloat>(), //
TanCases<f32>(),
TanCases<f16>()))));
template <typename T>
std::vector<Case> TanhCases() {
std::vector<Case> cases = {
C({T(0)}, T(0)),
C({-T(0)}, -T(0)),
C({T(1)}, T(0.761594156)).FloatComp(),
C({T(-1)}, -T(0.761594156)).FloatComp(),
// Vector tests
C({Vec(T(0), -T(0), T(1))}, Vec(T(0), -T(0), T(0.761594156))).FloatComp(),
};
return cases;
}
INSTANTIATE_TEST_SUITE_P( //
Tanh,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kTanh),
testing::ValuesIn(Concat(TanhCases<AFloat>(), //
TanhCases<f32>(),
TanhCases<f16>()))));
std::vector<Case> Unpack4x8snormCases() {
return {
C({Val(u32(0x0000'0000))}, Vec(f32(0), f32(0), f32(0), f32(0))),

View File

@@ -12657,48 +12657,48 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[25],
/* template types */ &kTemplateTypes[24],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[869],
/* return matcher indices */ &kMatcherIndices[1],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::tan,
},
{
/* [361] */
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[25],
/* template types */ &kTemplateTypes[24],
/* template numbers */ &kTemplateNumbers[5],
/* parameters */ &kParameters[999],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::tan,
},
{
/* [362] */
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[25],
/* template types */ &kTemplateTypes[24],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[993],
/* return matcher indices */ &kMatcherIndices[1],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::tanh,
},
{
/* [363] */
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[25],
/* template types */ &kTemplateTypes[24],
/* template numbers */ &kTemplateNumbers[5],
/* parameters */ &kParameters[987],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::tanh,
},
{
/* [364] */
@@ -14531,15 +14531,15 @@ constexpr IntrinsicInfo kBuiltins[] = {
},
{
/* [76] */
/* fn tan<T : f32_f16>(T) -> T */
/* fn tan<N : num, T : f32_f16>(vec<N, T>) -> vec<N, T> */
/* fn tan<T : fa_f32_f16>(T) -> T */
/* fn tan<N : num, T : fa_f32_f16>(vec<N, T>) -> vec<N, T> */
/* num overloads */ 2,
/* overloads */ &kOverloads[360],
},
{
/* [77] */
/* fn tanh<T : f32_f16>(T) -> T */
/* fn tanh<N : num, T : f32_f16>(vec<N, T>) -> vec<N, T> */
/* fn tanh<T : fa_f32_f16>(T) -> T */
/* fn tanh<N : num, T : fa_f32_f16>(vec<N, T>) -> vec<N, T> */
/* num overloads */ 2,
/* overloads */ &kOverloads[362],
},