tint: const eval of transpose builtin

Bug: tint:1581
Change-Id: Ia614647bc4a3d24a53d45981ddcdb1c84ea84608
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111600
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Antonio Maiorano
2022-11-23 23:12:56 +00:00
committed by Dawn LUCI CQ
parent dd3fb05af7
commit 9ba5f9e2c6
214 changed files with 5439 additions and 460 deletions

View File

@@ -546,7 +546,7 @@ fn refract<N: num, T: f32_f16>(vec<N, T>, vec<N, T>, T) -> vec<N, T>
@const fn tan<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
@const fn tanh<T: fa_f32_f16>(T) -> T
@const fn tanh<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
fn transpose<M: num, N: num, T: f32_f16>(mat<M, N, T>) -> mat<N, M, T>
@const fn transpose<M: num, N: num, T: fa_f32_f16>(mat<M, N, T>) -> mat<N, M, T>
@const fn trunc<T: fa_f32_f16>(@test_value(1.5) T) -> T
@const fn trunc<N: num, T: fa_f32_f16>(@test_value(1.5) vec<N, T>) -> vec<N, T>
@const fn unpack2x16float(u32) -> vec2<f32>

View File

@@ -3031,6 +3031,26 @@ ConstEval::Result ConstEval::tanh(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0]);
}
ConstEval::Result ConstEval::transpose(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {
auto* m = args[0];
auto* mat_ty = m->Type()->As<sem::Matrix>();
auto me = [&](size_t r, size_t c) { return m->Index(c)->Index(r); };
auto* result_mat_ty = ty->As<sem::Matrix>();
// Produce column vectors from each row
utils::Vector<const sem::Constant*, 4> result_mat;
for (size_t r = 0; r < mat_ty->rows(); ++r) {
utils::Vector<const sem::Constant*, 4> new_col_vec;
for (size_t c = 0; c < mat_ty->columns(); ++c) {
new_col_vec.Push(me(r, c));
}
result_mat.Push(CreateComposite(builder, result_mat_ty->ColumnType(), new_col_vec));
}
return CreateComposite(builder, ty, result_mat);
}
ConstEval::Result ConstEval::trunc(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {

View File

@@ -863,6 +863,15 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// transpose builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result transpose(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// trunc builtin
/// @param ty the expression type
/// @param args the input arguments

View File

@@ -2092,6 +2092,72 @@ INSTANTIATE_TEST_SUITE_P( //
TanhCases<f32>(),
TanhCases<f16>()))));
template <typename T>
std::vector<Case> TransposeCases() {
return {
// 2x2
C({Mat({T(1), T(2)}, //
{T(3), T(4)})}, //
Mat({T(1), T(3)}, //
{T(2), T(4)})),
// 3x3
C({Mat({T(1), T(2), T(3)}, //
{T(4), T(5), T(6)}, //
{T(7), T(8), T(9)})}, //
Mat({T(1), T(4), T(7)}, //
{T(2), T(5), T(8)}, //
{T(3), T(6), T(9)})),
// 4x4
C({Mat({T(1), T(2), T(3), T(4)}, //
{T(5), T(6), T(7), T(8)}, //
{T(9), T(10), T(11), T(12)}, //
{T(13), T(14), T(15), T(16)})}, //
Mat({T(1), T(5), T(9), T(13)}, //
{T(2), T(6), T(10), T(14)}, //
{T(3), T(7), T(11), T(15)}, //
{T(4), T(8), T(12), T(16)})),
// 4x2
C({Mat({T(1), T(2), T(3), T(4)}, //
{T(5), T(6), T(7), T(8)})}, //
Mat({T(1), T(5)}, //
{T(2), T(6)}, //
{T(3), T(7)}, //
{T(4), T(8)})),
// 2x4
C({Mat({T(1), T(2)}, //
{T(3), T(4)}, //
{T(5), T(6)}, //
{T(7), T(8)})}, //
Mat({T(1), T(3), T(5), T(7)}, //
{T(2), T(4), T(6), T(8)})),
// 3x2
C({Mat({T(1), T(2), T(3)}, //
{T(4), T(5), T(6)})}, //
Mat({T(1), T(4)}, //
{T(2), T(5)}, //
{T(3), T(6)})),
// 2x3
C({Mat({T(1), T(2)}, //
{T(3), T(4)}, //
{T(5), T(6)})}, //
Mat({T(1), T(3), T(5)}, //
{T(2), T(4), T(6)})),
};
}
INSTANTIATE_TEST_SUITE_P( //
Transpose,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kTranspose),
testing::ValuesIn(Concat(TransposeCases<AFloat>(), //
TransposeCases<f32>(),
TransposeCases<f16>()))));
template <typename T>
std::vector<Case> TruncCases() {
std::vector<Case> cases = {C({T(0)}, T(0)), //

View File

@@ -306,7 +306,17 @@ using Types = std::variant< //
Value<builder::mat3x2<AInt>>,
Value<builder::mat3x2<AFloat>>,
Value<builder::mat3x2<f32>>,
Value<builder::mat3x2<f16>>
Value<builder::mat3x2<f16>>,
Value<builder::mat2x4<AInt>>,
Value<builder::mat2x4<AFloat>>,
Value<builder::mat2x4<f32>>,
Value<builder::mat2x4<f16>>,
Value<builder::mat4x2<AInt>>,
Value<builder::mat4x2<AFloat>>,
Value<builder::mat4x2<f32>>,
Value<builder::mat4x2<f16>>
//
>;

View File

@@ -13734,12 +13734,12 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 2,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[3],
/* parameters */ &kParameters[908],
/* return matcher indices */ &kMatcherIndices[18],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::transpose,
},
{
/* [452] */
@@ -14518,7 +14518,7 @@ constexpr IntrinsicInfo kBuiltins[] = {
},
{
/* [78] */
/* fn transpose<M : num, N : num, T : f32_f16>(mat<M, N, T>) -> mat<N, M, T> */
/* fn transpose<M : num, N : num, T : fa_f32_f16>(mat<M, N, T>) -> mat<N, M, T> */
/* num overloads */ 1,
/* overloads */ &kOverloads[451],
},

View File

@@ -150,6 +150,12 @@ using mat2x3 = mat<2, 3, T>;
template <typename T>
using mat3x2 = mat<3, 2, T>;
template <typename T>
using mat2x4 = mat<2, 4, T>;
template <typename T>
using mat4x2 = mat<4, 2, T>;
template <typename T>
using mat3x3 = mat<3, 3, T>;