Const eval for distance

This CL adds const-eval for the `distance` builtin.

Bug: tint:1581
Change-Id: Iee3af6474ace8e7baa230156f582f0a372f77cb7
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111583
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
dan sinclair
2022-11-24 21:48:48 +00:00
committed by Dawn LUCI CQ
parent 7736153c15
commit a2a8895020
98 changed files with 2414 additions and 227 deletions

View File

@@ -443,8 +443,8 @@ fn arrayLength<T, A: access>(ptr<storage, array<T>, A>) -> u32
@const fn degrees<T: fa_f32_f16>(T) -> T
@const fn degrees<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
@const fn determinant<N: num, T: fa_f32_f16>(mat<N, N, T>) -> T
fn distance<T: f32_f16>(T, T) -> T
fn distance<N: num, T: f32_f16>(vec<N, T>, vec<N, T>) -> T
@const fn distance<T: fa_f32_f16>(T, T) -> T
@const fn distance<N: num, T: fa_f32_f16>(vec<N, T>, vec<N, T>) -> T
@const fn dot<N: num, T: fia_fiu32_f16>(vec<N, T>, vec<N, T>) -> T
fn dot4I8Packed(u32, u32) -> i32
fn dot4U8Packed(u32, u32) -> u32

View File

@@ -882,8 +882,8 @@ TEST_F(ResolverBuiltinFloatTest, Distance_TooManyParams) {
EXPECT_EQ(r()->error(), R"(error: no matching call to distance(vec3<f32>, vec3<f32>, vec3<f32>)
2 candidate functions:
distance(T, T) -> T where: T is f32 or f16
distance(vecN<T>, vecN<T>) -> T where: T is f32 or f16
distance(T, T) -> T where: T is abstract-float, f32 or f16
distance(vecN<T>, vecN<T>) -> T where: T is abstract-float, f32 or f16
)");
}
@@ -896,8 +896,8 @@ TEST_F(ResolverBuiltinFloatTest, Distance_TooFewParams) {
EXPECT_EQ(r()->error(), R"(error: no matching call to distance(vec3<f32>)
2 candidate functions:
distance(T, T) -> T where: T is f32 or f16
distance(vecN<T>, vecN<T>) -> T where: T is f32 or f16
distance(T, T) -> T where: T is abstract-float, f32 or f16
distance(vecN<T>, vecN<T>) -> T where: T is abstract-float, f32 or f16
)");
}
@@ -910,8 +910,8 @@ TEST_F(ResolverBuiltinFloatTest, Distance_NoParams) {
EXPECT_EQ(r()->error(), R"(error: no matching call to distance()
2 candidate functions:
distance(T, T) -> T where: T is f32 or f16
distance(vecN<T>, vecN<T>) -> T where: T is f32 or f16
distance(T, T) -> T where: T is abstract-float, f32 or f16
distance(vecN<T>, vecN<T>) -> T where: T is abstract-float, f32 or f16
)");
}

View File

@@ -1165,6 +1165,27 @@ ConstEval::Result ConstEval::Dot(const Source& source,
return utils::Failure;
}
ConstEval::Result ConstEval::Length(const Source& source,
const sem::Type* ty,
const sem::Constant* c0) {
auto* vec_ty = c0->Type()->As<sem::Vector>();
// Evaluates to the absolute value of e if T is scalar.
if (vec_ty == nullptr) {
auto create = [&](auto e) {
using NumberT = decltype(e);
return CreateElement(builder, source, ty, NumberT{std::abs(e)});
};
return Dispatch_fa_f32_f16(create, c0);
}
// Evaluates to sqrt(e[0]^2 + e[1]^2 + ...) if T is a vector type.
auto d = Dot(source, c0, c0);
if (!d) {
return utils::Failure;
}
return Dispatch_fa_f32_f16(SqrtFunc(source, ty), d.Get());
}
auto ConstEval::Det2Func(const Source& source, const sem::Type* elem_ty) {
return [=](auto a, auto b, auto c, auto d) -> ImplResult {
if (auto r = Det2(source, a, b, c, d)) {
@@ -2221,6 +2242,27 @@ ConstEval::Result ConstEval::determinant(const sem::Type* ty,
}
return r;
}
ConstEval::Result ConstEval::distance(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto err = [&]() -> ImplResult {
AddNote("when calculating distance", source);
return utils::Failure;
};
auto minus = OpMinus(args[0]->Type(), args, source);
if (!minus) {
return err();
}
auto len = Length(source, ty, minus.Get());
if (!len) {
return err();
}
return len;
}
ConstEval::Result ConstEval::dot(const sem::Type*,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
@@ -2566,26 +2608,7 @@ ConstEval::Result ConstEval::inverseSqrt(const sem::Type* ty,
ConstEval::Result ConstEval::length(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto calculate = [&]() -> ImplResult {
auto* vec_ty = args[0]->Type()->As<sem::Vector>();
// Evaluates to the absolute value of e if T is scalar.
if (vec_ty == nullptr) {
auto create = [&](auto e) {
using NumberT = decltype(e);
return CreateElement(builder, source, ty, NumberT{std::abs(e)});
};
return Dispatch_fa_f32_f16(create, args[0]);
}
// Evaluates to sqrt(e[0]^2 + e[1]^2 + ...) if T is a vector type.
auto d = Dot(source, args[0], args[0]);
if (!d) {
return utils::Failure;
}
return Dispatch_fa_f32_f16(SqrtFunc(source, ty), d.Get());
};
auto r = calculate();
auto r = Length(source, ty, args[0]);
if (!r) {
AddNote("when calculating length", source);
}

View File

@@ -557,6 +557,15 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// distance builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result distance(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// dot builtin
/// @param ty the expression type
/// @param args the input arguments
@@ -1225,6 +1234,13 @@ class ConstEval {
/// @returns the dot product
Result Dot(const Source& source, const sem::Constant* v1, const sem::Constant* v2);
/// Return sthe length of c0
/// @param source the source location
/// @param ty the return type
/// @param c0 the constant to calculate the length of
/// @returns the length of c0
Result Length(const Source& source, const sem::Type* ty, const sem::Constant* c0);
ProgramBuilder& builder;
};

View File

@@ -802,6 +802,33 @@ INSTANTIATE_TEST_SUITE_P( //
CrossCases<f32>(), //
CrossCases<f16>()))));
template <typename T>
std::vector<Case> DistanceCases() {
auto error_msg = [](auto a, const char* op, auto b) {
return "12:34 error: " + OverflowErrorMessage(a, op, b) + R"(
12:34 note: when calculating distance)";
};
return std::vector<Case>{
C({T(0), T(0)}, T(0)),
// length(-5) -> 5
C({T(30), T(35)}, T(5)),
C({Vec(T(30), T(20)), Vec(T(25), T(15))}, Val(T(7.0710678119))).FloatComp(),
E({T::Lowest(), T::Highest()}, error_msg(T::Lowest(), "-", T::Highest())),
E({Vec(T::Highest(), T::Highest()), Vec(T(1), T(1))},
error_msg(T(T::Highest() - T(1)), "*", T(T::Highest() - T(1)))),
};
}
INSTANTIATE_TEST_SUITE_P( //
Distance,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kDistance),
testing::ValuesIn(Concat(DistanceCases<AFloat>(), //
DistanceCases<f32>(), //
DistanceCases<f16>()))));
template <typename T>
std::vector<Case> DotCases() {
auto r = std::vector<Case>{

View File

@@ -12198,24 +12198,24 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[600],
/* return matcher indices */ &kMatcherIndices[3],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::distance,
},
{
/* [324] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[602],
/* return matcher indices */ &kMatcherIndices[3],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::distance,
},
{
/* [325] */
@@ -14130,8 +14130,8 @@ constexpr IntrinsicInfo kBuiltins[] = {
},
{
/* [21] */
/* fn distance<T : f32_f16>(T, T) -> T */
/* fn distance<N : num, T : f32_f16>(vec<N, T>, vec<N, T>) -> T */
/* fn distance<T : fa_f32_f16>(T, T) -> T */
/* fn distance<N : num, T : fa_f32_f16>(vec<N, T>, vec<N, T>) -> T */
/* num overloads */ 2,
/* overloads */ &kOverloads[323],
},