mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-21 10:49:14 +00:00
tint: const eval of determinant builtin
Bug: tint:1581 Change-Id: Ifed8202ba2346eee435ee4e3d0e82ab614a86255 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111281 Reviewed-by: Ben Clayton <bclayton@google.com> Reviewed-by: Dan Sinclair <dsinclair@chromium.org> Commit-Queue: Antonio Maiorano <amaiorano@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Kokoro: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
committed by
Dawn LUCI CQ
parent
43fe9bec16
commit
05c8daac42
@@ -1932,7 +1932,7 @@ TEST_F(ResolverBuiltinTest, Determinant_NotSquare) {
|
||||
EXPECT_EQ(r()->error(), R"(error: no matching call to determinant(mat2x3<f32>)
|
||||
|
||||
1 candidate function:
|
||||
determinant(matNxN<T>) -> T where: T is f32 or f16
|
||||
determinant(matNxN<T>) -> T where: T is abstract-float, f32 or f16
|
||||
)");
|
||||
}
|
||||
|
||||
@@ -1947,7 +1947,7 @@ TEST_F(ResolverBuiltinTest, Determinant_NotMatrix) {
|
||||
EXPECT_EQ(r()->error(), R"(error: no matching call to determinant(f32)
|
||||
|
||||
1 candidate function:
|
||||
determinant(matNxN<T>) -> T where: T is f32 or f16
|
||||
determinant(matNxN<T>) -> T where: T is abstract-float, f32 or f16
|
||||
)");
|
||||
}
|
||||
|
||||
|
||||
@@ -893,15 +893,22 @@ utils::Result<NumberT> ConstEval::Dot4(const Source& source,
|
||||
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> ConstEval::Det2(const Source& source,
|
||||
NumberT a1,
|
||||
NumberT a2,
|
||||
NumberT b1,
|
||||
NumberT b2) {
|
||||
auto r1 = Mul(source, a1, b2);
|
||||
NumberT a,
|
||||
NumberT b,
|
||||
NumberT c,
|
||||
NumberT d) {
|
||||
// | a c |
|
||||
// | b d |
|
||||
//
|
||||
// =
|
||||
//
|
||||
// a * d - c * b
|
||||
|
||||
auto r1 = Mul(source, a, d);
|
||||
if (!r1) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto r2 = Mul(source, b1, a2);
|
||||
auto r2 = Mul(source, c, b);
|
||||
if (!r2) {
|
||||
return utils::Failure;
|
||||
}
|
||||
@@ -912,6 +919,129 @@ utils::Result<NumberT> ConstEval::Det2(const Source& source,
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> ConstEval::Det3(const Source& source,
|
||||
NumberT a,
|
||||
NumberT b,
|
||||
NumberT c,
|
||||
NumberT d,
|
||||
NumberT e,
|
||||
NumberT f,
|
||||
NumberT g,
|
||||
NumberT h,
|
||||
NumberT i) {
|
||||
// | a d g |
|
||||
// | b e h |
|
||||
// | c f i |
|
||||
//
|
||||
// =
|
||||
//
|
||||
// a | e h | - d | b h | + g | b e |
|
||||
// | f i | | c i | | c f |
|
||||
|
||||
auto det1 = Det2(source, e, f, h, i);
|
||||
if (!det1) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto a_det1 = Mul(source, a, det1.Get());
|
||||
if (!a_det1) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto det2 = Det2(source, b, c, h, i);
|
||||
if (!det2) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto d_det2 = Mul(source, d, det2.Get());
|
||||
if (!d_det2) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto det3 = Det2(source, b, c, e, f);
|
||||
if (!det3) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto g_det3 = Mul(source, g, det3.Get());
|
||||
if (!g_det3) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto r = Sub(source, a_det1.Get(), d_det2.Get());
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
return Add(source, r.Get(), g_det3.Get());
|
||||
}
|
||||
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> ConstEval::Det4(const Source& source,
|
||||
NumberT a,
|
||||
NumberT b,
|
||||
NumberT c,
|
||||
NumberT d,
|
||||
NumberT e,
|
||||
NumberT f,
|
||||
NumberT g,
|
||||
NumberT h,
|
||||
NumberT i,
|
||||
NumberT j,
|
||||
NumberT k,
|
||||
NumberT l,
|
||||
NumberT m,
|
||||
NumberT n,
|
||||
NumberT o,
|
||||
NumberT p) {
|
||||
// | a e i m |
|
||||
// | b f j n |
|
||||
// | c g k o |
|
||||
// | d h l p |
|
||||
//
|
||||
// =
|
||||
//
|
||||
// a | f j n | - e | b j n | + i | b f n | - m | b f j |
|
||||
// | g k o | | c k o | | c g o | | c g k |
|
||||
// | h l p | | d l p | | d h p | | d h l |
|
||||
|
||||
auto det1 = Det3(source, f, g, h, j, k, l, n, o, p);
|
||||
if (!det1) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto a_det1 = Mul(source, a, det1.Get());
|
||||
if (!a_det1) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto det2 = Det3(source, b, c, d, j, k, l, n, o, p);
|
||||
if (!det2) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto e_det2 = Mul(source, e, det2.Get());
|
||||
if (!e_det2) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto det3 = Det3(source, b, c, d, f, g, h, n, o, p);
|
||||
if (!det3) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto i_det3 = Mul(source, i, det3.Get());
|
||||
if (!i_det3) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto det4 = Det3(source, b, c, d, f, g, h, j, k, l);
|
||||
if (!det4) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto m_det4 = Mul(source, m, det4.Get());
|
||||
if (!m_det4) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto r = Sub(source, a_det1.Get(), e_det2.Get());
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
r = Add(source, r.Get(), i_det3.Get());
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
return Sub(source, r.Get(), m_det4.Get());
|
||||
}
|
||||
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> ConstEval::Sqrt(const Source& source, NumberT v) {
|
||||
if (v < NumberT(0)) {
|
||||
@@ -1044,6 +1174,26 @@ auto ConstEval::Det2Func(const Source& source, const sem::Type* elem_ty) {
|
||||
};
|
||||
}
|
||||
|
||||
auto ConstEval::Det3Func(const Source& source, const sem::Type* elem_ty) {
|
||||
return
|
||||
[=](auto a, auto b, auto c, auto d, auto e, auto f, auto g, auto h, auto i) -> ImplResult {
|
||||
if (auto r = Det3(source, a, b, c, d, e, f, g, h, i)) {
|
||||
return CreateElement(builder, source, elem_ty, r.Get());
|
||||
}
|
||||
return utils::Failure;
|
||||
};
|
||||
}
|
||||
|
||||
auto ConstEval::Det4Func(const Source& source, const sem::Type* elem_ty) {
|
||||
return [=](auto a, auto b, auto c, auto d, auto e, auto f, auto g, auto h, auto i, auto j,
|
||||
auto k, auto l, auto m, auto n, auto o, auto p) -> ImplResult {
|
||||
if (auto r = Det4(source, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p)) {
|
||||
return CreateElement(builder, source, elem_ty, r.Get());
|
||||
}
|
||||
return utils::Failure;
|
||||
};
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::Literal(const sem::Type* ty, const ast::LiteralExpression* literal) {
|
||||
auto& source = literal->source;
|
||||
return Switch(
|
||||
@@ -2036,6 +2186,41 @@ ConstEval::Result ConstEval::degrees(const sem::Type* ty,
|
||||
return TransformElements(builder, ty, transform, args[0]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::determinant(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto calculate = [&]() -> ImplResult {
|
||||
auto* m = args[0];
|
||||
auto* mat_ty = m->Type()->As<sem::Matrix>();
|
||||
auto me = [&](size_t r, size_t c) { return m->Index(c)->Index(r); };
|
||||
switch (mat_ty->rows()) {
|
||||
case 2:
|
||||
return Dispatch_fa_f32_f16(Det2Func(source, ty), //
|
||||
me(0, 0), me(1, 0), //
|
||||
me(0, 1), me(1, 1));
|
||||
|
||||
case 3:
|
||||
return Dispatch_fa_f32_f16(Det3Func(source, ty), //
|
||||
me(0, 0), me(1, 0), me(2, 0), //
|
||||
me(0, 1), me(1, 1), me(2, 1), //
|
||||
me(0, 2), me(1, 2), me(2, 2));
|
||||
|
||||
case 4:
|
||||
return Dispatch_fa_f32_f16(Det4Func(source, ty), //
|
||||
me(0, 0), me(1, 0), me(2, 0), me(3, 0), //
|
||||
me(0, 1), me(1, 1), me(2, 1), me(3, 1), //
|
||||
me(0, 2), me(1, 2), me(2, 2), me(3, 2), //
|
||||
me(0, 3), me(1, 3), me(2, 3), me(3, 3));
|
||||
}
|
||||
TINT_ICE(Resolver, builder.Diagnostics()) << "Unexpected number of matrix rows";
|
||||
return utils::Failure;
|
||||
};
|
||||
auto r = calculate();
|
||||
if (!r) {
|
||||
AddNote("when calculating determinant", source);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
ConstEval::Result ConstEval::dot(const sem::Type*,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
|
||||
@@ -539,6 +539,7 @@ class ConstEval {
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// degrees builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location of the conversion
|
||||
@@ -547,6 +548,15 @@ class ConstEval {
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// determinant builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location of the conversion
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
Result determinant(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// dot builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
@@ -1012,18 +1022,87 @@ class ConstEval {
|
||||
NumberT b3,
|
||||
NumberT b4);
|
||||
|
||||
/// Returns the determinant of the 2x2 matrix [(a1, a2), (b1, b2)]
|
||||
/// Returns the determinant of the 2x2 matrix:
|
||||
/// | a c |
|
||||
/// | b d |
|
||||
/// @param source the source location
|
||||
/// @param a1 component 1 of the first column vector
|
||||
/// @param a2 component 2 of the first column vector
|
||||
/// @param b1 component 1 of the second column vector
|
||||
/// @param b2 component 2 of the second column vector
|
||||
/// @param a component 1 of the first column vector
|
||||
/// @param b component 2 of the first column vector
|
||||
/// @param c component 1 of the second column vector
|
||||
/// @param d component 2 of the second column vector
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> Det2(const Source& source,
|
||||
NumberT a1,
|
||||
NumberT a2,
|
||||
NumberT b1,
|
||||
NumberT b2);
|
||||
utils::Result<NumberT> Det2(const Source& source, //
|
||||
NumberT a,
|
||||
NumberT b,
|
||||
NumberT c,
|
||||
NumberT d);
|
||||
|
||||
/// Returns the determinant of the 3x3 matrix:
|
||||
/// | a d g |
|
||||
/// | b e h |
|
||||
/// | c f i |
|
||||
/// @param source the source location
|
||||
/// @param a component 1 of the first column vector
|
||||
/// @param b component 2 of the first column vector
|
||||
/// @param c component 3 of the first column vector
|
||||
/// @param d component 1 of the second column vector
|
||||
/// @param e component 2 of the second column vector
|
||||
/// @param f component 3 of the second column vector
|
||||
/// @param g component 1 of the third column vector
|
||||
/// @param h component 2 of the third column vector
|
||||
/// @param i component 3 of the third column vector
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> Det3(const Source& source,
|
||||
NumberT a,
|
||||
NumberT b,
|
||||
NumberT c,
|
||||
NumberT d,
|
||||
NumberT e,
|
||||
NumberT f,
|
||||
NumberT g,
|
||||
NumberT h,
|
||||
NumberT i);
|
||||
|
||||
/// Returns the determinant of the 4x4 matrix:
|
||||
/// | a e i m |
|
||||
/// | b f j n |
|
||||
/// | c g k o |
|
||||
/// | d h l p |
|
||||
/// @param source the source location
|
||||
/// @param a component 1 of the first column vector
|
||||
/// @param b component 2 of the first column vector
|
||||
/// @param c component 3 of the first column vector
|
||||
/// @param d component 4 of the first column vector
|
||||
/// @param e component 1 of the second column vector
|
||||
/// @param f component 2 of the second column vector
|
||||
/// @param g component 3 of the second column vector
|
||||
/// @param h component 4 of the second column vector
|
||||
/// @param i component 1 of the third column vector
|
||||
/// @param j component 2 of the third column vector
|
||||
/// @param k component 3 of the third column vector
|
||||
/// @param l component 4 of the third column vector
|
||||
/// @param m component 1 of the fourth column vector
|
||||
/// @param n component 2 of the fourth column vector
|
||||
/// @param o component 3 of the fourth column vector
|
||||
/// @param p component 4 of the fourth column vector
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> Det4(const Source& source,
|
||||
NumberT a,
|
||||
NumberT b,
|
||||
NumberT c,
|
||||
NumberT d,
|
||||
NumberT e,
|
||||
NumberT f,
|
||||
NumberT g,
|
||||
NumberT h,
|
||||
NumberT i,
|
||||
NumberT j,
|
||||
NumberT k,
|
||||
NumberT l,
|
||||
NumberT m,
|
||||
NumberT n,
|
||||
NumberT o,
|
||||
NumberT p);
|
||||
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> Sqrt(const Source& source, NumberT v);
|
||||
@@ -1093,6 +1172,20 @@ class ConstEval {
|
||||
/// @returns the callable function
|
||||
auto Det2Func(const Source& source, const sem::Type* elem_ty);
|
||||
|
||||
/// Returns a callable that calls Det3, and creates a Constant with its result of type `elem_ty`
|
||||
/// if successful, or returns Failure otherwise.
|
||||
/// @param source the source location
|
||||
/// @param elem_ty the element type of the Constant to create on success
|
||||
/// @returns the callable function
|
||||
auto Det3Func(const Source& source, const sem::Type* elem_ty);
|
||||
|
||||
/// Returns a callable that calls Det4, and creates a Constant with its result of type `elem_ty`
|
||||
/// if successful, or returns Failure otherwise.
|
||||
/// @param source the source location
|
||||
/// @param elem_ty the element type of the Constant to create on success
|
||||
/// @returns the callable function
|
||||
auto Det4Func(const Source& source, const sem::Type* elem_ty);
|
||||
|
||||
/// Returns a callable that calls Clamp, and creates a Constant with its result of type
|
||||
/// `elem_ty` if successful, or returns Failure otherwise.
|
||||
/// @param source the source location
|
||||
|
||||
@@ -853,6 +853,83 @@ INSTANTIATE_TEST_SUITE_P( //
|
||||
DotCases<f32>(), //
|
||||
DotCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> DeterminantCases() {
|
||||
auto error_msg = [](auto a, const char* op, auto b) {
|
||||
return "12:34 error: " + OverflowErrorMessage(a, op, b) + R"(
|
||||
12:34 note: when calculating determinant)";
|
||||
};
|
||||
|
||||
auto r = std::vector<Case>{
|
||||
// All zero == 0
|
||||
C({Mat({T(0), T(0)}, //
|
||||
{T(0), T(0)})}, //
|
||||
Val(T(0))),
|
||||
|
||||
C({Mat({T(0), T(0), T(0)}, //
|
||||
{T(0), T(0), T(0)}, //
|
||||
{T(0), T(0), T(0)})}, //
|
||||
Val(T(0))),
|
||||
|
||||
C({Mat({T(0), T(0), T(0), T(0)}, //
|
||||
{T(0), T(0), T(0), T(0)}, //
|
||||
{T(0), T(0), T(0), T(0)}, //
|
||||
{T(0), T(0), T(0), T(0)})}, //
|
||||
Val(T(0))),
|
||||
|
||||
// All same == 0
|
||||
C({Mat({T(42), T(42)}, //
|
||||
{T(42), T(42)})}, //
|
||||
Val(T(0))),
|
||||
|
||||
C({Mat({T(42), T(42), T(42)}, //
|
||||
{T(42), T(42), T(42)}, //
|
||||
{T(42), T(42), T(42)})}, //
|
||||
Val(T(0))),
|
||||
|
||||
C({Mat({T(42), T(42), T(42), T(42)}, //
|
||||
{T(42), T(42), T(42), T(42)}, //
|
||||
{T(42), T(42), T(42), T(42)}, //
|
||||
{T(42), T(42), T(42), T(42)})}, //
|
||||
Val(T(0))),
|
||||
|
||||
// Various values
|
||||
C({Mat({-T(2), T(17)}, //
|
||||
{T(5), T(45)})}, //
|
||||
Val(-T(175))),
|
||||
|
||||
C({Mat({T(4), T(6), -T(13)}, //
|
||||
{T(12), T(5), T(8)}, //
|
||||
{T(9), T(17), T(16)})}, //
|
||||
Val(-T(3011))),
|
||||
|
||||
C({Mat({T(2), T(9), T(8), T(1)}, //
|
||||
{-T(4), T(11), -T(3), T(7)}, //
|
||||
{T(6), T(5), T(12), -T(6)}, //
|
||||
{T(3), -T(10), T(4), -T(7)})}, //
|
||||
Val(T(469))),
|
||||
|
||||
// Overflow during multiply
|
||||
E({Mat({T::Highest(), T(0)}, //
|
||||
{T(0), T(2)})}, //
|
||||
error_msg(T::Highest(), "*", T(2))),
|
||||
|
||||
// Overflow during subtract
|
||||
E({Mat({T::Highest(), T::Lowest()}, //
|
||||
{T(1), T(1)})}, //
|
||||
error_msg(T::Highest(), "-", T::Lowest())),
|
||||
};
|
||||
|
||||
return r;
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
Determinant,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kDeterminant),
|
||||
testing::ValuesIn(Concat(DeterminantCases<AFloat>(), //
|
||||
DeterminantCases<f32>(), //
|
||||
DeterminantCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> FirstLeadingBitCases() {
|
||||
using B = BitValues<T>;
|
||||
|
||||
@@ -288,6 +288,16 @@ using Types = std::variant< //
|
||||
Value<builder::mat2x2<f32>>,
|
||||
Value<builder::mat2x2<f16>>,
|
||||
|
||||
Value<builder::mat3x3<AInt>>,
|
||||
Value<builder::mat3x3<AFloat>>,
|
||||
Value<builder::mat3x3<f32>>,
|
||||
Value<builder::mat3x3<f16>>,
|
||||
|
||||
Value<builder::mat4x4<AInt>>,
|
||||
Value<builder::mat4x4<AFloat>>,
|
||||
Value<builder::mat4x4<f32>>,
|
||||
Value<builder::mat4x4<f16>>,
|
||||
|
||||
Value<builder::mat2x3<AInt>>,
|
||||
Value<builder::mat2x3<AFloat>>,
|
||||
Value<builder::mat2x3<f32>>,
|
||||
|
||||
@@ -13566,12 +13566,12 @@ constexpr OverloadInfo kOverloads[] = {
|
||||
/* num parameters */ 1,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[4],
|
||||
/* parameters */ &kParameters[837],
|
||||
/* return matcher indices */ &kMatcherIndices[3],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::determinant,
|
||||
},
|
||||
{
|
||||
/* [438] */
|
||||
@@ -14124,7 +14124,7 @@ constexpr IntrinsicInfo kBuiltins[] = {
|
||||
},
|
||||
{
|
||||
/* [20] */
|
||||
/* fn determinant<N : num, T : f32_f16>(mat<N, N, T>) -> T */
|
||||
/* fn determinant<N : num, T : fa_f32_f16>(mat<N, N, T>) -> T */
|
||||
/* num overloads */ 1,
|
||||
/* overloads */ &kOverloads[437],
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user