diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7b96830da9..1c0d736982 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -325,6 +325,7 @@ if(${TINT_BUILD_SPV_READER}) reader/spirv/parser_impl_convert_member_decoration_test.cc reader/spirv/parser_impl_convert_type_test.cc reader/spirv/parser_impl_entry_point_test.cc + reader/spirv/parser_impl_function_decl_test.cc reader/spirv/parser_impl_get_decorations_test.cc reader/spirv/parser_impl_import_test.cc reader/spirv/parser_impl_module_var_test.cc diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc index e843e72299..e1641c69a4 100644 --- a/src/reader/spirv/parser_impl.cc +++ b/src/reader/spirv/parser_impl.cc @@ -19,10 +19,14 @@ #include #include #include +#include +#include #include +#include "source/opt/basic_block.h" #include "source/opt/build_module.h" #include "source/opt/decoration_manager.h" +#include "source/opt/function.h" #include "source/opt/instruction.h" #include "source/opt/module.h" #include "source/opt/type_manager.h" @@ -60,6 +64,53 @@ namespace { const spv_target_env kTargetEnv = SPV_ENV_WEBGPU_0; +// A FunctionTraverser is used to compute an ordering of functions in the +// module such that callees precede callers. +class FunctionTraverser { + public: + explicit FunctionTraverser(const spvtools::opt::Module& module) + : module_(module) {} + + // @returns the functions in the modules such that callees precede callers. + std::vector TopologicallyOrderedFunctions() { + visited_.clear(); + ordered_.clear(); + id_to_func_.clear(); + for (const auto& f : module_) { + id_to_func_[f.result_id()] = &f; + } + for (const auto& f : module_) { + Visit(f); + } + return ordered_; + } + + private: + void Visit(const spvtools::opt::Function& f) { + if (visited_.count(&f)) { + return; + } + visited_.insert(&f); + for (const auto& bb : f) { + for (const auto& inst : bb) { + if (inst.opcode() != SpvOpFunctionCall) { + continue; + } + const auto* callee = id_to_func_[inst.GetSingleWordInOperand(0)]; + if (callee) { + Visit(*callee); + } + } + } + ordered_.push_back(&f); + } + + const spvtools::opt::Module& module_; + std::unordered_set visited_; + std::unordered_map id_to_func_; + std::vector ordered_; +}; + } // namespace ParserImpl::ParserImpl(Context* ctx, const std::vector& spv_binary) @@ -174,7 +225,10 @@ ast::type::Type* ParserImpl::ConvertType(uint32_t type_id) { case spvtools::opt::analysis::Type::kPointer: return save(ConvertType(spirv_type->AsPointer())); case spvtools::opt::analysis::Type::kFunction: - // TODO(dneto). For now return null without erroring out. + // Tint doesn't have a Function type. + // We need to convert the result type and parameter types. + // But the SPIR-V defines those before defining the function + // type. No further work is required here. return nullptr; default: break; @@ -299,8 +353,10 @@ bool ParserImpl::ParseInternalModule() { if (!EmitModuleScopeVariables()) { return false; } - // TODO(dneto): fill in the rest - return true; + if (!EmitFunctions()) { + return false; + } + return success_; } bool ParserImpl::RegisterExtendedInstructionImports() { @@ -635,39 +691,8 @@ bool ParserImpl::EmitModuleScopeVariables() { << var.type_id(); } auto* ast_store_type = ast_type->AsPointer()->type(); - if (!namer_.HasName(var.result_id())) { - namer_.SuggestSanitizedName(var.result_id(), - "x_" + std::to_string(var.result_id())); - } - auto ast_var = std::make_unique( - namer_.GetName(var.result_id()), ast_storage_class, ast_store_type); - - std::vector> ast_decorations; - for (auto& deco : GetDecorationsFor(var.result_id())) { - if (deco.empty()) { - return Fail() << "malformed decoration on ID " << var.result_id() - << ": it is empty"; - } - if (deco[0] == SpvDecorationBuiltIn) { - if (deco.size() == 1) { - return Fail() << "malformed BuiltIn decoration on ID " - << var.result_id() << ": has no operand"; - } - auto ast_builtin = - enum_converter_.ToBuiltin(static_cast(deco[1])); - if (ast_builtin == ast::Builtin::kNone) { - return false; - } - ast_decorations.emplace_back( - std::make_unique(ast_builtin)); - } - } - if (!ast_decorations.empty()) { - auto decorated_var = - std::make_unique(std::move(ast_var)); - decorated_var->set_decorations(std::move(ast_decorations)); - ast_var = std::move(decorated_var); - } + auto ast_var = + MakeVariable(var.result_id(), ast_storage_class, ast_store_type); // TODO(dneto): initializers (a.k.a. constructor expression) ast_module_.AddGlobalVariable(std::move(ast_var)); @@ -675,6 +700,99 @@ bool ParserImpl::EmitModuleScopeVariables() { return success_; } +std::unique_ptr ParserImpl::MakeVariable(uint32_t id, + ast::StorageClass sc, + ast::type::Type* type) { + if (type == nullptr) { + Fail() << "internal error: can't make ast::Variable for null type"; + return nullptr; + } + + auto ast_var = std::make_unique(Name(id), sc, type); + + std::vector> ast_decorations; + for (auto& deco : GetDecorationsFor(id)) { + if (deco.empty()) { + Fail() << "malformed decoration on ID " << id << ": it is empty"; + return nullptr; + } + if (deco[0] == SpvDecorationBuiltIn) { + if (deco.size() == 1) { + Fail() << "malformed BuiltIn decoration on ID " << id + << ": has no operand"; + return nullptr; + } + auto ast_builtin = + enum_converter_.ToBuiltin(static_cast(deco[1])); + if (ast_builtin == ast::Builtin::kNone) { + return nullptr; + } + ast_decorations.emplace_back( + std::make_unique(ast_builtin)); + } + } + if (!ast_decorations.empty()) { + auto decorated_var = + std::make_unique(std::move(ast_var)); + decorated_var->set_decorations(std::move(ast_decorations)); + ast_var = std::move(decorated_var); + } + return ast_var; +} + +bool ParserImpl::EmitFunctions() { + if (!success_) { + return false; + } + for (const auto* f : + FunctionTraverser(*module_).TopologicallyOrderedFunctions()) { + EmitFunction(*f); + } + return success_; +} + +bool ParserImpl::EmitFunction(const spvtools::opt::Function& f) { + if (!success_) { + return false; + } + // We only care about functions with bodies. + if (f.cbegin() == f.cend()) { + return true; + } + + const auto name = Name(f.result_id()); + // Surprisingly, the "type id" on an OpFunction is the result type of the + // function, not the type of the function. This is the one exceptional case + // in SPIR-V where the type ID is not the type of the result ID. + auto* ret_ty = ConvertType(f.type_id()); + if (!success_) { + return false; + } + if (ret_ty == nullptr) { + return Fail() + << "internal error: unregistered return type for function with ID " + << f.result_id(); + } + + std::vector> ast_params; + f.ForEachParam([this, &ast_params](const spvtools::opt::Instruction* param) { + auto* ast_type = ConvertType(param->type_id()); + if (ast_type != nullptr) { + ast_params.emplace_back(std::move(MakeVariable( + param->result_id(), ast::StorageClass::kNone, ast_type))); + } + }); + if (!success_) { + return false; + } + + auto ast_fn = + std::make_unique(name, std::move(ast_params), ret_ty); + ast_module_.AddFunction(std::move(ast_fn)); + + return success_; +} + } // namespace spirv } // namespace reader } // namespace tint diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h index 2c6a0982e8..f9a6e4553d 100644 --- a/src/reader/spirv/parser_impl.h +++ b/src/reader/spirv/parser_impl.h @@ -178,7 +178,26 @@ class ParserImpl : Reader { /// @returns true if parser is still successful. bool EmitModuleScopeVariables(); + /// Emits functions, with callees preceding their callers. + /// This is a no-op if the parser has already failed. + /// @returns true if parser is still successful. + bool EmitFunctions(); + + /// Emits a single function, if it has a body. + /// This is a no-op if the parser has already failed. + /// @param f the function to emit + /// @returns true if parser is still successful. + bool EmitFunction(const spvtools::opt::Function& f); + private: + /// @returns a name for the given ID. Generates a name if non exists. + std::string Name(uint32_t id) { + if (!namer_.HasName(id)) { + namer_.SuggestSanitizedName(id, "x_" + std::to_string(id)); + } + return namer_.GetName(id); + } + /// Converts a specific SPIR-V type to a Tint type. Integer case ast::type::Type* ConvertType(const spvtools::opt::analysis::Integer* int_ty); /// Converts a specific SPIR-V type to a Tint type. Float case @@ -198,6 +217,16 @@ class ParserImpl : Reader { /// Converts a specific SPIR-V type to a Tint type. Pointer case ast::type::Type* ConvertType(const spvtools::opt::analysis::Pointer* ptr_ty); + /// Creates an AST Variable node for a SPIR-V ID, including any attached + /// decorations. + /// @param id the SPIR-V result ID + /// @param sc the storage class, which can be ast::StorageClass::kNone + /// @param type the type + /// @returns a new Variable node, or null in the error case + std::unique_ptr MakeVariable(uint32_t id, + ast::StorageClass sc, + ast::type::Type* type); + // The SPIR-V binary we're parsing std::vector spv_binary_; diff --git a/src/reader/spirv/parser_impl_function_decl_test.cc b/src/reader/spirv/parser_impl_function_decl_test.cc new file mode 100644 index 0000000000..e95cab8d39 --- /dev/null +++ b/src/reader/spirv/parser_impl_function_decl_test.cc @@ -0,0 +1,224 @@ +// Copyright 2020 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "src/reader/spirv/parser_impl.h" +#include "src/reader/spirv/parser_impl_test_helper.h" +#include "src/reader/spirv/spirv_tools_helpers_test.h" + +namespace tint { +namespace reader { +namespace spirv { +namespace { + +using ::testing::HasSubstr; + +/// @returns a SPIR-V assembly segment which assigns debug names +/// to particular IDs. +std::string Names(std::vector ids) { + std::ostringstream outs; + for (auto& id : ids) { + outs << " OpName %" << id << " \"" << id << "\"\n"; + } + return outs.str(); +} + +std::string CommonTypes() { + return R"( + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %uint = OpTypeInt 32 0 + %int = OpTypeInt 32 1 + %float_0 = OpConstant %float 0.0 + )"; +} + +TEST_F(SpvParserTest, EmitFunctions_NoFunctions) { + auto p = parser(test::Assemble(CommonTypes())); + EXPECT_TRUE(p->BuildAndParseInternalModule()); + EXPECT_TRUE(p->error().empty()); + const auto module_ast = p->module().to_str(); + EXPECT_THAT(module_ast, Not(HasSubstr("Function{"))); +} + +TEST_F(SpvParserTest, EmitFunctions_FunctionWithoutBody) { + auto p = parser(test::Assemble(Names({"main"}) + CommonTypes() + R"( + %main = OpFunction %void None %voidfn + OpFunctionEnd + )")); + EXPECT_TRUE(p->BuildAndParseInternalModule()); + EXPECT_TRUE(p->error().empty()); + const auto module_ast = p->module().to_str(); + EXPECT_THAT(module_ast, Not(HasSubstr("Function{"))); +} + +TEST_F(SpvParserTest, EmitFunctions_VoidFunctionWithoutParams) { + auto p = parser(test::Assemble(Names({"main"}) + CommonTypes() + R"( + %main = OpFunction %void None %voidfn + %entry = OpLabel + OpReturn + OpFunctionEnd + )")); + EXPECT_TRUE(p->BuildAndParseInternalModule()); + EXPECT_TRUE(p->error().empty()); + const auto module_ast = p->module().to_str(); + EXPECT_THAT(module_ast, HasSubstr(R"( + Function main -> __void + () + {)")); +} + +TEST_F(SpvParserTest, EmitFunctions_CalleePrecedesCaller) { + auto p = parser( + test::Assemble(Names({"root", "branch", "leaf"}) + CommonTypes() + R"( + %root = OpFunction %void None %voidfn + %root_entry = OpLabel + %branch_result = OpFunctionCall %void %branch + OpReturn + OpFunctionEnd + + %branch = OpFunction %void None %voidfn + %branch_entry = OpLabel + %leaf_result = OpFunctionCall %void %leaf + OpReturn + OpFunctionEnd + + %leaf = OpFunction %void None %voidfn + %leaf_entry = OpLabel + OpReturn + OpFunctionEnd + )")); + EXPECT_TRUE(p->BuildAndParseInternalModule()); + EXPECT_TRUE(p->error().empty()); + const auto module_ast = p->module().to_str(); + EXPECT_THAT(module_ast, HasSubstr(R"( + Function leaf -> __void + () + { + } + Function branch -> __void + () + { + } + Function root -> __void + () + { + })")); +} + +TEST_F(SpvParserTest, EmitFunctions_NonVoidResultType) { + auto p = parser(test::Assemble(Names({"ret_float"}) + CommonTypes() + R"( + %fn_ret_float = OpTypeFunction %float + + %ret_float = OpFunction %float None %nf_ret_float + %ret_float_entry = OpLabel + OpReturnValue %float_0 + OpFunctionEnd + )")); + EXPECT_TRUE(p->BuildAndParseInternalModule()); + EXPECT_TRUE(p->error().empty()); + const auto module_ast = p->module().to_str(); + EXPECT_THAT(module_ast, HasSubstr(R"( + Function ret_float -> __f32 + () + { + })")); +} + +TEST_F(SpvParserTest, EmitFunctions_MixedParamTypes) { + auto p = parser(test::Assemble(Names({"mixed_params", "a", "b", "c"}) + + CommonTypes() + R"( + %fn_mixed_params = OpTypeFunction %float %uint %float %int + + %mixed_params = OpFunction %void None %fn_mixed_params + %a = OpFunctionParameter %uint + %b = OpFunctionParameter %float + %c = OpFunctionParameter %int + %mixed_entry = OpLabel + OpReturn + OpFunctionEnd + )")); + EXPECT_TRUE(p->BuildAndParseInternalModule()); + EXPECT_TRUE(p->error().empty()); + const auto module_ast = p->module().to_str(); + EXPECT_THAT(module_ast, HasSubstr(R"( + Function mixed_params -> __void + ( + Variable{ + a + none + __u32 + } + Variable{ + b + none + __f32 + } + Variable{ + c + none + __i32 + } + ) + { + })")); +} + +TEST_F(SpvParserTest, EmitFunctions_GenerateParamNames) { + auto p = parser(test::Assemble(Names({"mixed_params"}) + CommonTypes() + R"( + %fn_mixed_params = OpTypeFunction %float %uint %float %int + + %mixed_params = OpFunction %void None %fn_mixed_params + %14 = OpFunctionParameter %uint + %15 = OpFunctionParameter %float + %16 = OpFunctionParameter %int + %mixed_entry = OpLabel + OpReturn + OpFunctionEnd + )")); + EXPECT_TRUE(p->BuildAndParseInternalModule()); + EXPECT_TRUE(p->error().empty()); + const auto module_ast = p->module().to_str(); + EXPECT_THAT(module_ast, HasSubstr(R"( + Function mixed_params -> __void + ( + Variable{ + x_14 + none + __u32 + } + Variable{ + x_15 + none + __f32 + } + Variable{ + x_16 + none + __i32 + } + ) + { + })")); +} + +} // namespace +} // namespace spirv +} // namespace reader +} // namespace tint