Const eval for inverseSqrt

This CL adds const-eval for the `inverseSqrt` builtin.

Bug: tint:1581
Change-Id: Ieef063416a8033b5fac9396e30c76c20b3360a90
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111581
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
dan sinclair
2022-11-24 18:16:13 +00:00
committed by Dawn LUCI CQ
parent a6670833cd
commit 7736153c15
115 changed files with 3072 additions and 190 deletions

View File

@@ -487,8 +487,8 @@ fn fract<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
@stage("fragment") fn fwidthFine<N: num>(vec<N, f32>) -> vec<N, f32>
@const fn insertBits<T: iu32>(T, T, u32, u32) -> T
@const fn insertBits<N: num, T: iu32>(vec<N, T>, vec<N, T>, u32, u32) -> vec<N, T>
fn inverseSqrt<T: f32_f16>(T) -> T
fn inverseSqrt<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
@const fn inverseSqrt<T: fa_f32_f16>(T) -> T
@const fn inverseSqrt<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
fn ldexp<T: f32_f16>(T, i32) -> T
fn ldexp<N: num, T: f32_f16>(vec<N, T>, vec<N, i32>) -> vec<N, T>
@const fn length<T: fa_f32_f16>(@test_value(0.0) T) -> T

View File

@@ -2529,6 +2529,40 @@ ConstEval::Result ConstEval::insertBits(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0], args[1]);
}
ConstEval::Result ConstEval::inverseSqrt(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto transform = [&](const sem::Constant* c0) {
auto create = [&](auto e) -> ImplResult {
using NumberT = decltype(e);
if (e <= NumberT(0)) {
AddError("inverseSqrt must be called with a value > 0", source);
return utils::Failure;
}
auto err = [&] {
AddNote("when calculating inverseSqrt", source);
return utils::Failure;
};
auto s = Sqrt(source, e);
if (!s) {
return err();
}
auto div = Div(source, NumberT(1), s.Get());
if (!div) {
return err();
}
return CreateElement(builder, source, c0->Type(), div.Get());
};
return Dispatch_fa_f32_f16(create, c0);
};
return TransformElements(builder, ty, transform, args[0]);
}
ConstEval::Result ConstEval::length(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {

View File

@@ -638,6 +638,15 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// inverseSqrt builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result inverseSqrt(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// length builtin
/// @param ty the expression type
/// @param args the input arguments

View File

@@ -1183,6 +1183,28 @@ INSTANTIATE_TEST_SUITE_P( //
testing::ValuesIn(Concat(InsertBitsCases<i32>(), //
InsertBitsCases<u32>()))));
template <typename T>
std::vector<Case> InverseSqrtCases() {
std::vector<Case> cases = {
C({T(25)}, T(.2)),
// Vector tests
C({Vec(T(25), T(100))}, Vec(T(.2), T(.1))),
E({T(0)}, "12:34 error: inverseSqrt must be called with a value > 0"),
E({-T(0)}, "12:34 error: inverseSqrt must be called with a value > 0"),
E({-T(25)}, "12:34 error: inverseSqrt must be called with a value > 0"),
};
return cases;
}
INSTANTIATE_TEST_SUITE_P( //
InverseSqrt,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kInverseSqrt),
testing::ValuesIn(Concat(InverseSqrtCases<AFloat>(), //
InverseSqrtCases<f32>(),
InverseSqrtCases<f16>()))));
template <typename T>
std::vector<Case> DegreesAFloatCases() {
return std::vector<Case>{

View File

@@ -12678,24 +12678,24 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[870],
/* return matcher indices */ &kMatcherIndices[3],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::inverseSqrt,
},
{
/* [364] */
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[871],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::inverseSqrt,
},
{
/* [365] */
@@ -14294,8 +14294,8 @@ constexpr IntrinsicInfo kBuiltins[] = {
},
{
/* [45] */
/* fn inverseSqrt<T : f32_f16>(T) -> T */
/* fn inverseSqrt<N : num, T : f32_f16>(vec<N, T>) -> vec<N, T> */
/* fn inverseSqrt<T : fa_f32_f16>(T) -> T */
/* fn inverseSqrt<N : num, T : fa_f32_f16>(vec<N, T>) -> vec<N, T> */
/* num overloads */ 2,
/* overloads */ &kOverloads[363],
},