[spirv-reader] Convert struct types

Handle as many member decorations as the Tint AST can express right now.
See crbug.com/tint/30

Bug: tint:3
Change-Id: I6d04f1beb438b3d952a76886fbd9c6b7ea701d81
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/18160
Reviewed-by: dan sinclair <dsinclair@google.com>
This commit is contained in:
David Neto 2020-03-30 21:14:28 +00:00
parent 4da4c696e2
commit e68c9b4075
5 changed files with 387 additions and 120 deletions

View File

@ -322,6 +322,7 @@ if(${TINT_BUILD_SPV_READER})
reader/spirv/enum_converter_test.cc
reader/spirv/fail_stream_test.cc
reader/spirv/namer_test.cc
reader/spirv/parser_impl_convert_member_decoration_test.cc
reader/spirv/parser_impl_convert_type_test.cc
reader/spirv/parser_impl_entry_point_test.cc
reader/spirv/parser_impl_get_decorations_test.cc

View File

@ -25,12 +25,19 @@
#include "source/opt/instruction.h"
#include "source/opt/module.h"
#include "source/opt/type_manager.h"
#include "source/opt/types.h"
#include "spirv-tools/libspirv.hpp"
#include "src/ast/struct.h"
#include "src/ast/struct_decoration.h"
#include "src/ast/struct_member.h"
#include "src/ast/struct_member_decoration.h"
#include "src/ast/struct_member_offset_decoration.h"
#include "src/ast/type/array_type.h"
#include "src/ast/type/bool_type.h"
#include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h"
#include "src/ast/type/matrix_type.h"
#include "src/ast/type/struct_type.h"
#include "src/ast/type/type.h"
#include "src/ast/type/u32_type.h"
#include "src/ast/type/vector_type.h"
@ -127,130 +134,38 @@ ast::type::Type* ParserImpl::ConvertType(uint32_t type_id) {
return nullptr;
}
ast::type::Type* result = nullptr;
auto save = [this, type_id](ast::type::Type* type) {
if (type != nullptr) {
id_to_type_[type_id] = type;
}
return type;
};
switch (spirv_type->kind()) {
case spvtools::opt::analysis::Type::kVoid:
result = ctx_.type_mgr().Get(std::make_unique<ast::type::VoidType>());
break;
return save(ctx_.type_mgr().Get(std::make_unique<ast::type::VoidType>()));
case spvtools::opt::analysis::Type::kBool:
result = ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
break;
case spvtools::opt::analysis::Type::kInteger: {
const auto* int_ty = spirv_type->AsInteger();
if (int_ty->width() == 32) {
if (int_ty->IsSigned()) {
result = ctx_.type_mgr().Get(std::make_unique<ast::type::I32Type>());
} else {
result = ctx_.type_mgr().Get(std::make_unique<ast::type::U32Type>());
}
} else {
Fail() << "unhandled integer width: " << int_ty->width();
}
break;
}
case spvtools::opt::analysis::Type::kFloat: {
const auto* float_ty = spirv_type->AsFloat();
if (float_ty->width() == 32) {
result = ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>());
} else {
Fail() << "unhandled float width: " << float_ty->width();
}
break;
}
case spvtools::opt::analysis::Type::kVector: {
const auto* vec_ty = spirv_type->AsVector();
const auto num_elem = vec_ty->element_count();
auto* ast_elem_ty = ConvertType(type_mgr_->GetId(vec_ty->element_type()));
if (ast_elem_ty != nullptr) {
result = ctx_.type_mgr().Get(
std::make_unique<ast::type::VectorType>(ast_elem_ty, num_elem));
}
// In the error case, we'll already have emitted a diagnostic.
break;
}
case spvtools::opt::analysis::Type::kMatrix: {
const auto* mat_ty = spirv_type->AsMatrix();
const auto* vec_ty = mat_ty->element_type()->AsVector();
const auto* scalar_ty = vec_ty->element_type();
const auto num_rows = vec_ty->element_count();
const auto num_columns = mat_ty->element_count();
auto* ast_scalar_ty = ConvertType(type_mgr_->GetId(scalar_ty));
if (ast_scalar_ty != nullptr) {
result = ctx_.type_mgr().Get(std::make_unique<ast::type::MatrixType>(
ast_scalar_ty, num_rows, num_columns));
}
// In the error case, we'll already have emitted a diagnostic.
break;
}
case spvtools::opt::analysis::Type::kRuntimeArray: {
// TODO(dneto): Handle ArrayStride. Blocked by crbug.com/tint/30
const auto* rtarr_ty = spirv_type->AsRuntimeArray();
auto* ast_elem_ty =
ConvertType(type_mgr_->GetId(rtarr_ty->element_type()));
if (ast_elem_ty != nullptr) {
result = ctx_.type_mgr().Get(
std::make_unique<ast::type::ArrayType>(ast_elem_ty));
}
// In the error case, we'll already have emitted a diagnostic.
break;
}
case spvtools::opt::analysis::Type::kArray: {
// TODO(dneto): Handle ArrayStride. Blocked by crbug.com/tint/30
const auto* arr_ty = spirv_type->AsArray();
auto* ast_elem_ty = ConvertType(type_mgr_->GetId(arr_ty->element_type()));
if (ast_elem_ty == nullptr) {
// In the error case, we'll already have emitted a diagnostic.
break;
}
const auto& length_info = arr_ty->length_info();
if (length_info.words.empty()) {
// The internal representation is invalid. The discriminant vector
// is mal-formed.
Fail() << "internal error: Array length info is invalid";
return nullptr;
}
if (length_info.words[0] !=
spvtools::opt::analysis::Array::LengthInfo::kConstant) {
Fail() << "Array type " << type_id
<< " length is a specialization constant";
return nullptr;
}
const auto* constant =
constant_mgr_->FindDeclaredConstant(length_info.id);
if (constant == nullptr) {
Fail() << "Array type " << type_id << " length ID " << length_info.id
<< " does not name an OpConstant";
return nullptr;
}
const uint64_t num_elem = constant->GetZeroExtendedValue();
// For now, limit to only 32bits.
if (num_elem > std::numeric_limits<uint32_t>::max()) {
Fail() << "Array type " << type_id
<< " has too many elements (more than can fit in 32 bits): "
<< num_elem;
return nullptr;
}
result = ctx_.type_mgr().Get(std::make_unique<ast::type::ArrayType>(
ast_elem_ty, static_cast<uint32_t>(num_elem)));
break;
}
return save(ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>()));
case spvtools::opt::analysis::Type::kInteger:
return save(ConvertType(spirv_type->AsInteger()));
case spvtools::opt::analysis::Type::kFloat:
return save(ConvertType(spirv_type->AsFloat()));
case spvtools::opt::analysis::Type::kVector:
return save(ConvertType(spirv_type->AsVector()));
case spvtools::opt::analysis::Type::kMatrix:
return save(ConvertType(spirv_type->AsMatrix()));
case spvtools::opt::analysis::Type::kRuntimeArray:
return save(ConvertType(spirv_type->AsRuntimeArray()));
case spvtools::opt::analysis::Type::kArray:
return save(ConvertType(spirv_type->AsArray()));
case spvtools::opt::analysis::Type::kStruct:
return save(ConvertType(spirv_type->AsStruct()));
default:
// The error diagnostic will be generated below because result is still
// nullptr.
break;
}
if (result == nullptr) {
if (success_) {
// Only emit a new diagnostic if we haven't already emitted a more
// specific one.
Fail() << "unknown SPIR-V type: " << type_id;
}
} else {
id_to_type_[type_id] = result;
}
return result;
Fail() << "unknown SPIR-V type: " << type_id;
return nullptr;
}
DecorationList ParserImpl::GetDecorationsFor(uint32_t id) const {
@ -289,6 +204,29 @@ DecorationList ParserImpl::GetDecorationsForMember(
return result;
}
std::unique_ptr<ast::StructMemberDecoration>
ParserImpl::ConvertMemberDecoration(const Decoration& decoration) {
if (decoration.empty()) {
Fail() << "malformed SPIR-V decoration: it's empty";
return nullptr;
}
switch (decoration[0]) {
case SpvDecorationOffset:
if (decoration.size() != 2) {
Fail()
<< "malformed Offset decoration: expected 1 literal operand, has "
<< decoration.size() - 1;
return nullptr;
}
return std::make_unique<ast::StructMemberOffsetDecoration>(decoration[1]);
default:
// TODO(dneto): Support the remaining member decorations.
break;
}
Fail() << "unhandled member decoration: " << decoration[0];
return nullptr;
}
bool ParserImpl::BuildInternalModule() {
tools_.SetMessageConsumer(message_consumer_);
@ -412,6 +350,153 @@ bool ParserImpl::EmitEntryPoints() {
return success_;
}
ast::type::Type* ParserImpl::ConvertType(
const spvtools::opt::analysis::Integer* int_ty) {
if (int_ty->width() == 32) {
if (int_ty->IsSigned()) {
return ctx_.type_mgr().Get(std::make_unique<ast::type::I32Type>());
} else {
return ctx_.type_mgr().Get(std::make_unique<ast::type::U32Type>());
}
}
Fail() << "unhandled integer width: " << int_ty->width();
return nullptr;
}
ast::type::Type* ParserImpl::ConvertType(
const spvtools::opt::analysis::Float* float_ty) {
if (float_ty->width() == 32) {
return ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>());
}
Fail() << "unhandled float width: " << float_ty->width();
return nullptr;
}
ast::type::Type* ParserImpl::ConvertType(
const spvtools::opt::analysis::Vector* vec_ty) {
const auto num_elem = vec_ty->element_count();
auto* ast_elem_ty = ConvertType(type_mgr_->GetId(vec_ty->element_type()));
if (ast_elem_ty == nullptr) {
return nullptr;
}
return ctx_.type_mgr().Get(
std::make_unique<ast::type::VectorType>(ast_elem_ty, num_elem));
}
ast::type::Type* ParserImpl::ConvertType(
const spvtools::opt::analysis::Matrix* mat_ty) {
const auto* vec_ty = mat_ty->element_type()->AsVector();
const auto* scalar_ty = vec_ty->element_type();
const auto num_rows = vec_ty->element_count();
const auto num_columns = mat_ty->element_count();
auto* ast_scalar_ty = ConvertType(type_mgr_->GetId(scalar_ty));
if (ast_scalar_ty == nullptr) {
return nullptr;
}
return ctx_.type_mgr().Get(std::make_unique<ast::type::MatrixType>(
ast_scalar_ty, num_rows, num_columns));
}
ast::type::Type* ParserImpl::ConvertType(
const spvtools::opt::analysis::RuntimeArray* rtarr_ty) {
// TODO(dneto): Handle ArrayStride. Blocked by crbug.com/tint/30
auto* ast_elem_ty = ConvertType(type_mgr_->GetId(rtarr_ty->element_type()));
if (ast_elem_ty == nullptr) {
return nullptr;
}
return ctx_.type_mgr().Get(
std::make_unique<ast::type::ArrayType>(ast_elem_ty));
}
ast::type::Type* ParserImpl::ConvertType(
const spvtools::opt::analysis::Array* arr_ty) {
// TODO(dneto): Handle ArrayStride. Blocked by crbug.com/tint/30
auto* ast_elem_ty = ConvertType(type_mgr_->GetId(arr_ty->element_type()));
if (ast_elem_ty == nullptr) {
return nullptr;
}
const auto& length_info = arr_ty->length_info();
if (length_info.words.empty()) {
// The internal representation is invalid. The discriminant vector
// is mal-formed.
Fail() << "internal error: Array length info is invalid";
return nullptr;
}
if (length_info.words[0] !=
spvtools::opt::analysis::Array::LengthInfo::kConstant) {
Fail() << "Array type " << type_mgr_->GetId(arr_ty)
<< " length is a specialization constant";
return nullptr;
}
const auto* constant = constant_mgr_->FindDeclaredConstant(length_info.id);
if (constant == nullptr) {
Fail() << "Array type " << type_mgr_->GetId(arr_ty) << " length ID "
<< length_info.id << " does not name an OpConstant";
return nullptr;
}
const uint64_t num_elem = constant->GetZeroExtendedValue();
// For now, limit to only 32bits.
if (num_elem > std::numeric_limits<uint32_t>::max()) {
Fail() << "Array type " << type_mgr_->GetId(arr_ty)
<< " has too many elements (more than can fit in 32 bits): "
<< num_elem;
return nullptr;
}
return ctx_.type_mgr().Get(std::make_unique<ast::type::ArrayType>(
ast_elem_ty, static_cast<uint32_t>(num_elem)));
}
ast::type::Type* ParserImpl::ConvertType(
const spvtools::opt::analysis::Struct* struct_ty) {
const auto type_id = type_mgr_->GetId(struct_ty);
// Compute the struct decoration.
auto struct_decorations = this->GetDecorationsFor(type_id);
auto ast_struct_decoration = ast::StructDecoration::kNone;
if (struct_decorations.size() == 1 &&
struct_decorations[0][0] == SpvDecorationBlock) {
ast_struct_decoration = ast::StructDecoration::kBlock;
} else if (struct_decorations.size() > 1) {
Fail() << "can't handle a struct with more than one decoration: struct "
<< type_id << " has " << struct_decorations.size();
return nullptr;
}
// Compute members
std::vector<std::unique_ptr<ast::StructMember>> ast_members;
const auto members = struct_ty->element_types();
for (size_t member_index = 0; member_index < members.size(); ++member_index) {
auto* ast_member_ty = ConvertType(type_mgr_->GetId(members[member_index]));
if (ast_member_ty == nullptr) {
// Already emitted diagnostics.
return nullptr;
}
std::vector<std::unique_ptr<ast::StructMemberDecoration>>
ast_member_decorations;
for (auto& deco : GetDecorationsForMember(type_id, member_index)) {
auto ast_member_decoration = ConvertMemberDecoration(deco);
if (ast_member_decoration == nullptr) {
// Already emitted diagnostics.
return nullptr;
}
ast_member_decorations.push_back(std::move(ast_member_decoration));
}
const auto member_name = namer_.GetMemberName(type_id, member_index);
auto ast_struct_member = std::make_unique<ast::StructMember>(
member_name, ast_member_ty, std::move(ast_member_decorations));
ast_members.push_back(std::move(ast_struct_member));
}
// Now make the struct.
auto ast_struct = std::make_unique<ast::Struct>(ast_struct_decoration,
std::move(ast_members));
auto ast_struct_type =
std::make_unique<ast::type::StructType>(std::move(ast_struct));
// The struct might not have a name yet. Suggest one.
namer_.SuggestSanitizedName(type_id, "S");
ast_struct_type->set_name(namer_.GetName(type_id));
return ctx_.type_mgr().Get(std::move(ast_struct_type));
}
} // namespace spirv
} // namespace reader
} // namespace tint

View File

@ -28,9 +28,11 @@
#include "source/opt/ir_context.h"
#include "source/opt/module.h"
#include "source/opt/type_manager.h"
#include "source/opt/types.h"
#include "spirv-tools/libspirv.hpp"
#include "src/ast/import.h"
#include "src/ast/module.h"
#include "src/ast/struct_member_decoration.h"
#include "src/ast/type/type.h"
#include "src/reader/reader.h"
#include "src/reader/spirv/enum_converter.h"
@ -91,10 +93,9 @@ class ParserImpl : Reader {
return glsl_std_450_imports_;
}
/// Converts a SPIR-V type to a Tint type.
/// On failure, logs an error and returns null.
/// This should only be called after the internal
/// representation of the module has been built.
/// Converts a SPIR-V type to a Tint type, and saves it for fast lookup.
/// On failure, logs an error and returns null. This should only be called
/// after the internal representation of the module has been built.
/// @param type_id the SPIR-V ID of a type.
/// @returns a Tint type, or nullptr
ast::type::Type* ConvertType(uint32_t type_id);
@ -118,6 +119,13 @@ class ParserImpl : Reader {
DecorationList GetDecorationsForMember(uint32_t id,
uint32_t member_index) const;
/// Converts a SPIR-V decoration. On failure, emits a diagnostic and returns
/// nullptr.
/// @param decoration an encoded SPIR-V Decoration
/// @returns the corresponding ast::StructuMemberDecoration
std::unique_ptr<ast::StructMemberDecoration> ConvertMemberDecoration(
const Decoration& decoration);
private:
/// Builds the internal representation of the SPIR-V module.
/// Assumes the module is somewhat well-formed. Normally you
@ -145,6 +153,23 @@ class ParserImpl : Reader {
/// Emit entry point AST nodes.
bool EmitEntryPoints();
/// Converts a specific SPIR-V type to a Tint type. Integer case
ast::type::Type* ConvertType(const spvtools::opt::analysis::Integer* int_ty);
/// Converts a specific SPIR-V type to a Tint type. Float case
ast::type::Type* ConvertType(const spvtools::opt::analysis::Float* float_ty);
/// Converts a specific SPIR-V type to a Tint type. Vector case
ast::type::Type* ConvertType(const spvtools::opt::analysis::Vector* vec_ty);
/// Converts a specific SPIR-V type to a Tint type. Matrix case
ast::type::Type* ConvertType(const spvtools::opt::analysis::Matrix* mat_ty);
/// Converts a specific SPIR-V type to a Tint type. RuntimeArray case
ast::type::Type* ConvertType(
const spvtools::opt::analysis::RuntimeArray* rtarr_ty);
/// Converts a specific SPIR-V type to a Tint type. Array case
ast::type::Type* ConvertType(const spvtools::opt::analysis::Array* arr_ty);
/// Converts a specific SPIR-V type to a Tint type. Struct case
ast::type::Type* ConvertType(
const spvtools::opt::analysis::Struct* struct_ty);
// The SPIR-V binary we're parsing
std::vector<uint32_t> spv_binary_;

View File

@ -0,0 +1,85 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstdint>
#include <memory>
#include <vector>
#include "gmock/gmock.h"
#include "spirv/unified1/spirv.h"
#include "src/ast/struct_member_decoration.h"
#include "src/ast/struct_member_offset_decoration.h"
#include "src/reader/spirv/parser_impl.h"
#include "src/reader/spirv/parser_impl_test_helper.h"
#include "src/reader/spirv/spirv_tools_helpers_test.h"
namespace tint {
namespace reader {
namespace spirv {
namespace {
using ::testing::Eq;
TEST_F(SpvParserTest, ConvertMemberDecoration_Empty) {
auto p = parser(std::vector<uint32_t>{});
auto result = p->ConvertMemberDecoration({});
EXPECT_EQ(result.get(), nullptr);
EXPECT_THAT(p->error(), Eq("malformed SPIR-V decoration: it's empty"));
}
TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithoutOperand) {
auto p = parser(std::vector<uint32_t>{});
auto result = p->ConvertMemberDecoration({SpvDecorationOffset});
EXPECT_EQ(result.get(), nullptr);
EXPECT_THAT(
p->error(),
Eq("malformed Offset decoration: expected 1 literal operand, has 0"));
}
TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithTooManyOperands) {
auto p = parser(std::vector<uint32_t>{});
auto result = p->ConvertMemberDecoration({SpvDecorationOffset, 3, 4});
EXPECT_EQ(result.get(), nullptr);
EXPECT_THAT(
p->error(),
Eq("malformed Offset decoration: expected 1 literal operand, has 2"));
}
TEST_F(SpvParserTest, ConvertMemberDecoration_Offset) {
auto p = parser(std::vector<uint32_t>{});
auto result = p->ConvertMemberDecoration({SpvDecorationOffset, 8});
ASSERT_NE(result.get(), nullptr);
EXPECT_TRUE(result->IsOffset());
auto* offset_deco = result->AsOffset();
ASSERT_NE(offset_deco, nullptr);
EXPECT_EQ(offset_deco->offset(), 8);
EXPECT_TRUE(p->error().empty());
}
TEST_F(SpvParserTest, ConvertMemberDecoration_UnhandledDecoration) {
auto p = parser(std::vector<uint32_t>{});
auto result = p->ConvertMemberDecoration({12345678});
EXPECT_EQ(result.get(), nullptr);
EXPECT_THAT(p->error(), Eq("unhandled member decoration: 12345678"));
}
} // namespace
} // namespace spirv
} // namespace reader
} // namespace tint

View File

@ -17,8 +17,10 @@
#include <vector>
#include "gmock/gmock.h"
#include "src/ast/struct.h"
#include "src/ast/type/array_type.h"
#include "src/ast/type/matrix_type.h"
#include "src/ast/type/struct_type.h"
#include "src/ast/type/vector_type.h"
#include "src/reader/spirv/parser_impl.h"
#include "src/reader/spirv/parser_impl_test_helper.h"
@ -415,6 +417,75 @@ TEST_F(SpvParserTest, ConvertType_ArrayBadTooBig) {
EXPECT_THAT(p->error(), Eq("unhandled integer width: 64"));
}
TEST_F(SpvParserTest, ConvertType_StructTwoMembers) {
auto p = parser(test::Assemble(R"(
%uint = OpTypeInt 32 0
%float = OpTypeFloat 32
%10 = OpTypeStruct %uint %float
)"));
EXPECT_TRUE(p->BuildAndParseInternalModule());
auto* type = p->ConvertType(10);
ASSERT_NE(type, nullptr) << p->error();
EXPECT_TRUE(type->IsStruct());
std::stringstream ss;
type->AsStruct()->impl()->to_str(ss, 0);
EXPECT_THAT(ss.str(), Eq(R"(Struct{
StructMember{field0: __u32}
StructMember{field1: __f32}
}
)"));
}
TEST_F(SpvParserTest, ConvertType_StructWithBlockDecoration) {
auto p = parser(test::Assemble(R"(
OpDecorate %10 Block
%uint = OpTypeInt 32 0
%10 = OpTypeStruct %uint
)"));
EXPECT_TRUE(p->BuildAndParseInternalModule());
auto* type = p->ConvertType(10);
ASSERT_NE(type, nullptr);
EXPECT_TRUE(type->IsStruct());
std::stringstream ss;
type->AsStruct()->impl()->to_str(ss, 0);
EXPECT_THAT(ss.str(), Eq(R"([[block]] Struct{
StructMember{field0: __u32}
}
)"));
}
TEST_F(SpvParserTest, ConvertType_StructWithMemberDecorations) {
auto p = parser(test::Assemble(R"(
OpMemberDecorate %10 0 Offset 0
OpMemberDecorate %10 1 Offset 8
OpMemberDecorate %10 2 Offset 16
%float = OpTypeFloat 32
%vec = OpTypeVector %float 2
%mat = OpTypeMatrix %vec 2
%10 = OpTypeStruct %float %vec %mat
)"));
EXPECT_TRUE(p->BuildAndParseInternalModule());
auto* type = p->ConvertType(10);
ASSERT_NE(type, nullptr) << p->error();
EXPECT_TRUE(type->IsStruct());
std::stringstream ss;
type->AsStruct()->impl()->to_str(ss, 0);
EXPECT_THAT(ss.str(), Eq(R"(Struct{
StructMember{[[ offset 0 ]] field0: __f32}
StructMember{[[ offset 8 ]] field1: __vec_2__f32}
StructMember{[[ offset 16 ]] field2: __mat_2_2__f32}
}
)"));
}
// TODO(dneto): Demonstrate other member deocrations. Blocked on
// crbug.com/tint/30
// TODO(dneto): Demonstrate multiple member deocrations. Blocked on
// crbug.com/tint/30
} // namespace
} // namespace spirv
} // namespace reader