Add const-eval for sqrt

This CL adds const-eval for `sqrt`.

Bug: tint:1581
Change-Id: I0d109e6439feaf07edb7db4d6a50274c1c4efcae
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/110174
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
dan sinclair
2022-11-16 21:15:49 +00:00
committed by Dawn LUCI CQ
parent 92af2b5693
commit 2d90dedea4
97 changed files with 2381 additions and 190 deletions

View File

@@ -537,8 +537,8 @@ fn round<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
@const fn sinh<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
fn smoothstep<T: f32_f16>(T, T, T) -> T
fn smoothstep<N: num, T: f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T>
fn sqrt<T: f32_f16>(T) -> T
fn sqrt<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
@const fn sqrt<T: fa_f32_f16>(T) -> T
@const fn sqrt<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
@const fn step<T: fa_f32_f16>(T, T) -> T
@const fn step<N: num, T: fa_f32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T>
@stage("compute") fn storageBarrier()

View File

@@ -2375,6 +2375,24 @@ ConstEval::Result ConstEval::step(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0], args[1]);
}
ConstEval::Result ConstEval::sqrt(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto transform = [&](const sem::Constant* c0) {
auto create = [&](auto i) -> ImplResult {
using NumberT = decltype(i);
if (i < NumberT(0)) {
AddError("sqrt must be called with a value >= 0", source);
return utils::Failure;
}
return CreateElement(builder, c0->Type(), NumberT(std::sqrt(i.value)));
};
return Dispatch_fa_f32_f16(create, c0);
};
return TransformElements(builder, ty, transform, args[0]);
}
ConstEval::Result ConstEval::tan(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {

View File

@@ -719,6 +719,15 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// sqrt builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result sqrt(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// tan builtin
/// @param ty the expression type
/// @param args the input arguments

View File

@@ -1650,6 +1650,28 @@ INSTANTIATE_TEST_SUITE_P( //
StepCases<f32>(),
StepCases<f16>()))));
template <typename T>
std::vector<Case> SqrtCases() {
std::vector<Case> cases = {
C({-T(0)}, -T(0)), //
C({T(0)}, T(0)), //
C({T(25)}, T(5)),
// Vector tests
C({Vec(T(25), T(100))}, Vec(T(5), T(10))),
E({-T(25)}, "12:34 error: sqrt must be called with a value >= 0"),
};
return cases;
}
INSTANTIATE_TEST_SUITE_P( //
Sqrt,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kSqrt),
testing::ValuesIn(Concat(SqrtCases<AFloat>(), //
SqrtCases<f32>(),
SqrtCases<f16>()))));
template <typename T>
std::vector<Case> TanCases() {
std::vector<Case> cases = {

View File

@@ -12585,24 +12585,24 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[25],
/* template types */ &kTemplateTypes[24],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[912],
/* return matcher indices */ &kMatcherIndices[1],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::sqrt,
},
{
/* [355] */
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[25],
/* template types */ &kTemplateTypes[24],
/* template numbers */ &kTemplateNumbers[5],
/* parameters */ &kParameters[913],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::sqrt,
},
{
/* [356] */
@@ -14511,8 +14511,8 @@ constexpr IntrinsicInfo kBuiltins[] = {
},
{
/* [73] */
/* fn sqrt<T : f32_f16>(T) -> T */
/* fn sqrt<N : num, T : f32_f16>(vec<N, T>) -> vec<N, T> */
/* fn sqrt<T : fa_f32_f16>(T) -> T */
/* fn sqrt<N : num, T : fa_f32_f16>(vec<N, T>) -> vec<N, T> */
/* num overloads */ 2,
/* overloads */ &kOverloads[354],
},