tint: const eval of determinant builtin

Bug: tint:1581
Change-Id: Ifed8202ba2346eee435ee4e3d0e82ab614a86255
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111281
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Kokoro: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Antonio Maiorano
2022-11-23 19:16:15 +00:00
committed by Dawn LUCI CQ
parent 43fe9bec16
commit 05c8daac42
63 changed files with 1482 additions and 229 deletions

View File

@@ -1932,7 +1932,7 @@ TEST_F(ResolverBuiltinTest, Determinant_NotSquare) {
EXPECT_EQ(r()->error(), R"(error: no matching call to determinant(mat2x3<f32>)
1 candidate function:
determinant(matNxN<T>) -> T where: T is f32 or f16
determinant(matNxN<T>) -> T where: T is abstract-float, f32 or f16
)");
}
@@ -1947,7 +1947,7 @@ TEST_F(ResolverBuiltinTest, Determinant_NotMatrix) {
EXPECT_EQ(r()->error(), R"(error: no matching call to determinant(f32)
1 candidate function:
determinant(matNxN<T>) -> T where: T is f32 or f16
determinant(matNxN<T>) -> T where: T is abstract-float, f32 or f16
)");
}

View File

@@ -893,15 +893,22 @@ utils::Result<NumberT> ConstEval::Dot4(const Source& source,
template <typename NumberT>
utils::Result<NumberT> ConstEval::Det2(const Source& source,
NumberT a1,
NumberT a2,
NumberT b1,
NumberT b2) {
auto r1 = Mul(source, a1, b2);
NumberT a,
NumberT b,
NumberT c,
NumberT d) {
// | a c |
// | b d |
//
// =
//
// a * d - c * b
auto r1 = Mul(source, a, d);
if (!r1) {
return utils::Failure;
}
auto r2 = Mul(source, b1, a2);
auto r2 = Mul(source, c, b);
if (!r2) {
return utils::Failure;
}
@@ -912,6 +919,129 @@ utils::Result<NumberT> ConstEval::Det2(const Source& source,
return r;
}
template <typename NumberT>
utils::Result<NumberT> ConstEval::Det3(const Source& source,
NumberT a,
NumberT b,
NumberT c,
NumberT d,
NumberT e,
NumberT f,
NumberT g,
NumberT h,
NumberT i) {
// | a d g |
// | b e h |
// | c f i |
//
// =
//
// a | e h | - d | b h | + g | b e |
// | f i | | c i | | c f |
auto det1 = Det2(source, e, f, h, i);
if (!det1) {
return utils::Failure;
}
auto a_det1 = Mul(source, a, det1.Get());
if (!a_det1) {
return utils::Failure;
}
auto det2 = Det2(source, b, c, h, i);
if (!det2) {
return utils::Failure;
}
auto d_det2 = Mul(source, d, det2.Get());
if (!d_det2) {
return utils::Failure;
}
auto det3 = Det2(source, b, c, e, f);
if (!det3) {
return utils::Failure;
}
auto g_det3 = Mul(source, g, det3.Get());
if (!g_det3) {
return utils::Failure;
}
auto r = Sub(source, a_det1.Get(), d_det2.Get());
if (!r) {
return utils::Failure;
}
return Add(source, r.Get(), g_det3.Get());
}
template <typename NumberT>
utils::Result<NumberT> ConstEval::Det4(const Source& source,
NumberT a,
NumberT b,
NumberT c,
NumberT d,
NumberT e,
NumberT f,
NumberT g,
NumberT h,
NumberT i,
NumberT j,
NumberT k,
NumberT l,
NumberT m,
NumberT n,
NumberT o,
NumberT p) {
// | a e i m |
// | b f j n |
// | c g k o |
// | d h l p |
//
// =
//
// a | f j n | - e | b j n | + i | b f n | - m | b f j |
// | g k o | | c k o | | c g o | | c g k |
// | h l p | | d l p | | d h p | | d h l |
auto det1 = Det3(source, f, g, h, j, k, l, n, o, p);
if (!det1) {
return utils::Failure;
}
auto a_det1 = Mul(source, a, det1.Get());
if (!a_det1) {
return utils::Failure;
}
auto det2 = Det3(source, b, c, d, j, k, l, n, o, p);
if (!det2) {
return utils::Failure;
}
auto e_det2 = Mul(source, e, det2.Get());
if (!e_det2) {
return utils::Failure;
}
auto det3 = Det3(source, b, c, d, f, g, h, n, o, p);
if (!det3) {
return utils::Failure;
}
auto i_det3 = Mul(source, i, det3.Get());
if (!i_det3) {
return utils::Failure;
}
auto det4 = Det3(source, b, c, d, f, g, h, j, k, l);
if (!det4) {
return utils::Failure;
}
auto m_det4 = Mul(source, m, det4.Get());
if (!m_det4) {
return utils::Failure;
}
auto r = Sub(source, a_det1.Get(), e_det2.Get());
if (!r) {
return utils::Failure;
}
r = Add(source, r.Get(), i_det3.Get());
if (!r) {
return utils::Failure;
}
return Sub(source, r.Get(), m_det4.Get());
}
template <typename NumberT>
utils::Result<NumberT> ConstEval::Sqrt(const Source& source, NumberT v) {
if (v < NumberT(0)) {
@@ -1044,6 +1174,26 @@ auto ConstEval::Det2Func(const Source& source, const sem::Type* elem_ty) {
};
}
auto ConstEval::Det3Func(const Source& source, const sem::Type* elem_ty) {
return
[=](auto a, auto b, auto c, auto d, auto e, auto f, auto g, auto h, auto i) -> ImplResult {
if (auto r = Det3(source, a, b, c, d, e, f, g, h, i)) {
return CreateElement(builder, source, elem_ty, r.Get());
}
return utils::Failure;
};
}
auto ConstEval::Det4Func(const Source& source, const sem::Type* elem_ty) {
return [=](auto a, auto b, auto c, auto d, auto e, auto f, auto g, auto h, auto i, auto j,
auto k, auto l, auto m, auto n, auto o, auto p) -> ImplResult {
if (auto r = Det4(source, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p)) {
return CreateElement(builder, source, elem_ty, r.Get());
}
return utils::Failure;
};
}
ConstEval::Result ConstEval::Literal(const sem::Type* ty, const ast::LiteralExpression* literal) {
auto& source = literal->source;
return Switch(
@@ -2036,6 +2186,41 @@ ConstEval::Result ConstEval::degrees(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0]);
}
ConstEval::Result ConstEval::determinant(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto calculate = [&]() -> ImplResult {
auto* m = args[0];
auto* mat_ty = m->Type()->As<sem::Matrix>();
auto me = [&](size_t r, size_t c) { return m->Index(c)->Index(r); };
switch (mat_ty->rows()) {
case 2:
return Dispatch_fa_f32_f16(Det2Func(source, ty), //
me(0, 0), me(1, 0), //
me(0, 1), me(1, 1));
case 3:
return Dispatch_fa_f32_f16(Det3Func(source, ty), //
me(0, 0), me(1, 0), me(2, 0), //
me(0, 1), me(1, 1), me(2, 1), //
me(0, 2), me(1, 2), me(2, 2));
case 4:
return Dispatch_fa_f32_f16(Det4Func(source, ty), //
me(0, 0), me(1, 0), me(2, 0), me(3, 0), //
me(0, 1), me(1, 1), me(2, 1), me(3, 1), //
me(0, 2), me(1, 2), me(2, 2), me(3, 2), //
me(0, 3), me(1, 3), me(2, 3), me(3, 3));
}
TINT_ICE(Resolver, builder.Diagnostics()) << "Unexpected number of matrix rows";
return utils::Failure;
};
auto r = calculate();
if (!r) {
AddNote("when calculating determinant", source);
}
return r;
}
ConstEval::Result ConstEval::dot(const sem::Type*,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {

View File

@@ -539,6 +539,7 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// degrees builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
@@ -547,6 +548,15 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// determinant builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result determinant(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// dot builtin
/// @param ty the expression type
/// @param args the input arguments
@@ -1012,18 +1022,87 @@ class ConstEval {
NumberT b3,
NumberT b4);
/// Returns the determinant of the 2x2 matrix [(a1, a2), (b1, b2)]
/// Returns the determinant of the 2x2 matrix:
/// | a c |
/// | b d |
/// @param source the source location
/// @param a1 component 1 of the first column vector
/// @param a2 component 2 of the first column vector
/// @param b1 component 1 of the second column vector
/// @param b2 component 2 of the second column vector
/// @param a component 1 of the first column vector
/// @param b component 2 of the first column vector
/// @param c component 1 of the second column vector
/// @param d component 2 of the second column vector
template <typename NumberT>
utils::Result<NumberT> Det2(const Source& source,
NumberT a1,
NumberT a2,
NumberT b1,
NumberT b2);
utils::Result<NumberT> Det2(const Source& source, //
NumberT a,
NumberT b,
NumberT c,
NumberT d);
/// Returns the determinant of the 3x3 matrix:
/// | a d g |
/// | b e h |
/// | c f i |
/// @param source the source location
/// @param a component 1 of the first column vector
/// @param b component 2 of the first column vector
/// @param c component 3 of the first column vector
/// @param d component 1 of the second column vector
/// @param e component 2 of the second column vector
/// @param f component 3 of the second column vector
/// @param g component 1 of the third column vector
/// @param h component 2 of the third column vector
/// @param i component 3 of the third column vector
template <typename NumberT>
utils::Result<NumberT> Det3(const Source& source,
NumberT a,
NumberT b,
NumberT c,
NumberT d,
NumberT e,
NumberT f,
NumberT g,
NumberT h,
NumberT i);
/// Returns the determinant of the 4x4 matrix:
/// | a e i m |
/// | b f j n |
/// | c g k o |
/// | d h l p |
/// @param source the source location
/// @param a component 1 of the first column vector
/// @param b component 2 of the first column vector
/// @param c component 3 of the first column vector
/// @param d component 4 of the first column vector
/// @param e component 1 of the second column vector
/// @param f component 2 of the second column vector
/// @param g component 3 of the second column vector
/// @param h component 4 of the second column vector
/// @param i component 1 of the third column vector
/// @param j component 2 of the third column vector
/// @param k component 3 of the third column vector
/// @param l component 4 of the third column vector
/// @param m component 1 of the fourth column vector
/// @param n component 2 of the fourth column vector
/// @param o component 3 of the fourth column vector
/// @param p component 4 of the fourth column vector
template <typename NumberT>
utils::Result<NumberT> Det4(const Source& source,
NumberT a,
NumberT b,
NumberT c,
NumberT d,
NumberT e,
NumberT f,
NumberT g,
NumberT h,
NumberT i,
NumberT j,
NumberT k,
NumberT l,
NumberT m,
NumberT n,
NumberT o,
NumberT p);
template <typename NumberT>
utils::Result<NumberT> Sqrt(const Source& source, NumberT v);
@@ -1093,6 +1172,20 @@ class ConstEval {
/// @returns the callable function
auto Det2Func(const Source& source, const sem::Type* elem_ty);
/// Returns a callable that calls Det3, and creates a Constant with its result of type `elem_ty`
/// if successful, or returns Failure otherwise.
/// @param source the source location
/// @param elem_ty the element type of the Constant to create on success
/// @returns the callable function
auto Det3Func(const Source& source, const sem::Type* elem_ty);
/// Returns a callable that calls Det4, and creates a Constant with its result of type `elem_ty`
/// if successful, or returns Failure otherwise.
/// @param source the source location
/// @param elem_ty the element type of the Constant to create on success
/// @returns the callable function
auto Det4Func(const Source& source, const sem::Type* elem_ty);
/// Returns a callable that calls Clamp, and creates a Constant with its result of type
/// `elem_ty` if successful, or returns Failure otherwise.
/// @param source the source location

View File

@@ -853,6 +853,83 @@ INSTANTIATE_TEST_SUITE_P( //
DotCases<f32>(), //
DotCases<f16>()))));
template <typename T>
std::vector<Case> DeterminantCases() {
auto error_msg = [](auto a, const char* op, auto b) {
return "12:34 error: " + OverflowErrorMessage(a, op, b) + R"(
12:34 note: when calculating determinant)";
};
auto r = std::vector<Case>{
// All zero == 0
C({Mat({T(0), T(0)}, //
{T(0), T(0)})}, //
Val(T(0))),
C({Mat({T(0), T(0), T(0)}, //
{T(0), T(0), T(0)}, //
{T(0), T(0), T(0)})}, //
Val(T(0))),
C({Mat({T(0), T(0), T(0), T(0)}, //
{T(0), T(0), T(0), T(0)}, //
{T(0), T(0), T(0), T(0)}, //
{T(0), T(0), T(0), T(0)})}, //
Val(T(0))),
// All same == 0
C({Mat({T(42), T(42)}, //
{T(42), T(42)})}, //
Val(T(0))),
C({Mat({T(42), T(42), T(42)}, //
{T(42), T(42), T(42)}, //
{T(42), T(42), T(42)})}, //
Val(T(0))),
C({Mat({T(42), T(42), T(42), T(42)}, //
{T(42), T(42), T(42), T(42)}, //
{T(42), T(42), T(42), T(42)}, //
{T(42), T(42), T(42), T(42)})}, //
Val(T(0))),
// Various values
C({Mat({-T(2), T(17)}, //
{T(5), T(45)})}, //
Val(-T(175))),
C({Mat({T(4), T(6), -T(13)}, //
{T(12), T(5), T(8)}, //
{T(9), T(17), T(16)})}, //
Val(-T(3011))),
C({Mat({T(2), T(9), T(8), T(1)}, //
{-T(4), T(11), -T(3), T(7)}, //
{T(6), T(5), T(12), -T(6)}, //
{T(3), -T(10), T(4), -T(7)})}, //
Val(T(469))),
// Overflow during multiply
E({Mat({T::Highest(), T(0)}, //
{T(0), T(2)})}, //
error_msg(T::Highest(), "*", T(2))),
// Overflow during subtract
E({Mat({T::Highest(), T::Lowest()}, //
{T(1), T(1)})}, //
error_msg(T::Highest(), "-", T::Lowest())),
};
return r;
}
INSTANTIATE_TEST_SUITE_P( //
Determinant,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kDeterminant),
testing::ValuesIn(Concat(DeterminantCases<AFloat>(), //
DeterminantCases<f32>(), //
DeterminantCases<f16>()))));
template <typename T>
std::vector<Case> FirstLeadingBitCases() {
using B = BitValues<T>;

View File

@@ -288,6 +288,16 @@ using Types = std::variant< //
Value<builder::mat2x2<f32>>,
Value<builder::mat2x2<f16>>,
Value<builder::mat3x3<AInt>>,
Value<builder::mat3x3<AFloat>>,
Value<builder::mat3x3<f32>>,
Value<builder::mat3x3<f16>>,
Value<builder::mat4x4<AInt>>,
Value<builder::mat4x4<AFloat>>,
Value<builder::mat4x4<f32>>,
Value<builder::mat4x4<f16>>,
Value<builder::mat2x3<AInt>>,
Value<builder::mat2x3<AFloat>>,
Value<builder::mat2x3<f32>>,

View File

@@ -13566,12 +13566,12 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[837],
/* return matcher indices */ &kMatcherIndices[3],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::determinant,
},
{
/* [438] */
@@ -14124,7 +14124,7 @@ constexpr IntrinsicInfo kBuiltins[] = {
},
{
/* [20] */
/* fn determinant<N : num, T : f32_f16>(mat<N, N, T>) -> T */
/* fn determinant<N : num, T : fa_f32_f16>(mat<N, N, T>) -> T */
/* num overloads */ 1,
/* overloads */ &kOverloads[437],
},