tint: Add matrix short names

Fixed: tint:1786
Change-Id: Ifa3acb2fc1792b392ccb4555bde840f5038eef2c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/114141
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: David Neto <dneto@google.com>
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton
2022-12-14 11:55:58 +00:00
committed by Dawn LUCI CQ
parent 167a7da051
commit 0335c7d65d
10 changed files with 392 additions and 52 deletions

View File

@@ -138,7 +138,18 @@ TEST_P(ResolverF16ExtensionShortNameTest, Vec2hTypeUsedWithoutExtension) {
INSTANTIATE_TEST_SUITE_P(ResolverF16ExtensionShortNameTest,
ResolverF16ExtensionShortNameTest,
testing::Values("vec2h", "vec3h", "vec4h"));
testing::Values("mat2x2h",
"mat2x3h",
"mat2x4h",
"mat3x2h",
"mat3x3h",
"mat3x4h",
"mat4x2h",
"mat4x3h",
"mat4x4h",
"vec2h",
"vec3h",
"vec4h"));
} // namespace
} // namespace tint::resolver

View File

@@ -2387,25 +2387,67 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr,
type::Type* Resolver::ShortName(Symbol sym, const Source& source) const {
auto name = builder_->Symbols().NameFor(sym);
auto& b = *builder_;
auto vec_f32 = [&](uint32_t n) { return b.create<type::Vector>(b.create<type::F32>(), n); };
auto vec_f16 = [&](uint32_t n) { return b.create<type::Vector>(b.create<type::F16>(), n); };
switch (type::ParseShortName(name)) {
case type::ShortName::kMat2X2F:
return b.create<type::Matrix>(vec_f32(2u), 2u);
case type::ShortName::kMat2X3F:
return b.create<type::Matrix>(vec_f32(3u), 2u);
case type::ShortName::kMat2X4F:
return b.create<type::Matrix>(vec_f32(4u), 2u);
case type::ShortName::kMat3X2F:
return b.create<type::Matrix>(vec_f32(2u), 3u);
case type::ShortName::kMat3X3F:
return b.create<type::Matrix>(vec_f32(3u), 3u);
case type::ShortName::kMat3X4F:
return b.create<type::Matrix>(vec_f32(4u), 3u);
case type::ShortName::kMat4X2F:
return b.create<type::Matrix>(vec_f32(2u), 4u);
case type::ShortName::kMat4X3F:
return b.create<type::Matrix>(vec_f32(3u), 4u);
case type::ShortName::kMat4X4F:
return b.create<type::Matrix>(vec_f32(4u), 4u);
case type::ShortName::kMat2X2H:
return validator_.CheckF16Enabled(source) ? b.create<type::Matrix>(vec_f16(2u), 2u)
: nullptr;
case type::ShortName::kMat2X3H:
return validator_.CheckF16Enabled(source) ? b.create<type::Matrix>(vec_f16(3u), 2u)
: nullptr;
case type::ShortName::kMat2X4H:
return validator_.CheckF16Enabled(source) ? b.create<type::Matrix>(vec_f16(4u), 2u)
: nullptr;
case type::ShortName::kMat3X2H:
return validator_.CheckF16Enabled(source) ? b.create<type::Matrix>(vec_f16(2u), 3u)
: nullptr;
case type::ShortName::kMat3X3H:
return validator_.CheckF16Enabled(source) ? b.create<type::Matrix>(vec_f16(3u), 3u)
: nullptr;
case type::ShortName::kMat3X4H:
return validator_.CheckF16Enabled(source) ? b.create<type::Matrix>(vec_f16(4u), 3u)
: nullptr;
case type::ShortName::kMat4X2H:
return validator_.CheckF16Enabled(source) ? b.create<type::Matrix>(vec_f16(2u), 4u)
: nullptr;
case type::ShortName::kMat4X3H:
return validator_.CheckF16Enabled(source) ? b.create<type::Matrix>(vec_f16(3u), 4u)
: nullptr;
case type::ShortName::kMat4X4H:
return validator_.CheckF16Enabled(source) ? b.create<type::Matrix>(vec_f16(4u), 4u)
: nullptr;
case type::ShortName::kVec2F:
return b.create<type::Vector>(b.create<type::F32>(), 2u);
return vec_f32(2u);
case type::ShortName::kVec3F:
return b.create<type::Vector>(b.create<type::F32>(), 3u);
return vec_f32(3u);
case type::ShortName::kVec4F:
return b.create<type::Vector>(b.create<type::F32>(), 4u);
return vec_f32(4u);
case type::ShortName::kVec2H:
return validator_.CheckF16Enabled(source)
? b.create<type::Vector>(b.create<type::F16>(), 2u)
: nullptr;
return validator_.CheckF16Enabled(source) ? vec_f16(2u) : nullptr;
case type::ShortName::kVec3H:
return validator_.CheckF16Enabled(source)
? b.create<type::Vector>(b.create<type::F16>(), 3u)
: nullptr;
return validator_.CheckF16Enabled(source) ? vec_f16(3u) : nullptr;
case type::ShortName::kVec4H:
return validator_.CheckF16Enabled(source)
? b.create<type::Vector>(b.create<type::F16>(), 4u)
: nullptr;
return validator_.CheckF16Enabled(source) ? vec_f16(4u) : nullptr;
case type::ShortName::kVec2I:
return b.create<type::Vector>(b.create<type::I32>(), 2u);
case type::ShortName::kVec3I:

View File

@@ -148,17 +148,23 @@ using mat2x2 = mat<2, 2, T>;
template <typename T>
using mat2x3 = mat<2, 3, T>;
template <typename T>
using mat2x4 = mat<2, 4, T>;
template <typename T>
using mat3x2 = mat<3, 2, T>;
template <typename T>
using mat2x4 = mat<2, 4, T>;
using mat3x3 = mat<3, 3, T>;
template <typename T>
using mat3x4 = mat<3, 4, T>;
template <typename T>
using mat4x2 = mat<4, 2, T>;
template <typename T>
using mat3x3 = mat<3, 3, T>;
using mat4x3 = mat<4, 3, T>;
template <typename T>
using mat4x4 = mat<4, 4, T>;

View File

@@ -39,8 +39,20 @@ using vec4 = builder::vec4<T>;
template <typename T>
using mat2x2 = builder::mat2x2<T>;
template <typename T>
using mat2x3 = builder::mat2x3<T>;
template <typename T>
using mat2x4 = builder::mat2x4<T>;
template <typename T>
using mat3x2 = builder::mat3x2<T>;
template <typename T>
using mat3x3 = builder::mat3x3<T>;
template <typename T>
using mat3x4 = builder::mat3x4<T>;
template <typename T>
using mat4x2 = builder::mat4x2<T>;
template <typename T>
using mat4x3 = builder::mat4x3<T>;
template <typename T>
using mat4x4 = builder::mat4x4<T>;
template <int N, typename T>
using array = builder::array<N, T>;
@@ -1241,6 +1253,7 @@ constexpr Params ParamsFor(uint32_t columns, uint32_t rows) {
using ValidMatrixTypes = ResolverTestWithParam<Params>;
TEST_P(ValidMatrixTypes, Okay) {
// enable f16;
// var a : matNxM<EL_TY>;
auto& params = GetParam();
@@ -1279,6 +1292,7 @@ INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
using InvalidMatrixElementTypes = ResolverTestWithParam<Params>;
TEST_P(InvalidMatrixElementTypes, InvalidElementType) {
// enable f16;
// var a : matNxM<EL_TY>;
auto& params = GetParam();
@@ -1321,6 +1335,7 @@ constexpr Params ParamsFor(uint32_t width) {
using ValidVectorTypes = ResolverTestWithParam<Params>;
TEST_P(ValidVectorTypes, Okay) {
// enable f16;
// var a : vecN<EL_TY>;
auto& params = GetParam();
@@ -1354,6 +1369,7 @@ INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
using InvalidVectorElementTypes = ResolverTestWithParam<Params>;
TEST_P(InvalidVectorElementTypes, InvalidElementType) {
// enable f16;
// var a : vecN<EL_TY>;
auto& params = GetParam();
@@ -1390,6 +1406,7 @@ constexpr Params Case(const char* alias) {
using BuiltinTypeAliasTest = ResolverTestWithParam<Params>;
TEST_P(BuiltinTypeAliasTest, CheckEquivalent) {
// enable f16;
// var aliased : vecTN;
// var explicit : vecN<T>;
// explicit = aliased;
@@ -1403,6 +1420,7 @@ TEST_P(BuiltinTypeAliasTest, CheckEquivalent) {
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_P(BuiltinTypeAliasTest, Construct) {
// enable f16;
// var v : vecN<T> = vecTN();
auto& params = GetParam();
@@ -1413,7 +1431,25 @@ TEST_P(BuiltinTypeAliasTest, Construct) {
}
INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
BuiltinTypeAliasTest,
testing::Values(Case<vec2<f32>>("vec2f"),
testing::Values(Case<mat2x2<f32>>("mat2x2f"),
Case<mat2x3<f32>>("mat2x3f"),
Case<mat2x4<f32>>("mat2x4f"),
Case<mat3x2<f32>>("mat3x2f"),
Case<mat3x3<f32>>("mat3x3f"),
Case<mat3x4<f32>>("mat3x4f"),
Case<mat4x2<f32>>("mat4x2f"),
Case<mat4x3<f32>>("mat4x3f"),
Case<mat4x4<f32>>("mat4x4f"),
Case<mat2x2<f16>>("mat2x2h"),
Case<mat2x3<f16>>("mat2x3h"),
Case<mat2x4<f16>>("mat2x4h"),
Case<mat3x2<f16>>("mat3x2h"),
Case<mat3x3<f16>>("mat3x3h"),
Case<mat3x4<f16>>("mat3x4h"),
Case<mat4x2<f16>>("mat4x2h"),
Case<mat4x3<f16>>("mat4x3h"),
Case<mat4x4<f16>>("mat4x4h"),
Case<vec2<f32>>("vec2f"),
Case<vec3<f32>>("vec3f"),
Case<vec4<f32>>("vec4f"),
Case<vec2<f16>>("vec2h"),