tint: const eval of length builtin

Bug: tint:1581
Change-Id: Ie6dc9da6b48c606af03da023c835ec36a99dd362
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/110981
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Antonio Maiorano
2022-11-23 00:00:45 +00:00
committed by Dawn LUCI CQ
parent 89f15fc57e
commit 92d858ac3c
167 changed files with 2726 additions and 488 deletions

View File

@@ -491,8 +491,8 @@ fn inverseSqrt<T: f32_f16>(T) -> T
fn inverseSqrt<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn ldexp<T: f32_f16>(T, i32) -> T
fn ldexp<N: num, T: f32_f16>(vec<N, T>, vec<N, i32>) -> vec<N, T>
fn length<T: f32_f16>(T) -> T
fn length<N: num, T: f32_f16>(vec<N, T>) -> T
@const fn length<T: fa_f32_f16>(@test_value(0.0) T) -> T
@const fn length<N: num, T: fa_f32_f16>(@test_value(0.0) vec<N, T>) -> T
fn log<T: f32_f16>(T) -> T
fn log<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn log2<T: f32_f16>(T) -> T

View File

@@ -1184,8 +1184,8 @@ TEST_F(ResolverBuiltinFloatTest, Length_NoParams) {
EXPECT_EQ(r()->error(), R"(error: no matching call to length()
2 candidate functions:
length(T) -> T where: T is f32 or f16
length(vecN<T>) -> T where: T is f32 or f16
length(T) -> T where: T is abstract-float, f32 or f16
length(vecN<T>) -> T where: T is abstract-float, f32 or f16
)");
}
@@ -1198,8 +1198,8 @@ TEST_F(ResolverBuiltinFloatTest, Length_TooManyParams) {
EXPECT_EQ(r()->error(), R"(error: no matching call to length(f32, f32)
2 candidate functions:
length(T) -> T where: T is f32 or f16
length(vecN<T>) -> T where: T is f32 or f16
length(T) -> T where: T is abstract-float, f32 or f16
length(vecN<T>) -> T where: T is abstract-float, f32 or f16
)");
}

View File

@@ -890,6 +890,24 @@ utils::Result<NumberT> ConstEval::Det2(const Source& source,
return r;
}
template <typename NumberT>
utils::Result<NumberT> ConstEval::Sqrt(const Source& source, NumberT v) {
if (v < NumberT(0)) {
AddError("sqrt must be called with a value >= 0", source);
return utils::Failure;
}
return NumberT{std::sqrt(v)};
}
auto ConstEval::SqrtFunc(const Source& source, const sem::Type* elem_ty) {
return [=](auto v) -> ImplResult {
if (auto r = Sqrt(source, v)) {
return CreateElement(builder, source, elem_ty, r.Get());
}
return utils::Failure;
};
}
template <typename NumberT>
utils::Result<NumberT> ConstEval::Clamp(const Source&, NumberT e, NumberT low, NumberT high) {
return NumberT{std::min(std::max(e, low), high)};
@@ -968,6 +986,33 @@ auto ConstEval::Dot4Func(const Source& source, const sem::Type* elem_ty) {
};
}
ConstEval::Result ConstEval::Dot(const Source& source,
const sem::Constant* v1,
const sem::Constant* v2) {
auto* vec_ty = v1->Type()->As<sem::Vector>();
TINT_ASSERT(Resolver, vec_ty);
auto* elem_ty = vec_ty->type();
switch (vec_ty->Width()) {
case 2:
return Dispatch_fia_fiu32_f16( //
Dot2Func(source, elem_ty), //
v1->Index(0), v1->Index(1), //
v2->Index(0), v2->Index(1));
case 3:
return Dispatch_fia_fiu32_f16( //
Dot3Func(source, elem_ty), //
v1->Index(0), v1->Index(1), v1->Index(2), //
v2->Index(0), v2->Index(1), v2->Index(2));
case 4:
return Dispatch_fia_fiu32_f16( //
Dot4Func(source, elem_ty), //
v1->Index(0), v1->Index(1), v1->Index(2), v1->Index(3), //
v2->Index(0), v2->Index(1), v2->Index(2), v2->Index(3));
}
TINT_ICE(Resolver, builder.Diagnostics()) << "Expected vector";
return utils::Failure;
}
auto ConstEval::Det2Func(const Source& source, const sem::Type* elem_ty) {
return [=](auto a, auto b, auto c, auto d) -> ImplResult {
if (auto r = Det2(source, a, b, c, d)) {
@@ -1969,30 +2014,10 @@ ConstEval::Result ConstEval::degrees(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0]);
}
ConstEval::Result ConstEval::dot(const sem::Type* ty,
ConstEval::Result ConstEval::dot(const sem::Type*,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto calculate = [&]() -> ImplResult {
auto* v1 = args[0];
auto* v2 = args[1];
auto* vec_ty = v1->Type()->As<sem::Vector>();
switch (vec_ty->Width()) {
case 2:
return Dispatch_fia_fiu32_f16(Dot2Func(source, ty), v1->Index(0), v1->Index(1),
v2->Index(0), v2->Index(1));
case 3:
return Dispatch_fia_fiu32_f16(Dot3Func(source, ty), v1->Index(0), v1->Index(1),
v1->Index(2), v2->Index(0), v2->Index(1),
v2->Index(2));
case 4:
return Dispatch_fia_fiu32_f16(Dot4Func(source, ty), v1->Index(0), v1->Index(1),
v1->Index(2), v1->Index(3), v2->Index(0),
v2->Index(1), v2->Index(2), v2->Index(3));
}
TINT_ICE(Resolver, builder.Diagnostics()) << "Expected scalar or vector";
return utils::Failure;
};
auto r = calculate();
auto r = Dot(source, args[0], args[1]);
if (!r) {
AddNote("when calculating dot", source);
}
@@ -2188,6 +2213,35 @@ ConstEval::Result ConstEval::insertBits(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0], args[1]);
}
ConstEval::Result ConstEval::length(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto calculate = [&]() -> ImplResult {
auto* vec_ty = args[0]->Type()->As<sem::Vector>();
// Evaluates to the absolute value of e if T is scalar.
if (vec_ty == nullptr) {
auto create = [&](auto e) {
using NumberT = decltype(e);
return CreateElement(builder, source, ty, NumberT{std::abs(e)});
};
return Dispatch_fa_f32_f16(create, args[0]);
}
// Evaluates to sqrt(e[0]^2 + e[1]^2 + ...) if T is a vector type.
auto d = Dot(source, args[0], args[0]);
if (!d) {
return utils::Failure;
}
return Dispatch_fa_f32_f16(SqrtFunc(source, ty), d.Get());
};
auto r = calculate();
if (!r) {
AddNote("when calculating length", source);
}
return r;
}
ConstEval::Result ConstEval::max(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
@@ -2561,15 +2615,7 @@ ConstEval::Result ConstEval::sqrt(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto transform = [&](const sem::Constant* c0) {
auto create = [&](auto i) -> ImplResult {
using NumberT = decltype(i);
if (i < NumberT(0)) {
AddError("sqrt must be called with a value >= 0", source);
return utils::Failure;
}
return CreateElement(builder, source, c0->Type(), NumberT(std::sqrt(i.value)));
};
return Dispatch_fa_f32_f16(create, c0);
return Dispatch_fa_f32_f16(SqrtFunc(source, c0->Type()), c0);
};
return TransformElements(builder, ty, transform, args[0]);

View File

@@ -601,6 +601,15 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// length builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result length(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// max builtin
/// @param ty the expression type
/// @param args the input arguments
@@ -962,6 +971,9 @@ class ConstEval {
NumberT b1,
NumberT b2);
template <typename NumberT>
utils::Result<NumberT> Sqrt(const Source& source, NumberT v);
/// Clamps e between low and high
/// @param source the source location
/// @param e the number to clamp
@@ -1034,6 +1046,20 @@ class ConstEval {
/// @returns the callable function
auto ClampFunc(const Source& source, const sem::Type* elem_ty);
/// Returns a callable that calls SqrtFunc, and creates a Constant with its
/// result of type `elem_ty` if successful, or returns Failure otherwise.
/// @param source the source location
/// @param elem_ty the element type of the Constant to create on success
/// @returns the callable function
auto SqrtFunc(const Source& source, const sem::Type* elem_ty);
/// Returns the dot product of v1 and v2.
/// @param source the source location
/// @param v1 the first vector
/// @param v2 the second vector
/// @returns the dot product
Result Dot(const Source& source, const sem::Constant* v1, const sem::Constant* v2);
ProgramBuilder& builder;
};

View File

@@ -41,10 +41,12 @@ struct Case {
return *this;
}
/// Expected value should be compared using FLOAT_EQ instead of EQ
Case& FloatComp() {
/// Expected value should be compared using EXPECT_FLOAT_EQ instead of EXPECT_EQ.
/// If optional epsilon is passed in, will be compared using EXPECT_NEAR with that epsilon.
Case& FloatComp(std::optional<double> epsilon = {}) {
Success s = expected.Get();
s.flags.float_compare = true;
s.flags.float_compare_epsilon = epsilon;
expected = s;
return *this;
}
@@ -1176,6 +1178,65 @@ INSTANTIATE_TEST_SUITE_P( //
testing::ValuesIn(Concat(ExtractBitsCases<i32>(), //
ExtractBitsCases<u32>()))));
template <typename T>
std::vector<Case> LengthCases() {
const auto kSqrtOfHighest = T(std::sqrt(T::Highest()));
const auto kSqrtOfHighestSquared = T(kSqrtOfHighest * kSqrtOfHighest);
auto error_msg = [](auto a, const char* op, auto b) {
return "12:34 error: " + OverflowErrorMessage(a, op, b) + R"(
12:34 note: when calculating length)";
};
return {
C({T(0)}, T(0)),
C({Vec(T(0), T(0))}, Val(T(0))),
C({Vec(T(0), T(0), T(0))}, Val(T(0))),
C({Vec(T(0), T(0), T(0), T(0))}, Val(T(0))),
C({T(1)}, T(1)),
C({Vec(T(1), T(1))}, Val(T(std::sqrt(2)))),
C({Vec(T(1), T(1), T(1))}, Val(T(std::sqrt(3)))),
C({Vec(T(1), T(1), T(1), T(1))}, Val(T(std::sqrt(4)))),
C({T(2)}, T(2)),
C({Vec(T(2), T(2))}, Val(T(std::sqrt(8)))),
C({Vec(T(2), T(2), T(2))}, Val(T(std::sqrt(12)))),
C({Vec(T(2), T(2), T(2), T(2))}, Val(T(std::sqrt(16)))),
C({Vec(T(2), T(3))}, Val(T(std::sqrt(13)))),
C({Vec(T(2), T(3), T(4))}, Val(T(std::sqrt(29)))),
C({Vec(T(2), T(3), T(4), T(5))}, Val(T(std::sqrt(54)))),
C({T(-5)}, T(5)),
C({T::Highest()}, T::Highest()),
C({T::Lowest()}, T::Highest()),
C({Vec(T(-2), T(-3), T(-4), T(-5))}, Val(T(std::sqrt(54)))),
C({Vec(T(2), T(-3), T(4), T(-5))}, Val(T(std::sqrt(54)))),
C({Vec(T(-2), T(3), T(-4), T(5))}, Val(T(std::sqrt(54)))),
C({Vec(kSqrtOfHighest, T(0))}, Val(kSqrtOfHighest)).FloatComp(0.2),
C({Vec(T(0), kSqrtOfHighest)}, Val(kSqrtOfHighest)).FloatComp(0.2),
C({Vec(-kSqrtOfHighest, T(0))}, Val(kSqrtOfHighest)).FloatComp(0.2),
C({Vec(T(0), -kSqrtOfHighest)}, Val(kSqrtOfHighest)).FloatComp(0.2),
// Overflow when squaring a term
E({Vec(T::Highest(), T(0))}, error_msg(T::Highest(), "*", T::Highest())),
E({Vec(T(0), T::Highest())}, error_msg(T::Highest(), "*", T::Highest())),
// Overflow when adding squared terms
E({Vec(kSqrtOfHighest, kSqrtOfHighest)},
error_msg(kSqrtOfHighestSquared, "+", kSqrtOfHighestSquared)),
};
}
INSTANTIATE_TEST_SUITE_P( //
Length,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kLength),
testing::ValuesIn(Concat(LengthCases<AFloat>(), //
LengthCases<f32>(),
LengthCases<f16>()))));
template <typename T>
std::vector<Case> MaxCases() {
return {

View File

@@ -16,6 +16,7 @@
#define SRC_TINT_RESOLVER_CONST_EVAL_TEST_H_
#include <limits>
#include <optional>
#include <string>
#include <utility>
@@ -74,8 +75,11 @@ inline auto Abs(const Number<T>& v) {
struct CheckConstantFlags {
/// Expected value may be positive or negative
bool pos_or_neg = false;
/// Expected value should be compared using FLOAT_EQ instead of EQ
/// Expected value should be compared using EXPECT_FLOAT_EQ instead of EQ, or EXPECT_NEAR if
/// float_compare_epsilon is set.
bool float_compare = false;
/// Expected value should be compared using EXPECT_NEAR if float_compare is set.
std::optional<double> float_compare_epsilon;
};
/// CheckConstant checks that @p got_constant, the result value of
@@ -106,18 +110,16 @@ inline void CheckConstant(const sem::Constant* got_constant,
EXPECT_TRUE(std::isnan(got));
} else {
if (flags.pos_or_neg) {
auto got_abs = Abs(got);
if (flags.float_compare) {
EXPECT_FLOAT_EQ(got_abs, expected);
got = Abs(got);
}
if (flags.float_compare) {
if (flags.float_compare_epsilon) {
EXPECT_NEAR(got, expected, *flags.float_compare_epsilon);
} else {
EXPECT_EQ(got_abs, expected);
EXPECT_FLOAT_EQ(got, expected);
}
} else {
if (flags.float_compare) {
EXPECT_FLOAT_EQ(got, expected);
} else {
EXPECT_EQ(got, expected);
}
EXPECT_EQ(got, expected);
}
}
} else {

View File

@@ -12726,24 +12726,24 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[872],
/* return matcher indices */ &kMatcherIndices[3],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::length,
},
{
/* [368] */
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[873],
/* return matcher indices */ &kMatcherIndices[3],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::length,
},
{
/* [369] */
@@ -14308,8 +14308,8 @@ constexpr IntrinsicInfo kBuiltins[] = {
},
{
/* [47] */
/* fn length<T : f32_f16>(T) -> T */
/* fn length<N : num, T : f32_f16>(vec<N, T>) -> T */
/* fn length<T : fa_f32_f16>(@test_value(0) T) -> T */
/* fn length<N : num, T : fa_f32_f16>(@test_value(0) vec<N, T>) -> T */
/* num overloads */ 2,
/* overloads */ &kOverloads[367],
},