diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5813e8b19b..1f2c161e7f 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -322,6 +322,7 @@ if(${TINT_BUILD_SPV_READER}) reader/spirv/enum_converter_test.cc reader/spirv/fail_stream_test.cc reader/spirv/namer_test.cc + reader/spirv/parser_impl_convert_member_decoration_test.cc reader/spirv/parser_impl_convert_type_test.cc reader/spirv/parser_impl_entry_point_test.cc reader/spirv/parser_impl_get_decorations_test.cc diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc index 51b94e33e2..27ec8c6257 100644 --- a/src/reader/spirv/parser_impl.cc +++ b/src/reader/spirv/parser_impl.cc @@ -25,12 +25,19 @@ #include "source/opt/instruction.h" #include "source/opt/module.h" #include "source/opt/type_manager.h" +#include "source/opt/types.h" #include "spirv-tools/libspirv.hpp" +#include "src/ast/struct.h" +#include "src/ast/struct_decoration.h" +#include "src/ast/struct_member.h" +#include "src/ast/struct_member_decoration.h" +#include "src/ast/struct_member_offset_decoration.h" #include "src/ast/type/array_type.h" #include "src/ast/type/bool_type.h" #include "src/ast/type/f32_type.h" #include "src/ast/type/i32_type.h" #include "src/ast/type/matrix_type.h" +#include "src/ast/type/struct_type.h" #include "src/ast/type/type.h" #include "src/ast/type/u32_type.h" #include "src/ast/type/vector_type.h" @@ -127,130 +134,38 @@ ast::type::Type* ParserImpl::ConvertType(uint32_t type_id) { return nullptr; } - ast::type::Type* result = nullptr; + auto save = [this, type_id](ast::type::Type* type) { + if (type != nullptr) { + id_to_type_[type_id] = type; + } + return type; + }; switch (spirv_type->kind()) { case spvtools::opt::analysis::Type::kVoid: - result = ctx_.type_mgr().Get(std::make_unique()); - break; + return save(ctx_.type_mgr().Get(std::make_unique())); case spvtools::opt::analysis::Type::kBool: - result = ctx_.type_mgr().Get(std::make_unique()); - break; - case spvtools::opt::analysis::Type::kInteger: { - const auto* int_ty = spirv_type->AsInteger(); - if (int_ty->width() == 32) { - if (int_ty->IsSigned()) { - result = ctx_.type_mgr().Get(std::make_unique()); - } else { - result = ctx_.type_mgr().Get(std::make_unique()); - } - } else { - Fail() << "unhandled integer width: " << int_ty->width(); - } - break; - } - case spvtools::opt::analysis::Type::kFloat: { - const auto* float_ty = spirv_type->AsFloat(); - if (float_ty->width() == 32) { - result = ctx_.type_mgr().Get(std::make_unique()); - } else { - Fail() << "unhandled float width: " << float_ty->width(); - } - break; - } - case spvtools::opt::analysis::Type::kVector: { - const auto* vec_ty = spirv_type->AsVector(); - const auto num_elem = vec_ty->element_count(); - auto* ast_elem_ty = ConvertType(type_mgr_->GetId(vec_ty->element_type())); - if (ast_elem_ty != nullptr) { - result = ctx_.type_mgr().Get( - std::make_unique(ast_elem_ty, num_elem)); - } - // In the error case, we'll already have emitted a diagnostic. - break; - } - case spvtools::opt::analysis::Type::kMatrix: { - const auto* mat_ty = spirv_type->AsMatrix(); - const auto* vec_ty = mat_ty->element_type()->AsVector(); - const auto* scalar_ty = vec_ty->element_type(); - const auto num_rows = vec_ty->element_count(); - const auto num_columns = mat_ty->element_count(); - auto* ast_scalar_ty = ConvertType(type_mgr_->GetId(scalar_ty)); - if (ast_scalar_ty != nullptr) { - result = ctx_.type_mgr().Get(std::make_unique( - ast_scalar_ty, num_rows, num_columns)); - } - // In the error case, we'll already have emitted a diagnostic. - break; - } - case spvtools::opt::analysis::Type::kRuntimeArray: { - // TODO(dneto): Handle ArrayStride. Blocked by crbug.com/tint/30 - const auto* rtarr_ty = spirv_type->AsRuntimeArray(); - auto* ast_elem_ty = - ConvertType(type_mgr_->GetId(rtarr_ty->element_type())); - if (ast_elem_ty != nullptr) { - result = ctx_.type_mgr().Get( - std::make_unique(ast_elem_ty)); - } - // In the error case, we'll already have emitted a diagnostic. - break; - } - case spvtools::opt::analysis::Type::kArray: { - // TODO(dneto): Handle ArrayStride. Blocked by crbug.com/tint/30 - const auto* arr_ty = spirv_type->AsArray(); - auto* ast_elem_ty = ConvertType(type_mgr_->GetId(arr_ty->element_type())); - if (ast_elem_ty == nullptr) { - // In the error case, we'll already have emitted a diagnostic. - break; - } - const auto& length_info = arr_ty->length_info(); - if (length_info.words.empty()) { - // The internal representation is invalid. The discriminant vector - // is mal-formed. - Fail() << "internal error: Array length info is invalid"; - return nullptr; - } - if (length_info.words[0] != - spvtools::opt::analysis::Array::LengthInfo::kConstant) { - Fail() << "Array type " << type_id - << " length is a specialization constant"; - return nullptr; - } - const auto* constant = - constant_mgr_->FindDeclaredConstant(length_info.id); - if (constant == nullptr) { - Fail() << "Array type " << type_id << " length ID " << length_info.id - << " does not name an OpConstant"; - return nullptr; - } - const uint64_t num_elem = constant->GetZeroExtendedValue(); - // For now, limit to only 32bits. - if (num_elem > std::numeric_limits::max()) { - Fail() << "Array type " << type_id - << " has too many elements (more than can fit in 32 bits): " - << num_elem; - return nullptr; - } - result = ctx_.type_mgr().Get(std::make_unique( - ast_elem_ty, static_cast(num_elem))); - break; - } + return save(ctx_.type_mgr().Get(std::make_unique())); + case spvtools::opt::analysis::Type::kInteger: + return save(ConvertType(spirv_type->AsInteger())); + case spvtools::opt::analysis::Type::kFloat: + return save(ConvertType(spirv_type->AsFloat())); + case spvtools::opt::analysis::Type::kVector: + return save(ConvertType(spirv_type->AsVector())); + case spvtools::opt::analysis::Type::kMatrix: + return save(ConvertType(spirv_type->AsMatrix())); + case spvtools::opt::analysis::Type::kRuntimeArray: + return save(ConvertType(spirv_type->AsRuntimeArray())); + case spvtools::opt::analysis::Type::kArray: + return save(ConvertType(spirv_type->AsArray())); + case spvtools::opt::analysis::Type::kStruct: + return save(ConvertType(spirv_type->AsStruct())); default: - // The error diagnostic will be generated below because result is still - // nullptr. break; } - if (result == nullptr) { - if (success_) { - // Only emit a new diagnostic if we haven't already emitted a more - // specific one. - Fail() << "unknown SPIR-V type: " << type_id; - } - } else { - id_to_type_[type_id] = result; - } - return result; + Fail() << "unknown SPIR-V type: " << type_id; + return nullptr; } DecorationList ParserImpl::GetDecorationsFor(uint32_t id) const { @@ -289,6 +204,29 @@ DecorationList ParserImpl::GetDecorationsForMember( return result; } +std::unique_ptr +ParserImpl::ConvertMemberDecoration(const Decoration& decoration) { + if (decoration.empty()) { + Fail() << "malformed SPIR-V decoration: it's empty"; + return nullptr; + } + switch (decoration[0]) { + case SpvDecorationOffset: + if (decoration.size() != 2) { + Fail() + << "malformed Offset decoration: expected 1 literal operand, has " + << decoration.size() - 1; + return nullptr; + } + return std::make_unique(decoration[1]); + default: + // TODO(dneto): Support the remaining member decorations. + break; + } + Fail() << "unhandled member decoration: " << decoration[0]; + return nullptr; +} + bool ParserImpl::BuildInternalModule() { tools_.SetMessageConsumer(message_consumer_); @@ -412,6 +350,153 @@ bool ParserImpl::EmitEntryPoints() { return success_; } +ast::type::Type* ParserImpl::ConvertType( + const spvtools::opt::analysis::Integer* int_ty) { + if (int_ty->width() == 32) { + if (int_ty->IsSigned()) { + return ctx_.type_mgr().Get(std::make_unique()); + } else { + return ctx_.type_mgr().Get(std::make_unique()); + } + } + Fail() << "unhandled integer width: " << int_ty->width(); + return nullptr; +} + +ast::type::Type* ParserImpl::ConvertType( + const spvtools::opt::analysis::Float* float_ty) { + if (float_ty->width() == 32) { + return ctx_.type_mgr().Get(std::make_unique()); + } + Fail() << "unhandled float width: " << float_ty->width(); + return nullptr; +} + +ast::type::Type* ParserImpl::ConvertType( + const spvtools::opt::analysis::Vector* vec_ty) { + const auto num_elem = vec_ty->element_count(); + auto* ast_elem_ty = ConvertType(type_mgr_->GetId(vec_ty->element_type())); + if (ast_elem_ty == nullptr) { + return nullptr; + } + return ctx_.type_mgr().Get( + std::make_unique(ast_elem_ty, num_elem)); +} + +ast::type::Type* ParserImpl::ConvertType( + const spvtools::opt::analysis::Matrix* mat_ty) { + const auto* vec_ty = mat_ty->element_type()->AsVector(); + const auto* scalar_ty = vec_ty->element_type(); + const auto num_rows = vec_ty->element_count(); + const auto num_columns = mat_ty->element_count(); + auto* ast_scalar_ty = ConvertType(type_mgr_->GetId(scalar_ty)); + if (ast_scalar_ty == nullptr) { + return nullptr; + } + return ctx_.type_mgr().Get(std::make_unique( + ast_scalar_ty, num_rows, num_columns)); +} + +ast::type::Type* ParserImpl::ConvertType( + const spvtools::opt::analysis::RuntimeArray* rtarr_ty) { + // TODO(dneto): Handle ArrayStride. Blocked by crbug.com/tint/30 + auto* ast_elem_ty = ConvertType(type_mgr_->GetId(rtarr_ty->element_type())); + if (ast_elem_ty == nullptr) { + return nullptr; + } + return ctx_.type_mgr().Get( + std::make_unique(ast_elem_ty)); +} + +ast::type::Type* ParserImpl::ConvertType( + const spvtools::opt::analysis::Array* arr_ty) { + // TODO(dneto): Handle ArrayStride. Blocked by crbug.com/tint/30 + auto* ast_elem_ty = ConvertType(type_mgr_->GetId(arr_ty->element_type())); + if (ast_elem_ty == nullptr) { + return nullptr; + } + const auto& length_info = arr_ty->length_info(); + if (length_info.words.empty()) { + // The internal representation is invalid. The discriminant vector + // is mal-formed. + Fail() << "internal error: Array length info is invalid"; + return nullptr; + } + if (length_info.words[0] != + spvtools::opt::analysis::Array::LengthInfo::kConstant) { + Fail() << "Array type " << type_mgr_->GetId(arr_ty) + << " length is a specialization constant"; + return nullptr; + } + const auto* constant = constant_mgr_->FindDeclaredConstant(length_info.id); + if (constant == nullptr) { + Fail() << "Array type " << type_mgr_->GetId(arr_ty) << " length ID " + << length_info.id << " does not name an OpConstant"; + return nullptr; + } + const uint64_t num_elem = constant->GetZeroExtendedValue(); + // For now, limit to only 32bits. + if (num_elem > std::numeric_limits::max()) { + Fail() << "Array type " << type_mgr_->GetId(arr_ty) + << " has too many elements (more than can fit in 32 bits): " + << num_elem; + return nullptr; + } + return ctx_.type_mgr().Get(std::make_unique( + ast_elem_ty, static_cast(num_elem))); +} + +ast::type::Type* ParserImpl::ConvertType( + const spvtools::opt::analysis::Struct* struct_ty) { + const auto type_id = type_mgr_->GetId(struct_ty); + // Compute the struct decoration. + auto struct_decorations = this->GetDecorationsFor(type_id); + auto ast_struct_decoration = ast::StructDecoration::kNone; + if (struct_decorations.size() == 1 && + struct_decorations[0][0] == SpvDecorationBlock) { + ast_struct_decoration = ast::StructDecoration::kBlock; + } else if (struct_decorations.size() > 1) { + Fail() << "can't handle a struct with more than one decoration: struct " + << type_id << " has " << struct_decorations.size(); + return nullptr; + } + + // Compute members + std::vector> ast_members; + const auto members = struct_ty->element_types(); + for (size_t member_index = 0; member_index < members.size(); ++member_index) { + auto* ast_member_ty = ConvertType(type_mgr_->GetId(members[member_index])); + if (ast_member_ty == nullptr) { + // Already emitted diagnostics. + return nullptr; + } + std::vector> + ast_member_decorations; + for (auto& deco : GetDecorationsForMember(type_id, member_index)) { + auto ast_member_decoration = ConvertMemberDecoration(deco); + if (ast_member_decoration == nullptr) { + // Already emitted diagnostics. + return nullptr; + } + ast_member_decorations.push_back(std::move(ast_member_decoration)); + } + const auto member_name = namer_.GetMemberName(type_id, member_index); + auto ast_struct_member = std::make_unique( + member_name, ast_member_ty, std::move(ast_member_decorations)); + ast_members.push_back(std::move(ast_struct_member)); + } + + // Now make the struct. + auto ast_struct = std::make_unique(ast_struct_decoration, + std::move(ast_members)); + auto ast_struct_type = + std::make_unique(std::move(ast_struct)); + // The struct might not have a name yet. Suggest one. + namer_.SuggestSanitizedName(type_id, "S"); + ast_struct_type->set_name(namer_.GetName(type_id)); + return ctx_.type_mgr().Get(std::move(ast_struct_type)); +} + } // namespace spirv } // namespace reader } // namespace tint diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h index 0bb6694e3f..8a7b974936 100644 --- a/src/reader/spirv/parser_impl.h +++ b/src/reader/spirv/parser_impl.h @@ -28,9 +28,11 @@ #include "source/opt/ir_context.h" #include "source/opt/module.h" #include "source/opt/type_manager.h" +#include "source/opt/types.h" #include "spirv-tools/libspirv.hpp" #include "src/ast/import.h" #include "src/ast/module.h" +#include "src/ast/struct_member_decoration.h" #include "src/ast/type/type.h" #include "src/reader/reader.h" #include "src/reader/spirv/enum_converter.h" @@ -91,10 +93,9 @@ class ParserImpl : Reader { return glsl_std_450_imports_; } - /// Converts a SPIR-V type to a Tint type. - /// On failure, logs an error and returns null. - /// This should only be called after the internal - /// representation of the module has been built. + /// Converts a SPIR-V type to a Tint type, and saves it for fast lookup. + /// On failure, logs an error and returns null. This should only be called + /// after the internal representation of the module has been built. /// @param type_id the SPIR-V ID of a type. /// @returns a Tint type, or nullptr ast::type::Type* ConvertType(uint32_t type_id); @@ -118,6 +119,13 @@ class ParserImpl : Reader { DecorationList GetDecorationsForMember(uint32_t id, uint32_t member_index) const; + /// Converts a SPIR-V decoration. On failure, emits a diagnostic and returns + /// nullptr. + /// @param decoration an encoded SPIR-V Decoration + /// @returns the corresponding ast::StructuMemberDecoration + std::unique_ptr ConvertMemberDecoration( + const Decoration& decoration); + private: /// Builds the internal representation of the SPIR-V module. /// Assumes the module is somewhat well-formed. Normally you @@ -145,6 +153,23 @@ class ParserImpl : Reader { /// Emit entry point AST nodes. bool EmitEntryPoints(); + /// Converts a specific SPIR-V type to a Tint type. Integer case + ast::type::Type* ConvertType(const spvtools::opt::analysis::Integer* int_ty); + /// Converts a specific SPIR-V type to a Tint type. Float case + ast::type::Type* ConvertType(const spvtools::opt::analysis::Float* float_ty); + /// Converts a specific SPIR-V type to a Tint type. Vector case + ast::type::Type* ConvertType(const spvtools::opt::analysis::Vector* vec_ty); + /// Converts a specific SPIR-V type to a Tint type. Matrix case + ast::type::Type* ConvertType(const spvtools::opt::analysis::Matrix* mat_ty); + /// Converts a specific SPIR-V type to a Tint type. RuntimeArray case + ast::type::Type* ConvertType( + const spvtools::opt::analysis::RuntimeArray* rtarr_ty); + /// Converts a specific SPIR-V type to a Tint type. Array case + ast::type::Type* ConvertType(const spvtools::opt::analysis::Array* arr_ty); + /// Converts a specific SPIR-V type to a Tint type. Struct case + ast::type::Type* ConvertType( + const spvtools::opt::analysis::Struct* struct_ty); + // The SPIR-V binary we're parsing std::vector spv_binary_; diff --git a/src/reader/spirv/parser_impl_convert_member_decoration_test.cc b/src/reader/spirv/parser_impl_convert_member_decoration_test.cc new file mode 100644 index 0000000000..d8351a80c8 --- /dev/null +++ b/src/reader/spirv/parser_impl_convert_member_decoration_test.cc @@ -0,0 +1,85 @@ +// Copyright 2020 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "spirv/unified1/spirv.h" +#include "src/ast/struct_member_decoration.h" +#include "src/ast/struct_member_offset_decoration.h" +#include "src/reader/spirv/parser_impl.h" +#include "src/reader/spirv/parser_impl_test_helper.h" +#include "src/reader/spirv/spirv_tools_helpers_test.h" + +namespace tint { +namespace reader { +namespace spirv { +namespace { + +using ::testing::Eq; + +TEST_F(SpvParserTest, ConvertMemberDecoration_Empty) { + auto p = parser(std::vector{}); + + auto result = p->ConvertMemberDecoration({}); + EXPECT_EQ(result.get(), nullptr); + EXPECT_THAT(p->error(), Eq("malformed SPIR-V decoration: it's empty")); +} + +TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithoutOperand) { + auto p = parser(std::vector{}); + + auto result = p->ConvertMemberDecoration({SpvDecorationOffset}); + EXPECT_EQ(result.get(), nullptr); + EXPECT_THAT( + p->error(), + Eq("malformed Offset decoration: expected 1 literal operand, has 0")); +} + +TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithTooManyOperands) { + auto p = parser(std::vector{}); + + auto result = p->ConvertMemberDecoration({SpvDecorationOffset, 3, 4}); + EXPECT_EQ(result.get(), nullptr); + EXPECT_THAT( + p->error(), + Eq("malformed Offset decoration: expected 1 literal operand, has 2")); +} + +TEST_F(SpvParserTest, ConvertMemberDecoration_Offset) { + auto p = parser(std::vector{}); + + auto result = p->ConvertMemberDecoration({SpvDecorationOffset, 8}); + ASSERT_NE(result.get(), nullptr); + EXPECT_TRUE(result->IsOffset()); + auto* offset_deco = result->AsOffset(); + ASSERT_NE(offset_deco, nullptr); + EXPECT_EQ(offset_deco->offset(), 8); + EXPECT_TRUE(p->error().empty()); +} + +TEST_F(SpvParserTest, ConvertMemberDecoration_UnhandledDecoration) { + auto p = parser(std::vector{}); + + auto result = p->ConvertMemberDecoration({12345678}); + EXPECT_EQ(result.get(), nullptr); + EXPECT_THAT(p->error(), Eq("unhandled member decoration: 12345678")); +} + +} // namespace +} // namespace spirv +} // namespace reader +} // namespace tint diff --git a/src/reader/spirv/parser_impl_convert_type_test.cc b/src/reader/spirv/parser_impl_convert_type_test.cc index 72c6fc38ff..7dd24dffa5 100644 --- a/src/reader/spirv/parser_impl_convert_type_test.cc +++ b/src/reader/spirv/parser_impl_convert_type_test.cc @@ -17,8 +17,10 @@ #include #include "gmock/gmock.h" +#include "src/ast/struct.h" #include "src/ast/type/array_type.h" #include "src/ast/type/matrix_type.h" +#include "src/ast/type/struct_type.h" #include "src/ast/type/vector_type.h" #include "src/reader/spirv/parser_impl.h" #include "src/reader/spirv/parser_impl_test_helper.h" @@ -415,6 +417,75 @@ TEST_F(SpvParserTest, ConvertType_ArrayBadTooBig) { EXPECT_THAT(p->error(), Eq("unhandled integer width: 64")); } +TEST_F(SpvParserTest, ConvertType_StructTwoMembers) { + auto p = parser(test::Assemble(R"( + %uint = OpTypeInt 32 0 + %float = OpTypeFloat 32 + %10 = OpTypeStruct %uint %float + )")); + EXPECT_TRUE(p->BuildAndParseInternalModule()); + + auto* type = p->ConvertType(10); + ASSERT_NE(type, nullptr) << p->error(); + EXPECT_TRUE(type->IsStruct()); + std::stringstream ss; + type->AsStruct()->impl()->to_str(ss, 0); + EXPECT_THAT(ss.str(), Eq(R"(Struct{ + StructMember{field0: __u32} + StructMember{field1: __f32} +} +)")); +} + +TEST_F(SpvParserTest, ConvertType_StructWithBlockDecoration) { + auto p = parser(test::Assemble(R"( + OpDecorate %10 Block + %uint = OpTypeInt 32 0 + %10 = OpTypeStruct %uint + )")); + EXPECT_TRUE(p->BuildAndParseInternalModule()); + + auto* type = p->ConvertType(10); + ASSERT_NE(type, nullptr); + EXPECT_TRUE(type->IsStruct()); + std::stringstream ss; + type->AsStruct()->impl()->to_str(ss, 0); + EXPECT_THAT(ss.str(), Eq(R"([[block]] Struct{ + StructMember{field0: __u32} +} +)")); +} + +TEST_F(SpvParserTest, ConvertType_StructWithMemberDecorations) { + auto p = parser(test::Assemble(R"( + OpMemberDecorate %10 0 Offset 0 + OpMemberDecorate %10 1 Offset 8 + OpMemberDecorate %10 2 Offset 16 + %float = OpTypeFloat 32 + %vec = OpTypeVector %float 2 + %mat = OpTypeMatrix %vec 2 + %10 = OpTypeStruct %float %vec %mat + )")); + EXPECT_TRUE(p->BuildAndParseInternalModule()); + + auto* type = p->ConvertType(10); + ASSERT_NE(type, nullptr) << p->error(); + EXPECT_TRUE(type->IsStruct()); + std::stringstream ss; + type->AsStruct()->impl()->to_str(ss, 0); + EXPECT_THAT(ss.str(), Eq(R"(Struct{ + StructMember{[[ offset 0 ]] field0: __f32} + StructMember{[[ offset 8 ]] field1: __vec_2__f32} + StructMember{[[ offset 16 ]] field2: __mat_2_2__f32} +} +)")); +} + +// TODO(dneto): Demonstrate other member deocrations. Blocked on +// crbug.com/tint/30 +// TODO(dneto): Demonstrate multiple member deocrations. Blocked on +// crbug.com/tint/30 + } // namespace } // namespace spirv } // namespace reader