diff --git a/BUILD.gn b/BUILD.gn index 3d6884789a..c79492cc0e 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -511,6 +511,7 @@ source_set("tint_unittests_spv_reader_src") { "src/reader/spirv/enum_converter_test.cc", "src/reader/spirv/fail_stream_test.cc", "src/reader/spirv/function_arithmetic_test.cc", + "src/reader/spirv/function_cfg_test.cc", "src/reader/spirv/function_conversion_test.cc", "src/reader/spirv/function_decl_test.cc", "src/reader/spirv/function_logical_test.cc", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8a490bea97..1aa773f04c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -326,6 +326,7 @@ if(${TINT_BUILD_SPV_READER}) reader/spirv/enum_converter_test.cc reader/spirv/fail_stream_test.cc reader/spirv/function_arithmetic_test.cc + reader/spirv/function_cfg_test.cc reader/spirv/function_conversion_test.cc reader/spirv/function_decl_test.cc reader/spirv/function_logical_test.cc diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index 9e1d8ef687..97c3d33e99 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -14,7 +14,10 @@ #include "src/reader/spirv/function.h" +#include +#include #include +#include #include "source/opt/basic_block.h" #include "source/opt/function.h" @@ -112,6 +115,97 @@ ast::BinaryOp ConvertBinaryOp(SpvOp opcode) { return ast::BinaryOp::kNone; } +// @returns the merge block ID for the given basic block, or 0 if there is none. +uint32_t MergeFor(const spvtools::opt::BasicBlock& bb) { + // Get the OpSelectionMerge or OpLoopMerge instruction, if any. + auto* inst = bb.GetMergeInst(); + return inst == nullptr ? 0 : inst->GetSingleWordInOperand(0); +} + +// @returns the continue target ID for the given basic block, or 0 if there +// is none. +uint32_t ContinueTargetFor(const spvtools::opt::BasicBlock& bb) { + // Get the OpLoopMerge instruction, if any. + auto* inst = bb.GetLoopMergeInst(); + return inst == nullptr ? 0 : inst->GetSingleWordInOperand(1); +} + +// A structured traverser produces the reverse structured post-order of the +// CFG of a function. The blocks traversed are the transitive closure (minimum +// fixed point) of: +// - the entry block +// - a block reached by a branch from another block in the set +// - a block mentioned as a merge block or continue target for a block in the +// set +class StructuredTraverser { + public: + explicit StructuredTraverser(const spvtools::opt::Function& function) + : function_(function) { + for (auto& block : function_) { + id_to_block_[block.id()] = █ + } + } + + // Returns the reverse postorder traversal of the CFG, where: + // - a merge block always follows its associated constructs + // - a continue target always follows the associated loop construct, if any + // @returns the IDs of blocks in reverse structured post order + std::vector ReverseStructuredPostOrder() { + visit_order_.clear(); + visited_.clear(); + VisitBackward(function_.entry()->id()); + + std::vector order(visit_order_.rbegin(), visit_order_.rend()); + return order; + } + + private: + // Executes a depth first search of the CFG, where right after we visit a + // header, we will visit its merge block, then its continue target (if any). + // Also records the post order ordering. + void VisitBackward(uint32_t id) { + if (id == 0) + return; + if (visited_.count(id)) + return; + visited_.insert(id); + + const spvtools::opt::BasicBlock* bb = + id_to_block_[id]; // non-null for valid modules + VisitBackward(MergeFor(*bb)); + VisitBackward(ContinueTargetFor(*bb)); + + // Visit successors. We will naturally skip the continue target and merge + // blocks. + auto* terminator = bb->terminator(); + auto opcode = terminator->opcode(); + if (opcode == SpvOpBranchConditional) { + // Visit the false branch, then the true branch, to make them come + // out in the natural order for an "if". + VisitBackward(terminator->GetSingleWordInOperand(2)); + VisitBackward(terminator->GetSingleWordInOperand(1)); + } else if (opcode == SpvOpBranch) { + VisitBackward(terminator->GetSingleWordInOperand(0)); + } else if (opcode == SpvOpSwitch) { + // TODO(dneto): Consider visiting the labels in literal-value order. + std::vector successors; + bb->ForEachSuccessorLabel([&successors](const uint32_t succ_id) { + successors.push_back(succ_id); + }); + for (auto succ_id : successors) { + VisitBackward(succ_id); + } + } + + visit_order_.push_back(id); + } + + const spvtools::opt::Function& function_; + std::unordered_map id_to_block_; + std::vector visit_order_; + std::unordered_set visited_; +}; + } // namespace FunctionEmitter::FunctionEmitter(ParserImpl* pi, @@ -213,6 +307,8 @@ ast::type::Type* FunctionEmitter::GetVariableStoreType( } bool FunctionEmitter::EmitBody() { + ComputeBlockOrderAndPositions(); + if (!EmitFunctionVariables()) { return false; } @@ -222,6 +318,18 @@ bool FunctionEmitter::EmitBody() { return success(); } +void FunctionEmitter::ComputeBlockOrderAndPositions() { + for (auto& block : function_) { + block_info_[block.id()] = std::make_unique(block); + } + + rspo_ = StructuredTraverser(function_).ReverseStructuredPostOrder(); + + for (uint32_t i = 0; i < rspo_.size(); ++i) { + GetBlockInfo(rspo_[i])->pos = i; + } +} + bool FunctionEmitter::EmitFunctionVariables() { if (failed()) { return false; diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h index 433b078f39..48367a3d53 100644 --- a/src/reader/spirv/function.h +++ b/src/reader/spirv/function.h @@ -36,6 +36,23 @@ namespace tint { namespace reader { namespace spirv { +/// Bookkeeping info for a basic block. +struct BlockInfo { + /// Constructor + /// @param bb internal representation of the basic block + explicit BlockInfo(const spvtools::opt::BasicBlock& bb) + : basic_block(&bb), id(bb.id()) {} + + /// The internal representation of the basic block. + const spvtools::opt::BasicBlock* basic_block; + + /// The ID of the OpLabel instruction that starts this block. + uint32_t id = 0; + + /// The position of this block in the reverse structured post-order. + uint32_t pos = 0; +}; + /// A FunctionEmitter emits a SPIR-V function onto a Tint AST module. class FunctionEmitter { public: @@ -73,6 +90,14 @@ class FunctionEmitter { /// @returns false if emission failed. bool EmitBody(); + /// Determines the output order for the basic blocks in the function. + /// Populates |rspo_| and the |pos| block info member. + void ComputeBlockOrderAndPositions(); + + /// @returns the reverse structured post order of the basic blocks in + /// the function. + const std::vector& rspo() const { return rspo_; } + /// Emits declarations of function variables. /// @returns false if emission failed. bool EmitFunctionVariables(); @@ -116,6 +141,16 @@ class FunctionEmitter { TypedExpression MaybeEmitCombinatorialValue( const spvtools::opt::Instruction& inst); + /// Gets the block info for a block ID, if any exists + /// @param id the SPIR-V ID of the OpLabel instruction starting the block + /// @returns the block info for the given ID, if it exists, or nullptr + BlockInfo* GetBlockInfo(uint32_t id) { + auto where = block_info_.find(id); + if (where == block_info_.end()) + return nullptr; + return where->second.get(); + } + private: /// @returns the store type for the OpVariable instruction, or /// null on failure. @@ -136,6 +171,13 @@ class FunctionEmitter { std::unordered_set identifier_values_; // Mapping from SPIR-V ID that is used at most once, to its AST expression. std::unordered_map singly_used_values_; + + // The IDs of basic blocks, in reverse structured post-order (RSPO). + // This is the output order for the basic blocks. + std::vector rspo_; + + // Mapping from block ID to its bookkeeping info. + std::unordered_map> block_info_; }; } // namespace spirv diff --git a/src/reader/spirv/function_cfg_test.cc b/src/reader/spirv/function_cfg_test.cc new file mode 100644 index 0000000000..d6b6d12a60 --- /dev/null +++ b/src/reader/spirv/function_cfg_test.cc @@ -0,0 +1,413 @@ +// Copyright 2020 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "src/reader/spirv/function.h" +#include "src/reader/spirv/parser_impl.h" +#include "src/reader/spirv/parser_impl_test_helper.h" +#include "src/reader/spirv/spirv_tools_helpers_test.h" + +namespace tint { +namespace reader { +namespace spirv { +namespace { + +using ::testing::ElementsAre; + +std::string CommonTypes() { + return R"( + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + + %bool = OpTypeBool + %cond = OpUndef %bool + + %uint = OpTypeInt 32 0 + %selector = OpUndef %uint + )"; +} + +TEST_F(SpvParserTest, ComputeBlockOrder_OneBlock) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %42 = OpLabel + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + fe.ComputeBlockOrderAndPositions(); + + EXPECT_THAT(fe.rspo(), ElementsAre(42)); +} + +TEST_F(SpvParserTest, ComputeBlockOrder_IgnoreStaticalyUnreachable) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpBranch %20 + + %15 = OpLabel ; statically dead + OpReturn + + %20 = OpLabel + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + fe.ComputeBlockOrderAndPositions(); + + EXPECT_THAT(fe.rspo(), ElementsAre(10, 20)); +} + +TEST_F(SpvParserTest, ComputeBlockOrder_ReorderSequence) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpBranch %20 + + %30 = OpLabel + OpReturn + + %20 = OpLabel + OpBranch %30 ; backtrack + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + fe.ComputeBlockOrderAndPositions(); + + EXPECT_THAT(fe.rspo(), ElementsAre(10, 20, 30)); +} + +TEST_F(SpvParserTest, ComputeBlockOrder_RespectConditionalBranchOrder) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpSelectionMerge %99 None + OpBranchConditional %cond %20 %30 + + %99 = OpLabel + OpReturn + + %30 = OpLabel + OpReturn + + %20 = OpLabel + OpBranch %99 + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + fe.ComputeBlockOrderAndPositions(); + + EXPECT_THAT(fe.rspo(), ElementsAre(10, 20, 30, 99)); +} + +TEST_F(SpvParserTest, ComputeBlockOrder_TrueOnlyBranch) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpSelectionMerge %99 None + OpBranchConditional %cond %20 %99 + + %99 = OpLabel + OpReturn + + %20 = OpLabel + OpBranch %99 + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + fe.ComputeBlockOrderAndPositions(); + + EXPECT_THAT(fe.rspo(), ElementsAre(10, 20, 99)); +} + +TEST_F(SpvParserTest, ComputeBlockOrder_FalseOnlyBranch) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpSelectionMerge %99 None + OpBranchConditional %cond %99 %20 + + %99 = OpLabel + OpReturn + + %20 = OpLabel + OpBranch %99 + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + fe.ComputeBlockOrderAndPositions(); + + EXPECT_THAT(fe.rspo(), ElementsAre(10, 20, 99)); +} + +TEST_F(SpvParserTest, ComputeBlockOrder_SwitchOrderNaturallyReversed) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpSelectionMerge %99 None + OpSwitch %selector %99 20 %20 30 %30 + + %99 = OpLabel + OpReturn + + %30 = OpLabel + OpReturn + + %20 = OpLabel + OpBranch %99 + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + fe.ComputeBlockOrderAndPositions(); + + EXPECT_THAT(fe.rspo(), ElementsAre(10, 30, 20, 99)); +} + +TEST_F(SpvParserTest, + ComputeBlockOrder_SwitchWithDefaultOrderNaturallyReversed) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpSelectionMerge %99 None + OpSwitch %selector %80 20 %20 30 %30 + + %80 = OpLabel ; the default case + OpBranch %99 + + %99 = OpLabel + OpReturn + + %30 = OpLabel + OpReturn + + %20 = OpLabel + OpBranch %99 + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + fe.ComputeBlockOrderAndPositions(); + + EXPECT_THAT(fe.rspo(), ElementsAre(10, 30, 20, 80, 99)); +} + +TEST_F(SpvParserTest, ComputeBlockOrder_RespectSwitchCaseFallthrough) { + auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpSelectionMerge %99 None + OpSwitch %selector %99 20 %20 30 %30 40 %40 50 %50 + + %50 = OpLabel + OpBranch %99 + + %99 = OpLabel + OpReturn + + %40 = OpLabel + OpBranch %99 + + %30 = OpLabel + OpBranch %50 ; fallthrough + + %20 = OpLabel + OpBranch %40 ; fallthrough + + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + fe.ComputeBlockOrderAndPositions(); + + EXPECT_THAT(fe.rspo(), ElementsAre(10, 30, 50, 20, 40, 99)) << assembly; +} + +TEST_F(SpvParserTest, + ComputeBlockOrder_RespectSwitchCaseFallthrough_FromDefault) { + auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpSelectionMerge %99 None + OpSwitch %selector %80 20 %20 30 %30 40 %40 + + %80 = OpLabel ; the default case + OpBranch %30 ; fallthrough to another case + + %99 = OpLabel + OpReturn + + %40 = OpLabel + OpBranch %99 + + %30 = OpLabel + OpBranch %40 + + %20 = OpLabel + OpBranch %99 + + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + fe.ComputeBlockOrderAndPositions(); + + EXPECT_THAT(fe.rspo(), ElementsAre(10, 20, 80, 30, 40, 99)) << assembly; +} + +TEST_F(SpvParserTest, + ComputeBlockOrder_RespectSwitchCaseFallthrough_FromCaseToDefaultToCase) { + auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpSelectionMerge %99 None + OpSwitch %selector %80 20 %20 30 %30 + + %99 = OpLabel + OpReturn + + %20 = OpLabel + OpBranch %80 ; fallthrough to default + + %80 = OpLabel ; the default case + OpBranch %30 ; fallthrough to 30 + + %30 = OpLabel + OpBranch %99 + + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + fe.ComputeBlockOrderAndPositions(); + + EXPECT_THAT(fe.rspo(), ElementsAre(10, 20, 80, 30, 99)) << assembly; +} + +TEST_F(SpvParserTest, + ComputeBlockOrder_SwitchCasesFallthrough_OppositeDirections) { + auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpSelectionMerge %99 None + OpSwitch %selector %99 20 %20 30 %30 40 %40 50 %50 + + %99 = OpLabel + OpReturn + + %20 = OpLabel + OpBranch %30 ; forward + + %40 = OpLabel + OpBranch %99 + + %30 = OpLabel + OpBranch %99 + + ; SPIR-V doesn't actually allow a fall-through that goes backward in the + ; module. But the block ordering algorithm tolerates it. + %50 = OpLabel + OpBranch %40 ; backward + + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + fe.ComputeBlockOrderAndPositions(); + + EXPECT_THAT(fe.rspo(), ElementsAre(10, 50, 40, 20, 30, 99)) << assembly; +} + +TEST_F(SpvParserTest, + ComputeBlockOrder_RespectSwitchCaseFallthrough_Interleaved) { + auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpSelectionMerge %99 None + OpSwitch %selector %99 20 %20 30 %30 40 %40 50 %50 + + %99 = OpLabel + OpReturn + + %20 = OpLabel + OpBranch %40 + + %30 = OpLabel + OpBranch %50 + + %40 = OpLabel + OpBranch %60 + + %50 = OpLabel + OpBranch %70 + + %60 = OpLabel + OpBranch %99 + + %70 = OpLabel + OpBranch %99 + + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + fe.ComputeBlockOrderAndPositions(); + + EXPECT_THAT(fe.rspo(), ElementsAre(10, 30, 50, 70, 20, 40, 60, 99)) + << assembly; +} + +// TODO(dneto): test nesting +// TODO(dneto): test loops + +} // namespace +} // namespace spirv +} // namespace reader +} // namespace tint