mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-09 21:47:47 +00:00
Const eval for distance
This CL adds const-eval for the `distance` builtin. Bug: tint:1581 Change-Id: Iee3af6474ace8e7baa230156f582f0a372f77cb7 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111583 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Antonio Maiorano <amaiorano@google.com> Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
committed by
Dawn LUCI CQ
parent
7736153c15
commit
a2a8895020
@@ -443,8 +443,8 @@ fn arrayLength<T, A: access>(ptr<storage, array<T>, A>) -> u32
|
||||
@const fn degrees<T: fa_f32_f16>(T) -> T
|
||||
@const fn degrees<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
|
||||
@const fn determinant<N: num, T: fa_f32_f16>(mat<N, N, T>) -> T
|
||||
fn distance<T: f32_f16>(T, T) -> T
|
||||
fn distance<N: num, T: f32_f16>(vec<N, T>, vec<N, T>) -> T
|
||||
@const fn distance<T: fa_f32_f16>(T, T) -> T
|
||||
@const fn distance<N: num, T: fa_f32_f16>(vec<N, T>, vec<N, T>) -> T
|
||||
@const fn dot<N: num, T: fia_fiu32_f16>(vec<N, T>, vec<N, T>) -> T
|
||||
fn dot4I8Packed(u32, u32) -> i32
|
||||
fn dot4U8Packed(u32, u32) -> u32
|
||||
|
||||
@@ -882,8 +882,8 @@ TEST_F(ResolverBuiltinFloatTest, Distance_TooManyParams) {
|
||||
EXPECT_EQ(r()->error(), R"(error: no matching call to distance(vec3<f32>, vec3<f32>, vec3<f32>)
|
||||
|
||||
2 candidate functions:
|
||||
distance(T, T) -> T where: T is f32 or f16
|
||||
distance(vecN<T>, vecN<T>) -> T where: T is f32 or f16
|
||||
distance(T, T) -> T where: T is abstract-float, f32 or f16
|
||||
distance(vecN<T>, vecN<T>) -> T where: T is abstract-float, f32 or f16
|
||||
)");
|
||||
}
|
||||
|
||||
@@ -896,8 +896,8 @@ TEST_F(ResolverBuiltinFloatTest, Distance_TooFewParams) {
|
||||
EXPECT_EQ(r()->error(), R"(error: no matching call to distance(vec3<f32>)
|
||||
|
||||
2 candidate functions:
|
||||
distance(T, T) -> T where: T is f32 or f16
|
||||
distance(vecN<T>, vecN<T>) -> T where: T is f32 or f16
|
||||
distance(T, T) -> T where: T is abstract-float, f32 or f16
|
||||
distance(vecN<T>, vecN<T>) -> T where: T is abstract-float, f32 or f16
|
||||
)");
|
||||
}
|
||||
|
||||
@@ -910,8 +910,8 @@ TEST_F(ResolverBuiltinFloatTest, Distance_NoParams) {
|
||||
EXPECT_EQ(r()->error(), R"(error: no matching call to distance()
|
||||
|
||||
2 candidate functions:
|
||||
distance(T, T) -> T where: T is f32 or f16
|
||||
distance(vecN<T>, vecN<T>) -> T where: T is f32 or f16
|
||||
distance(T, T) -> T where: T is abstract-float, f32 or f16
|
||||
distance(vecN<T>, vecN<T>) -> T where: T is abstract-float, f32 or f16
|
||||
)");
|
||||
}
|
||||
|
||||
|
||||
@@ -1165,6 +1165,27 @@ ConstEval::Result ConstEval::Dot(const Source& source,
|
||||
return utils::Failure;
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::Length(const Source& source,
|
||||
const sem::Type* ty,
|
||||
const sem::Constant* c0) {
|
||||
auto* vec_ty = c0->Type()->As<sem::Vector>();
|
||||
// Evaluates to the absolute value of e if T is scalar.
|
||||
if (vec_ty == nullptr) {
|
||||
auto create = [&](auto e) {
|
||||
using NumberT = decltype(e);
|
||||
return CreateElement(builder, source, ty, NumberT{std::abs(e)});
|
||||
};
|
||||
return Dispatch_fa_f32_f16(create, c0);
|
||||
}
|
||||
|
||||
// Evaluates to sqrt(e[0]^2 + e[1]^2 + ...) if T is a vector type.
|
||||
auto d = Dot(source, c0, c0);
|
||||
if (!d) {
|
||||
return utils::Failure;
|
||||
}
|
||||
return Dispatch_fa_f32_f16(SqrtFunc(source, ty), d.Get());
|
||||
}
|
||||
|
||||
auto ConstEval::Det2Func(const Source& source, const sem::Type* elem_ty) {
|
||||
return [=](auto a, auto b, auto c, auto d) -> ImplResult {
|
||||
if (auto r = Det2(source, a, b, c, d)) {
|
||||
@@ -2221,6 +2242,27 @@ ConstEval::Result ConstEval::determinant(const sem::Type* ty,
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::distance(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto err = [&]() -> ImplResult {
|
||||
AddNote("when calculating distance", source);
|
||||
return utils::Failure;
|
||||
};
|
||||
|
||||
auto minus = OpMinus(args[0]->Type(), args, source);
|
||||
if (!minus) {
|
||||
return err();
|
||||
}
|
||||
|
||||
auto len = Length(source, ty, minus.Get());
|
||||
if (!len) {
|
||||
return err();
|
||||
}
|
||||
return len;
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::dot(const sem::Type*,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
@@ -2566,26 +2608,7 @@ ConstEval::Result ConstEval::inverseSqrt(const sem::Type* ty,
|
||||
ConstEval::Result ConstEval::length(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto calculate = [&]() -> ImplResult {
|
||||
auto* vec_ty = args[0]->Type()->As<sem::Vector>();
|
||||
|
||||
// Evaluates to the absolute value of e if T is scalar.
|
||||
if (vec_ty == nullptr) {
|
||||
auto create = [&](auto e) {
|
||||
using NumberT = decltype(e);
|
||||
return CreateElement(builder, source, ty, NumberT{std::abs(e)});
|
||||
};
|
||||
return Dispatch_fa_f32_f16(create, args[0]);
|
||||
}
|
||||
|
||||
// Evaluates to sqrt(e[0]^2 + e[1]^2 + ...) if T is a vector type.
|
||||
auto d = Dot(source, args[0], args[0]);
|
||||
if (!d) {
|
||||
return utils::Failure;
|
||||
}
|
||||
return Dispatch_fa_f32_f16(SqrtFunc(source, ty), d.Get());
|
||||
};
|
||||
auto r = calculate();
|
||||
auto r = Length(source, ty, args[0]);
|
||||
if (!r) {
|
||||
AddNote("when calculating length", source);
|
||||
}
|
||||
|
||||
@@ -557,6 +557,15 @@ class ConstEval {
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// distance builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location of the conversion
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
Result distance(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// dot builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
@@ -1225,6 +1234,13 @@ class ConstEval {
|
||||
/// @returns the dot product
|
||||
Result Dot(const Source& source, const sem::Constant* v1, const sem::Constant* v2);
|
||||
|
||||
/// Return sthe length of c0
|
||||
/// @param source the source location
|
||||
/// @param ty the return type
|
||||
/// @param c0 the constant to calculate the length of
|
||||
/// @returns the length of c0
|
||||
Result Length(const Source& source, const sem::Type* ty, const sem::Constant* c0);
|
||||
|
||||
ProgramBuilder& builder;
|
||||
};
|
||||
|
||||
|
||||
@@ -802,6 +802,33 @@ INSTANTIATE_TEST_SUITE_P( //
|
||||
CrossCases<f32>(), //
|
||||
CrossCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> DistanceCases() {
|
||||
auto error_msg = [](auto a, const char* op, auto b) {
|
||||
return "12:34 error: " + OverflowErrorMessage(a, op, b) + R"(
|
||||
12:34 note: when calculating distance)";
|
||||
};
|
||||
|
||||
return std::vector<Case>{
|
||||
C({T(0), T(0)}, T(0)),
|
||||
// length(-5) -> 5
|
||||
C({T(30), T(35)}, T(5)),
|
||||
|
||||
C({Vec(T(30), T(20)), Vec(T(25), T(15))}, Val(T(7.0710678119))).FloatComp(),
|
||||
|
||||
E({T::Lowest(), T::Highest()}, error_msg(T::Lowest(), "-", T::Highest())),
|
||||
E({Vec(T::Highest(), T::Highest()), Vec(T(1), T(1))},
|
||||
error_msg(T(T::Highest() - T(1)), "*", T(T::Highest() - T(1)))),
|
||||
};
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
Distance,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kDistance),
|
||||
testing::ValuesIn(Concat(DistanceCases<AFloat>(), //
|
||||
DistanceCases<f32>(), //
|
||||
DistanceCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> DotCases() {
|
||||
auto r = std::vector<Case>{
|
||||
|
||||
@@ -12198,24 +12198,24 @@ constexpr OverloadInfo kOverloads[] = {
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 0,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[10],
|
||||
/* parameters */ &kParameters[600],
|
||||
/* return matcher indices */ &kMatcherIndices[3],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::distance,
|
||||
},
|
||||
{
|
||||
/* [324] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[4],
|
||||
/* parameters */ &kParameters[602],
|
||||
/* return matcher indices */ &kMatcherIndices[3],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::distance,
|
||||
},
|
||||
{
|
||||
/* [325] */
|
||||
@@ -14130,8 +14130,8 @@ constexpr IntrinsicInfo kBuiltins[] = {
|
||||
},
|
||||
{
|
||||
/* [21] */
|
||||
/* fn distance<T : f32_f16>(T, T) -> T */
|
||||
/* fn distance<N : num, T : f32_f16>(vec<N, T>, vec<N, T>) -> T */
|
||||
/* fn distance<T : fa_f32_f16>(T, T) -> T */
|
||||
/* fn distance<N : num, T : fa_f32_f16>(vec<N, T>, vec<N, T>) -> T */
|
||||
/* num overloads */ 2,
|
||||
/* overloads */ &kOverloads[323],
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user