diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc index d09f23985a..81aa8d8422 100644 --- a/src/reader/spirv/parser_impl.cc +++ b/src/reader/spirv/parser_impl.cc @@ -27,8 +27,10 @@ #include "src/ast/type/bool_type.h" #include "src/ast/type/f32_type.h" #include "src/ast/type/i32_type.h" +#include "src/ast/type/matrix_type.h" #include "src/ast/type/type.h" #include "src/ast/type/u32_type.h" +#include "src/ast/type/vector_type.h" #include "src/ast/type/void_type.h" #include "src/type_manager.h" @@ -101,13 +103,13 @@ ast::Module ParserImpl::module() { return std::move(ast_module_); } -const ast::type::Type* ParserImpl::ConvertType(uint32_t type_id) { +ast::type::Type* ParserImpl::ConvertType(uint32_t type_id) { if (!success_) { return nullptr; } if (type_mgr_ == nullptr) { - Fail() << "ConvertType called when the internal module has not been built."; + Fail() << "ConvertType called when the internal module has not been built"; return nullptr; } @@ -122,7 +124,7 @@ const ast::type::Type* ParserImpl::ConvertType(uint32_t type_id) { return nullptr; } - const ast::type::Type* result = nullptr; + ast::type::Type* result = nullptr; TypeManager* tint_tm = TypeManager::Instance(); switch (spirv_type->kind()) { @@ -154,6 +156,31 @@ const ast::type::Type* ParserImpl::ConvertType(uint32_t type_id) { } break; } + case spvtools::opt::analysis::Type::kVector: { + const auto* vec_ty = spirv_type->AsVector(); + const auto num_elem = vec_ty->element_count(); + auto* ast_elem_ty = ConvertType(type_mgr_->GetId(vec_ty->element_type())); + if (ast_elem_ty != nullptr) { + result = tint_tm->Get( + std::make_unique(ast_elem_ty, num_elem)); + } + // In the error case, we'll already have emitted a diagnostic. + break; + } + case spvtools::opt::analysis::Type::kMatrix: { + const auto* mat_ty = spirv_type->AsMatrix(); + const auto* vec_ty = mat_ty->element_type()->AsVector(); + const auto* scalar_ty = vec_ty->element_type(); + const auto num_rows = vec_ty->element_count(); + const auto num_columns = mat_ty->element_count(); + auto* ast_scalar_ty = ConvertType(type_mgr_->GetId(scalar_ty)); + if (ast_scalar_ty != nullptr) { + result = tint_tm->Get(std::make_unique( + ast_scalar_ty, num_rows, num_columns)); + } + // In the error case, we'll already have emitted a diagnostic. + break; + } default: // The error diagnostic will be generated below because result is still // nullptr. diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h index b6f91ce1b8..f5bfc041cb 100644 --- a/src/reader/spirv/parser_impl.h +++ b/src/reader/spirv/parser_impl.h @@ -89,7 +89,7 @@ class ParserImpl : Reader { /// representation of the module has been built. /// @param type_id the SPIR-V ID of a type. /// @returns a Tint type, or nullptr - const ast::type::Type* ConvertType(uint32_t type_id); + ast::type::Type* ConvertType(uint32_t type_id); /// @returns the namer object Namer& namer() { return namer_; } @@ -158,7 +158,7 @@ class ParserImpl : Reader { std::unordered_set glsl_std_450_imports_; // Maps a SPIR-V type ID to a Tint type. - std::unordered_map id_to_type_; + std::unordered_map id_to_type_; }; } // namespace spirv diff --git a/src/reader/spirv/parser_impl_convert_type_test.cc b/src/reader/spirv/parser_impl_convert_type_test.cc index 44acf9aa73..f963ae14ec 100644 --- a/src/reader/spirv/parser_impl_convert_type_test.cc +++ b/src/reader/spirv/parser_impl_convert_type_test.cc @@ -19,7 +19,10 @@ #include #include "gmock/gmock.h" +#include "src/ast/type/matrix_type.h" +#include "src/ast/type/vector_type.h" #include "src/reader/spirv/spirv_tools_helpers_test.h" +#include "src/type_manager.h" namespace tint { namespace reader { @@ -28,21 +31,35 @@ namespace { using ::testing::Eq; -using SpvParserTest_ConvertType = ::testing::Test; +class SpvParserTest_ConvertType : public ::testing::Test { + void TearDown() override { + // Clean up the type manager instance at the end of a single test. + TypeManager::Destroy(); + } +}; TEST_F(SpvParserTest_ConvertType, PreservesExistingFailure) { ParserImpl p(std::vector{}); p.Fail() << "boing"; - const auto* type = p.ConvertType(10); + auto* type = p.ConvertType(10); EXPECT_EQ(type, nullptr); EXPECT_THAT(p.error(), Eq("boing")); } +TEST_F(SpvParserTest_ConvertType, RequiresInternalRepresntation) { + ParserImpl p(std::vector{}); + auto* type = p.ConvertType(10); + EXPECT_EQ(type, nullptr); + EXPECT_THAT( + p.error(), + Eq("ConvertType called when the internal module has not been built")); +} + TEST_F(SpvParserTest_ConvertType, NotAnId) { ParserImpl p(test::Assemble("%1 = OpExtInstImport \"GLSL.std.450\"")); EXPECT_TRUE(p.BuildAndParseInternalModule()) << p.error(); - const auto* type = p.ConvertType(10); + auto* type = p.ConvertType(10); EXPECT_EQ(type, nullptr); EXPECT_EQ(nullptr, type); EXPECT_THAT(p.error(), Eq("ID is not a SPIR-V type: 10")); @@ -52,7 +69,7 @@ TEST_F(SpvParserTest_ConvertType, IdExistsButIsNotAType) { ParserImpl p(test::Assemble("%1 = OpExtInstImport \"GLSL.std.450\"")); EXPECT_TRUE(p.BuildAndParseInternalModule()); - const auto* type = p.ConvertType(1); + auto* type = p.ConvertType(1); EXPECT_EQ(nullptr, type); EXPECT_THAT(p.error(), Eq("ID is not a SPIR-V type: 1")); } @@ -62,7 +79,7 @@ TEST_F(SpvParserTest_ConvertType, UnhandledType) { ParserImpl p(test::Assemble("%70 = OpTypePipe WriteOnly")); EXPECT_TRUE(p.BuildAndParseInternalModule()); - const auto* type = p.ConvertType(70); + auto* type = p.ConvertType(70); EXPECT_EQ(nullptr, type); EXPECT_THAT(p.error(), Eq("unknown SPIR-V type: 70")); } @@ -71,7 +88,7 @@ TEST_F(SpvParserTest_ConvertType, Void) { ParserImpl p(test::Assemble("%1 = OpTypeVoid")); EXPECT_TRUE(p.BuildAndParseInternalModule()); - const auto* type = p.ConvertType(1); + auto* type = p.ConvertType(1); EXPECT_TRUE(type->IsVoid()); EXPECT_TRUE(p.error().empty()); } @@ -80,7 +97,7 @@ TEST_F(SpvParserTest_ConvertType, Bool) { ParserImpl p(test::Assemble("%100 = OpTypeBool")); EXPECT_TRUE(p.BuildAndParseInternalModule()); - const auto* type = p.ConvertType(100); + auto* type = p.ConvertType(100); EXPECT_TRUE(type->IsBool()); EXPECT_TRUE(p.error().empty()); } @@ -89,7 +106,7 @@ TEST_F(SpvParserTest_ConvertType, I32) { ParserImpl p(test::Assemble("%2 = OpTypeInt 32 1")); EXPECT_TRUE(p.BuildAndParseInternalModule()); - const auto* type = p.ConvertType(2); + auto* type = p.ConvertType(2); EXPECT_TRUE(type->IsI32()); EXPECT_TRUE(p.error().empty()); } @@ -98,7 +115,7 @@ TEST_F(SpvParserTest_ConvertType, U32) { ParserImpl p(test::Assemble("%3 = OpTypeInt 32 0")); EXPECT_TRUE(p.BuildAndParseInternalModule()); - const auto* type = p.ConvertType(3); + auto* type = p.ConvertType(3); EXPECT_TRUE(type->IsU32()); EXPECT_TRUE(p.error().empty()); } @@ -107,7 +124,7 @@ TEST_F(SpvParserTest_ConvertType, F32) { ParserImpl p(test::Assemble("%4 = OpTypeFloat 32")); EXPECT_TRUE(p.BuildAndParseInternalModule()); - const auto* type = p.ConvertType(4); + auto* type = p.ConvertType(4); EXPECT_TRUE(type->IsF32()); EXPECT_TRUE(p.error().empty()); } @@ -116,7 +133,7 @@ TEST_F(SpvParserTest_ConvertType, BadIntWidth) { ParserImpl p(test::Assemble("%5 = OpTypeInt 17 1")); EXPECT_TRUE(p.BuildAndParseInternalModule()); - const auto* type = p.ConvertType(5); + auto* type = p.ConvertType(5); EXPECT_EQ(type, nullptr); EXPECT_THAT(p.error(), Eq("unhandled integer width: 17")); } @@ -125,11 +142,195 @@ TEST_F(SpvParserTest_ConvertType, BadFloatWidth) { ParserImpl p(test::Assemble("%6 = OpTypeFloat 19")); EXPECT_TRUE(p.BuildAndParseInternalModule()); - const auto* type = p.ConvertType(6); + auto* type = p.ConvertType(6); EXPECT_EQ(type, nullptr); EXPECT_THAT(p.error(), Eq("unhandled float width: 19")); } +TEST_F(SpvParserTest_ConvertType, InvalidVectorElement) { + ParserImpl p(test::Assemble(R"( + %5 = OpTypePipe ReadOnly + %20 = OpTypeVector %5 2 + )")); + EXPECT_TRUE(p.BuildAndParseInternalModule()); + + auto* type = p.ConvertType(20); + EXPECT_EQ(type, nullptr); + EXPECT_THAT(p.error(), Eq("unknown SPIR-V type: 5")); +} + +TEST_F(SpvParserTest_ConvertType, VecOverF32) { + ParserImpl p(test::Assemble(R"( + %float = OpTypeFloat 32 + %20 = OpTypeVector %float 2 + %30 = OpTypeVector %float 3 + %40 = OpTypeVector %float 4 + )")); + EXPECT_TRUE(p.BuildAndParseInternalModule()); + + auto* v2xf32 = p.ConvertType(20); + EXPECT_TRUE(v2xf32->IsVector()); + EXPECT_TRUE(v2xf32->AsVector()->type()->IsF32()); + EXPECT_EQ(v2xf32->AsVector()->size(), 2u); + + auto* v3xf32 = p.ConvertType(30); + EXPECT_TRUE(v3xf32->IsVector()); + EXPECT_TRUE(v3xf32->AsVector()->type()->IsF32()); + EXPECT_EQ(v3xf32->AsVector()->size(), 3u); + + auto* v4xf32 = p.ConvertType(40); + EXPECT_TRUE(v4xf32->IsVector()); + EXPECT_TRUE(v4xf32->AsVector()->type()->IsF32()); + EXPECT_EQ(v4xf32->AsVector()->size(), 4u); + + EXPECT_TRUE(p.error().empty()); +} + +TEST_F(SpvParserTest_ConvertType, VecOverI32) { + ParserImpl p(test::Assemble(R"( + %int = OpTypeInt 32 1 + %20 = OpTypeVector %int 2 + %30 = OpTypeVector %int 3 + %40 = OpTypeVector %int 4 + )")); + EXPECT_TRUE(p.BuildAndParseInternalModule()); + + auto* v2xi32 = p.ConvertType(20); + EXPECT_TRUE(v2xi32->IsVector()); + EXPECT_TRUE(v2xi32->AsVector()->type()->IsI32()); + EXPECT_EQ(v2xi32->AsVector()->size(), 2u); + + auto* v3xi32 = p.ConvertType(30); + EXPECT_TRUE(v3xi32->IsVector()); + EXPECT_TRUE(v3xi32->AsVector()->type()->IsI32()); + EXPECT_EQ(v3xi32->AsVector()->size(), 3u); + + auto* v4xi32 = p.ConvertType(40); + EXPECT_TRUE(v4xi32->IsVector()); + EXPECT_TRUE(v4xi32->AsVector()->type()->IsI32()); + EXPECT_EQ(v4xi32->AsVector()->size(), 4u); + + EXPECT_TRUE(p.error().empty()); +} + +TEST_F(SpvParserTest_ConvertType, VecOverU32) { + ParserImpl p(test::Assemble(R"( + %uint = OpTypeInt 32 0 + %20 = OpTypeVector %uint 2 + %30 = OpTypeVector %uint 3 + %40 = OpTypeVector %uint 4 + )")); + EXPECT_TRUE(p.BuildAndParseInternalModule()); + + auto* v2xu32 = p.ConvertType(20); + EXPECT_TRUE(v2xu32->IsVector()); + EXPECT_TRUE(v2xu32->AsVector()->type()->IsU32()); + EXPECT_EQ(v2xu32->AsVector()->size(), 2u); + + auto* v3xu32 = p.ConvertType(30); + EXPECT_TRUE(v3xu32->IsVector()); + EXPECT_TRUE(v3xu32->AsVector()->type()->IsU32()); + EXPECT_EQ(v3xu32->AsVector()->size(), 3u); + + auto* v4xu32 = p.ConvertType(40); + EXPECT_TRUE(v4xu32->IsVector()); + EXPECT_TRUE(v4xu32->AsVector()->type()->IsU32()); + EXPECT_EQ(v4xu32->AsVector()->size(), 4u); + + EXPECT_TRUE(p.error().empty()); +} + +TEST_F(SpvParserTest_ConvertType, InvalidMatrixElement) { + ParserImpl p(test::Assemble(R"( + %5 = OpTypePipe ReadOnly + %10 = OpTypeVector %5 2 + %20 = OpTypeMatrix %10 2 + )")); + EXPECT_TRUE(p.BuildAndParseInternalModule()); + + auto* type = p.ConvertType(20); + EXPECT_EQ(type, nullptr); + EXPECT_THAT(p.error(), Eq("unknown SPIR-V type: 5")); +} + +TEST_F(SpvParserTest_ConvertType, MatrixOverF32) { + // Matrices are only defined over floats. + ParserImpl p(test::Assemble(R"( + %float = OpTypeFloat 32 + %v2 = OpTypeVector %float 2 + %v3 = OpTypeVector %float 3 + %v4 = OpTypeVector %float 4 + ; First digit is rows + ; Second digit is columns + %22 = OpTypeMatrix %v2 2 + %23 = OpTypeMatrix %v2 3 + %24 = OpTypeMatrix %v2 4 + %32 = OpTypeMatrix %v3 2 + %33 = OpTypeMatrix %v3 3 + %34 = OpTypeMatrix %v3 4 + %42 = OpTypeMatrix %v4 2 + %43 = OpTypeMatrix %v4 3 + %44 = OpTypeMatrix %v4 4 + )")); + EXPECT_TRUE(p.BuildAndParseInternalModule()); + + auto* m22 = p.ConvertType(22); + EXPECT_TRUE(m22->IsMatrix()); + EXPECT_TRUE(m22->AsMatrix()->type()->IsF32()); + EXPECT_EQ(m22->AsMatrix()->rows(), 2); + EXPECT_EQ(m22->AsMatrix()->columns(), 2); + + auto* m23 = p.ConvertType(23); + EXPECT_TRUE(m23->IsMatrix()); + EXPECT_TRUE(m23->AsMatrix()->type()->IsF32()); + EXPECT_EQ(m23->AsMatrix()->rows(), 2); + EXPECT_EQ(m23->AsMatrix()->columns(), 3); + + auto* m24 = p.ConvertType(24); + EXPECT_TRUE(m24->IsMatrix()); + EXPECT_TRUE(m24->AsMatrix()->type()->IsF32()); + EXPECT_EQ(m24->AsMatrix()->rows(), 2); + EXPECT_EQ(m24->AsMatrix()->columns(), 4); + + auto* m32 = p.ConvertType(32); + EXPECT_TRUE(m32->IsMatrix()); + EXPECT_TRUE(m32->AsMatrix()->type()->IsF32()); + EXPECT_EQ(m32->AsMatrix()->rows(), 3); + EXPECT_EQ(m32->AsMatrix()->columns(), 2); + + auto* m33 = p.ConvertType(33); + EXPECT_TRUE(m33->IsMatrix()); + EXPECT_TRUE(m33->AsMatrix()->type()->IsF32()); + EXPECT_EQ(m33->AsMatrix()->rows(), 3); + EXPECT_EQ(m33->AsMatrix()->columns(), 3); + + auto* m34 = p.ConvertType(34); + EXPECT_TRUE(m34->IsMatrix()); + EXPECT_TRUE(m34->AsMatrix()->type()->IsF32()); + EXPECT_EQ(m34->AsMatrix()->rows(), 3); + EXPECT_EQ(m34->AsMatrix()->columns(), 4); + + auto* m42 = p.ConvertType(42); + EXPECT_TRUE(m42->IsMatrix()); + EXPECT_TRUE(m42->AsMatrix()->type()->IsF32()); + EXPECT_EQ(m42->AsMatrix()->rows(), 4); + EXPECT_EQ(m42->AsMatrix()->columns(), 2); + + auto* m43 = p.ConvertType(43); + EXPECT_TRUE(m43->IsMatrix()); + EXPECT_TRUE(m43->AsMatrix()->type()->IsF32()); + EXPECT_EQ(m43->AsMatrix()->rows(), 4); + EXPECT_EQ(m43->AsMatrix()->columns(), 3); + + auto* m44 = p.ConvertType(44); + EXPECT_TRUE(m44->IsMatrix()); + EXPECT_TRUE(m44->AsMatrix()->type()->IsF32()); + EXPECT_EQ(m44->AsMatrix()->rows(), 4); + EXPECT_EQ(m44->AsMatrix()->columns(), 4); + + EXPECT_TRUE(p.error().empty()); +} + } // namespace } // namespace spirv } // namespace reader