mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-16 08:27:05 +00:00
tint: const eval of length builtin
Bug: tint:1581 Change-Id: Ie6dc9da6b48c606af03da023c835ec36a99dd362 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/110981 Reviewed-by: Dan Sinclair <dsinclair@chromium.org> Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
committed by
Dawn LUCI CQ
parent
89f15fc57e
commit
92d858ac3c
@@ -491,8 +491,8 @@ fn inverseSqrt<T: f32_f16>(T) -> T
|
||||
fn inverseSqrt<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
|
||||
fn ldexp<T: f32_f16>(T, i32) -> T
|
||||
fn ldexp<N: num, T: f32_f16>(vec<N, T>, vec<N, i32>) -> vec<N, T>
|
||||
fn length<T: f32_f16>(T) -> T
|
||||
fn length<N: num, T: f32_f16>(vec<N, T>) -> T
|
||||
@const fn length<T: fa_f32_f16>(@test_value(0.0) T) -> T
|
||||
@const fn length<N: num, T: fa_f32_f16>(@test_value(0.0) vec<N, T>) -> T
|
||||
fn log<T: f32_f16>(T) -> T
|
||||
fn log<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
|
||||
fn log2<T: f32_f16>(T) -> T
|
||||
|
||||
@@ -1184,8 +1184,8 @@ TEST_F(ResolverBuiltinFloatTest, Length_NoParams) {
|
||||
EXPECT_EQ(r()->error(), R"(error: no matching call to length()
|
||||
|
||||
2 candidate functions:
|
||||
length(T) -> T where: T is f32 or f16
|
||||
length(vecN<T>) -> T where: T is f32 or f16
|
||||
length(T) -> T where: T is abstract-float, f32 or f16
|
||||
length(vecN<T>) -> T where: T is abstract-float, f32 or f16
|
||||
)");
|
||||
}
|
||||
|
||||
@@ -1198,8 +1198,8 @@ TEST_F(ResolverBuiltinFloatTest, Length_TooManyParams) {
|
||||
EXPECT_EQ(r()->error(), R"(error: no matching call to length(f32, f32)
|
||||
|
||||
2 candidate functions:
|
||||
length(T) -> T where: T is f32 or f16
|
||||
length(vecN<T>) -> T where: T is f32 or f16
|
||||
length(T) -> T where: T is abstract-float, f32 or f16
|
||||
length(vecN<T>) -> T where: T is abstract-float, f32 or f16
|
||||
)");
|
||||
}
|
||||
|
||||
|
||||
@@ -890,6 +890,24 @@ utils::Result<NumberT> ConstEval::Det2(const Source& source,
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> ConstEval::Sqrt(const Source& source, NumberT v) {
|
||||
if (v < NumberT(0)) {
|
||||
AddError("sqrt must be called with a value >= 0", source);
|
||||
return utils::Failure;
|
||||
}
|
||||
return NumberT{std::sqrt(v)};
|
||||
}
|
||||
|
||||
auto ConstEval::SqrtFunc(const Source& source, const sem::Type* elem_ty) {
|
||||
return [=](auto v) -> ImplResult {
|
||||
if (auto r = Sqrt(source, v)) {
|
||||
return CreateElement(builder, source, elem_ty, r.Get());
|
||||
}
|
||||
return utils::Failure;
|
||||
};
|
||||
}
|
||||
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> ConstEval::Clamp(const Source&, NumberT e, NumberT low, NumberT high) {
|
||||
return NumberT{std::min(std::max(e, low), high)};
|
||||
@@ -968,6 +986,33 @@ auto ConstEval::Dot4Func(const Source& source, const sem::Type* elem_ty) {
|
||||
};
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::Dot(const Source& source,
|
||||
const sem::Constant* v1,
|
||||
const sem::Constant* v2) {
|
||||
auto* vec_ty = v1->Type()->As<sem::Vector>();
|
||||
TINT_ASSERT(Resolver, vec_ty);
|
||||
auto* elem_ty = vec_ty->type();
|
||||
switch (vec_ty->Width()) {
|
||||
case 2:
|
||||
return Dispatch_fia_fiu32_f16( //
|
||||
Dot2Func(source, elem_ty), //
|
||||
v1->Index(0), v1->Index(1), //
|
||||
v2->Index(0), v2->Index(1));
|
||||
case 3:
|
||||
return Dispatch_fia_fiu32_f16( //
|
||||
Dot3Func(source, elem_ty), //
|
||||
v1->Index(0), v1->Index(1), v1->Index(2), //
|
||||
v2->Index(0), v2->Index(1), v2->Index(2));
|
||||
case 4:
|
||||
return Dispatch_fia_fiu32_f16( //
|
||||
Dot4Func(source, elem_ty), //
|
||||
v1->Index(0), v1->Index(1), v1->Index(2), v1->Index(3), //
|
||||
v2->Index(0), v2->Index(1), v2->Index(2), v2->Index(3));
|
||||
}
|
||||
TINT_ICE(Resolver, builder.Diagnostics()) << "Expected vector";
|
||||
return utils::Failure;
|
||||
}
|
||||
|
||||
auto ConstEval::Det2Func(const Source& source, const sem::Type* elem_ty) {
|
||||
return [=](auto a, auto b, auto c, auto d) -> ImplResult {
|
||||
if (auto r = Det2(source, a, b, c, d)) {
|
||||
@@ -1969,30 +2014,10 @@ ConstEval::Result ConstEval::degrees(const sem::Type* ty,
|
||||
return TransformElements(builder, ty, transform, args[0]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::dot(const sem::Type* ty,
|
||||
ConstEval::Result ConstEval::dot(const sem::Type*,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto calculate = [&]() -> ImplResult {
|
||||
auto* v1 = args[0];
|
||||
auto* v2 = args[1];
|
||||
auto* vec_ty = v1->Type()->As<sem::Vector>();
|
||||
switch (vec_ty->Width()) {
|
||||
case 2:
|
||||
return Dispatch_fia_fiu32_f16(Dot2Func(source, ty), v1->Index(0), v1->Index(1),
|
||||
v2->Index(0), v2->Index(1));
|
||||
case 3:
|
||||
return Dispatch_fia_fiu32_f16(Dot3Func(source, ty), v1->Index(0), v1->Index(1),
|
||||
v1->Index(2), v2->Index(0), v2->Index(1),
|
||||
v2->Index(2));
|
||||
case 4:
|
||||
return Dispatch_fia_fiu32_f16(Dot4Func(source, ty), v1->Index(0), v1->Index(1),
|
||||
v1->Index(2), v1->Index(3), v2->Index(0),
|
||||
v2->Index(1), v2->Index(2), v2->Index(3));
|
||||
}
|
||||
TINT_ICE(Resolver, builder.Diagnostics()) << "Expected scalar or vector";
|
||||
return utils::Failure;
|
||||
};
|
||||
auto r = calculate();
|
||||
auto r = Dot(source, args[0], args[1]);
|
||||
if (!r) {
|
||||
AddNote("when calculating dot", source);
|
||||
}
|
||||
@@ -2188,6 +2213,35 @@ ConstEval::Result ConstEval::insertBits(const sem::Type* ty,
|
||||
return TransformElements(builder, ty, transform, args[0], args[1]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::length(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto calculate = [&]() -> ImplResult {
|
||||
auto* vec_ty = args[0]->Type()->As<sem::Vector>();
|
||||
|
||||
// Evaluates to the absolute value of e if T is scalar.
|
||||
if (vec_ty == nullptr) {
|
||||
auto create = [&](auto e) {
|
||||
using NumberT = decltype(e);
|
||||
return CreateElement(builder, source, ty, NumberT{std::abs(e)});
|
||||
};
|
||||
return Dispatch_fa_f32_f16(create, args[0]);
|
||||
}
|
||||
|
||||
// Evaluates to sqrt(e[0]^2 + e[1]^2 + ...) if T is a vector type.
|
||||
auto d = Dot(source, args[0], args[0]);
|
||||
if (!d) {
|
||||
return utils::Failure;
|
||||
}
|
||||
return Dispatch_fa_f32_f16(SqrtFunc(source, ty), d.Get());
|
||||
};
|
||||
auto r = calculate();
|
||||
if (!r) {
|
||||
AddNote("when calculating length", source);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::max(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
@@ -2561,15 +2615,7 @@ ConstEval::Result ConstEval::sqrt(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto transform = [&](const sem::Constant* c0) {
|
||||
auto create = [&](auto i) -> ImplResult {
|
||||
using NumberT = decltype(i);
|
||||
if (i < NumberT(0)) {
|
||||
AddError("sqrt must be called with a value >= 0", source);
|
||||
return utils::Failure;
|
||||
}
|
||||
return CreateElement(builder, source, c0->Type(), NumberT(std::sqrt(i.value)));
|
||||
};
|
||||
return Dispatch_fa_f32_f16(create, c0);
|
||||
return Dispatch_fa_f32_f16(SqrtFunc(source, c0->Type()), c0);
|
||||
};
|
||||
|
||||
return TransformElements(builder, ty, transform, args[0]);
|
||||
|
||||
@@ -601,6 +601,15 @@ class ConstEval {
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// length builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
Result length(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// max builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
@@ -962,6 +971,9 @@ class ConstEval {
|
||||
NumberT b1,
|
||||
NumberT b2);
|
||||
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> Sqrt(const Source& source, NumberT v);
|
||||
|
||||
/// Clamps e between low and high
|
||||
/// @param source the source location
|
||||
/// @param e the number to clamp
|
||||
@@ -1034,6 +1046,20 @@ class ConstEval {
|
||||
/// @returns the callable function
|
||||
auto ClampFunc(const Source& source, const sem::Type* elem_ty);
|
||||
|
||||
/// Returns a callable that calls SqrtFunc, and creates a Constant with its
|
||||
/// result of type `elem_ty` if successful, or returns Failure otherwise.
|
||||
/// @param source the source location
|
||||
/// @param elem_ty the element type of the Constant to create on success
|
||||
/// @returns the callable function
|
||||
auto SqrtFunc(const Source& source, const sem::Type* elem_ty);
|
||||
|
||||
/// Returns the dot product of v1 and v2.
|
||||
/// @param source the source location
|
||||
/// @param v1 the first vector
|
||||
/// @param v2 the second vector
|
||||
/// @returns the dot product
|
||||
Result Dot(const Source& source, const sem::Constant* v1, const sem::Constant* v2);
|
||||
|
||||
ProgramBuilder& builder;
|
||||
};
|
||||
|
||||
|
||||
@@ -41,10 +41,12 @@ struct Case {
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Expected value should be compared using FLOAT_EQ instead of EQ
|
||||
Case& FloatComp() {
|
||||
/// Expected value should be compared using EXPECT_FLOAT_EQ instead of EXPECT_EQ.
|
||||
/// If optional epsilon is passed in, will be compared using EXPECT_NEAR with that epsilon.
|
||||
Case& FloatComp(std::optional<double> epsilon = {}) {
|
||||
Success s = expected.Get();
|
||||
s.flags.float_compare = true;
|
||||
s.flags.float_compare_epsilon = epsilon;
|
||||
expected = s;
|
||||
return *this;
|
||||
}
|
||||
@@ -1176,6 +1178,65 @@ INSTANTIATE_TEST_SUITE_P( //
|
||||
testing::ValuesIn(Concat(ExtractBitsCases<i32>(), //
|
||||
ExtractBitsCases<u32>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> LengthCases() {
|
||||
const auto kSqrtOfHighest = T(std::sqrt(T::Highest()));
|
||||
const auto kSqrtOfHighestSquared = T(kSqrtOfHighest * kSqrtOfHighest);
|
||||
|
||||
auto error_msg = [](auto a, const char* op, auto b) {
|
||||
return "12:34 error: " + OverflowErrorMessage(a, op, b) + R"(
|
||||
12:34 note: when calculating length)";
|
||||
};
|
||||
return {
|
||||
C({T(0)}, T(0)),
|
||||
C({Vec(T(0), T(0))}, Val(T(0))),
|
||||
C({Vec(T(0), T(0), T(0))}, Val(T(0))),
|
||||
C({Vec(T(0), T(0), T(0), T(0))}, Val(T(0))),
|
||||
|
||||
C({T(1)}, T(1)),
|
||||
C({Vec(T(1), T(1))}, Val(T(std::sqrt(2)))),
|
||||
C({Vec(T(1), T(1), T(1))}, Val(T(std::sqrt(3)))),
|
||||
C({Vec(T(1), T(1), T(1), T(1))}, Val(T(std::sqrt(4)))),
|
||||
|
||||
C({T(2)}, T(2)),
|
||||
C({Vec(T(2), T(2))}, Val(T(std::sqrt(8)))),
|
||||
C({Vec(T(2), T(2), T(2))}, Val(T(std::sqrt(12)))),
|
||||
C({Vec(T(2), T(2), T(2), T(2))}, Val(T(std::sqrt(16)))),
|
||||
|
||||
C({Vec(T(2), T(3))}, Val(T(std::sqrt(13)))),
|
||||
C({Vec(T(2), T(3), T(4))}, Val(T(std::sqrt(29)))),
|
||||
C({Vec(T(2), T(3), T(4), T(5))}, Val(T(std::sqrt(54)))),
|
||||
|
||||
C({T(-5)}, T(5)),
|
||||
C({T::Highest()}, T::Highest()),
|
||||
C({T::Lowest()}, T::Highest()),
|
||||
|
||||
C({Vec(T(-2), T(-3), T(-4), T(-5))}, Val(T(std::sqrt(54)))),
|
||||
C({Vec(T(2), T(-3), T(4), T(-5))}, Val(T(std::sqrt(54)))),
|
||||
C({Vec(T(-2), T(3), T(-4), T(5))}, Val(T(std::sqrt(54)))),
|
||||
|
||||
C({Vec(kSqrtOfHighest, T(0))}, Val(kSqrtOfHighest)).FloatComp(0.2),
|
||||
C({Vec(T(0), kSqrtOfHighest)}, Val(kSqrtOfHighest)).FloatComp(0.2),
|
||||
|
||||
C({Vec(-kSqrtOfHighest, T(0))}, Val(kSqrtOfHighest)).FloatComp(0.2),
|
||||
C({Vec(T(0), -kSqrtOfHighest)}, Val(kSqrtOfHighest)).FloatComp(0.2),
|
||||
|
||||
// Overflow when squaring a term
|
||||
E({Vec(T::Highest(), T(0))}, error_msg(T::Highest(), "*", T::Highest())),
|
||||
E({Vec(T(0), T::Highest())}, error_msg(T::Highest(), "*", T::Highest())),
|
||||
// Overflow when adding squared terms
|
||||
E({Vec(kSqrtOfHighest, kSqrtOfHighest)},
|
||||
error_msg(kSqrtOfHighestSquared, "+", kSqrtOfHighestSquared)),
|
||||
};
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
Length,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kLength),
|
||||
testing::ValuesIn(Concat(LengthCases<AFloat>(), //
|
||||
LengthCases<f32>(),
|
||||
LengthCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> MaxCases() {
|
||||
return {
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#define SRC_TINT_RESOLVER_CONST_EVAL_TEST_H_
|
||||
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
@@ -74,8 +75,11 @@ inline auto Abs(const Number<T>& v) {
|
||||
struct CheckConstantFlags {
|
||||
/// Expected value may be positive or negative
|
||||
bool pos_or_neg = false;
|
||||
/// Expected value should be compared using FLOAT_EQ instead of EQ
|
||||
/// Expected value should be compared using EXPECT_FLOAT_EQ instead of EQ, or EXPECT_NEAR if
|
||||
/// float_compare_epsilon is set.
|
||||
bool float_compare = false;
|
||||
/// Expected value should be compared using EXPECT_NEAR if float_compare is set.
|
||||
std::optional<double> float_compare_epsilon;
|
||||
};
|
||||
|
||||
/// CheckConstant checks that @p got_constant, the result value of
|
||||
@@ -106,18 +110,16 @@ inline void CheckConstant(const sem::Constant* got_constant,
|
||||
EXPECT_TRUE(std::isnan(got));
|
||||
} else {
|
||||
if (flags.pos_or_neg) {
|
||||
auto got_abs = Abs(got);
|
||||
if (flags.float_compare) {
|
||||
EXPECT_FLOAT_EQ(got_abs, expected);
|
||||
got = Abs(got);
|
||||
}
|
||||
if (flags.float_compare) {
|
||||
if (flags.float_compare_epsilon) {
|
||||
EXPECT_NEAR(got, expected, *flags.float_compare_epsilon);
|
||||
} else {
|
||||
EXPECT_EQ(got_abs, expected);
|
||||
EXPECT_FLOAT_EQ(got, expected);
|
||||
}
|
||||
} else {
|
||||
if (flags.float_compare) {
|
||||
EXPECT_FLOAT_EQ(got, expected);
|
||||
} else {
|
||||
EXPECT_EQ(got, expected);
|
||||
}
|
||||
EXPECT_EQ(got, expected);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -12726,24 +12726,24 @@ constexpr OverloadInfo kOverloads[] = {
|
||||
/* num parameters */ 1,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 0,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[10],
|
||||
/* parameters */ &kParameters[872],
|
||||
/* return matcher indices */ &kMatcherIndices[3],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::length,
|
||||
},
|
||||
{
|
||||
/* [368] */
|
||||
/* num parameters */ 1,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[4],
|
||||
/* parameters */ &kParameters[873],
|
||||
/* return matcher indices */ &kMatcherIndices[3],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::length,
|
||||
},
|
||||
{
|
||||
/* [369] */
|
||||
@@ -14308,8 +14308,8 @@ constexpr IntrinsicInfo kBuiltins[] = {
|
||||
},
|
||||
{
|
||||
/* [47] */
|
||||
/* fn length<T : f32_f16>(T) -> T */
|
||||
/* fn length<N : num, T : f32_f16>(vec<N, T>) -> T */
|
||||
/* fn length<T : fa_f32_f16>(@test_value(0) T) -> T */
|
||||
/* fn length<N : num, T : fa_f32_f16>(@test_value(0) vec<N, T>) -> T */
|
||||
/* num overloads */ 2,
|
||||
/* overloads */ &kOverloads[367],
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user