mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-16 16:37:08 +00:00
tint: const eval of transpose builtin
Bug: tint:1581 Change-Id: Ia614647bc4a3d24a53d45981ddcdb1c84ea84608 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111600 Commit-Queue: Antonio Maiorano <amaiorano@google.com> Reviewed-by: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
committed by
Dawn LUCI CQ
parent
dd3fb05af7
commit
9ba5f9e2c6
@@ -546,7 +546,7 @@ fn refract<N: num, T: f32_f16>(vec<N, T>, vec<N, T>, T) -> vec<N, T>
|
||||
@const fn tan<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
|
||||
@const fn tanh<T: fa_f32_f16>(T) -> T
|
||||
@const fn tanh<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
|
||||
fn transpose<M: num, N: num, T: f32_f16>(mat<M, N, T>) -> mat<N, M, T>
|
||||
@const fn transpose<M: num, N: num, T: fa_f32_f16>(mat<M, N, T>) -> mat<N, M, T>
|
||||
@const fn trunc<T: fa_f32_f16>(@test_value(1.5) T) -> T
|
||||
@const fn trunc<N: num, T: fa_f32_f16>(@test_value(1.5) vec<N, T>) -> vec<N, T>
|
||||
@const fn unpack2x16float(u32) -> vec2<f32>
|
||||
|
||||
@@ -3031,6 +3031,26 @@ ConstEval::Result ConstEval::tanh(const sem::Type* ty,
|
||||
return TransformElements(builder, ty, transform, args[0]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::transpose(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source&) {
|
||||
auto* m = args[0];
|
||||
auto* mat_ty = m->Type()->As<sem::Matrix>();
|
||||
auto me = [&](size_t r, size_t c) { return m->Index(c)->Index(r); };
|
||||
auto* result_mat_ty = ty->As<sem::Matrix>();
|
||||
|
||||
// Produce column vectors from each row
|
||||
utils::Vector<const sem::Constant*, 4> result_mat;
|
||||
for (size_t r = 0; r < mat_ty->rows(); ++r) {
|
||||
utils::Vector<const sem::Constant*, 4> new_col_vec;
|
||||
for (size_t c = 0; c < mat_ty->columns(); ++c) {
|
||||
new_col_vec.Push(me(r, c));
|
||||
}
|
||||
result_mat.Push(CreateComposite(builder, result_mat_ty->ColumnType(), new_col_vec));
|
||||
}
|
||||
return CreateComposite(builder, ty, result_mat);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::trunc(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
|
||||
@@ -863,6 +863,15 @@ class ConstEval {
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// transpose builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
Result transpose(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// trunc builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
|
||||
@@ -2092,6 +2092,72 @@ INSTANTIATE_TEST_SUITE_P( //
|
||||
TanhCases<f32>(),
|
||||
TanhCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> TransposeCases() {
|
||||
return {
|
||||
// 2x2
|
||||
C({Mat({T(1), T(2)}, //
|
||||
{T(3), T(4)})}, //
|
||||
Mat({T(1), T(3)}, //
|
||||
{T(2), T(4)})),
|
||||
|
||||
// 3x3
|
||||
C({Mat({T(1), T(2), T(3)}, //
|
||||
{T(4), T(5), T(6)}, //
|
||||
{T(7), T(8), T(9)})}, //
|
||||
Mat({T(1), T(4), T(7)}, //
|
||||
{T(2), T(5), T(8)}, //
|
||||
{T(3), T(6), T(9)})),
|
||||
|
||||
// 4x4
|
||||
C({Mat({T(1), T(2), T(3), T(4)}, //
|
||||
{T(5), T(6), T(7), T(8)}, //
|
||||
{T(9), T(10), T(11), T(12)}, //
|
||||
{T(13), T(14), T(15), T(16)})}, //
|
||||
Mat({T(1), T(5), T(9), T(13)}, //
|
||||
{T(2), T(6), T(10), T(14)}, //
|
||||
{T(3), T(7), T(11), T(15)}, //
|
||||
{T(4), T(8), T(12), T(16)})),
|
||||
|
||||
// 4x2
|
||||
C({Mat({T(1), T(2), T(3), T(4)}, //
|
||||
{T(5), T(6), T(7), T(8)})}, //
|
||||
Mat({T(1), T(5)}, //
|
||||
{T(2), T(6)}, //
|
||||
{T(3), T(7)}, //
|
||||
{T(4), T(8)})),
|
||||
|
||||
// 2x4
|
||||
C({Mat({T(1), T(2)}, //
|
||||
{T(3), T(4)}, //
|
||||
{T(5), T(6)}, //
|
||||
{T(7), T(8)})}, //
|
||||
Mat({T(1), T(3), T(5), T(7)}, //
|
||||
{T(2), T(4), T(6), T(8)})),
|
||||
|
||||
// 3x2
|
||||
C({Mat({T(1), T(2), T(3)}, //
|
||||
{T(4), T(5), T(6)})}, //
|
||||
Mat({T(1), T(4)}, //
|
||||
{T(2), T(5)}, //
|
||||
{T(3), T(6)})),
|
||||
|
||||
// 2x3
|
||||
C({Mat({T(1), T(2)}, //
|
||||
{T(3), T(4)}, //
|
||||
{T(5), T(6)})}, //
|
||||
Mat({T(1), T(3), T(5)}, //
|
||||
{T(2), T(4), T(6)})),
|
||||
};
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
Transpose,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kTranspose),
|
||||
testing::ValuesIn(Concat(TransposeCases<AFloat>(), //
|
||||
TransposeCases<f32>(),
|
||||
TransposeCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> TruncCases() {
|
||||
std::vector<Case> cases = {C({T(0)}, T(0)), //
|
||||
|
||||
@@ -306,7 +306,17 @@ using Types = std::variant< //
|
||||
Value<builder::mat3x2<AInt>>,
|
||||
Value<builder::mat3x2<AFloat>>,
|
||||
Value<builder::mat3x2<f32>>,
|
||||
Value<builder::mat3x2<f16>>
|
||||
Value<builder::mat3x2<f16>>,
|
||||
|
||||
Value<builder::mat2x4<AInt>>,
|
||||
Value<builder::mat2x4<AFloat>>,
|
||||
Value<builder::mat2x4<f32>>,
|
||||
Value<builder::mat2x4<f16>>,
|
||||
|
||||
Value<builder::mat4x2<AInt>>,
|
||||
Value<builder::mat4x2<AFloat>>,
|
||||
Value<builder::mat4x2<f32>>,
|
||||
Value<builder::mat4x2<f16>>
|
||||
//
|
||||
>;
|
||||
|
||||
|
||||
@@ -13734,12 +13734,12 @@ constexpr OverloadInfo kOverloads[] = {
|
||||
/* num parameters */ 1,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 2,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[3],
|
||||
/* parameters */ &kParameters[908],
|
||||
/* return matcher indices */ &kMatcherIndices[18],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::transpose,
|
||||
},
|
||||
{
|
||||
/* [452] */
|
||||
@@ -14518,7 +14518,7 @@ constexpr IntrinsicInfo kBuiltins[] = {
|
||||
},
|
||||
{
|
||||
/* [78] */
|
||||
/* fn transpose<M : num, N : num, T : f32_f16>(mat<M, N, T>) -> mat<N, M, T> */
|
||||
/* fn transpose<M : num, N : num, T : fa_f32_f16>(mat<M, N, T>) -> mat<N, M, T> */
|
||||
/* num overloads */ 1,
|
||||
/* overloads */ &kOverloads[451],
|
||||
},
|
||||
|
||||
@@ -150,6 +150,12 @@ using mat2x3 = mat<2, 3, T>;
|
||||
template <typename T>
|
||||
using mat3x2 = mat<3, 2, T>;
|
||||
|
||||
template <typename T>
|
||||
using mat2x4 = mat<2, 4, T>;
|
||||
|
||||
template <typename T>
|
||||
using mat4x2 = mat<4, 2, T>;
|
||||
|
||||
template <typename T>
|
||||
using mat3x3 = mat<3, 3, T>;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user