diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index 723f35174e..a5cefcbf50 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -2040,9 +2040,12 @@ bool FunctionEmitter::EmitNormalTerminator(const BlockInfo& block_info) { // The fallthrough case is special because WGSL requires the fallthrough // statement to be last in the case clause. - if (true_kind == EdgeKind::kCaseFallThrough || - false_kind == EdgeKind::kCaseFallThrough) { - return Fail() << "fallthrough is unhandled"; + if (true_kind == EdgeKind::kCaseFallThrough) { + return EmitConditionalCaseFallThrough(block_info, std::move(cond), + false_kind, *false_info, true); + } else if (false_kind == EdgeKind::kCaseFallThrough) { + return EmitConditionalCaseFallThrough(block_info, std::move(cond), + true_kind, *true_info, false); } // At this point, at most one edge is kForward or kIfBreak. @@ -2067,16 +2070,21 @@ bool FunctionEmitter::EmitNormalTerminator(const BlockInfo& block_info) { return success(); } -std::unique_ptr FunctionEmitter::MakeBranch( +std::unique_ptr FunctionEmitter::MakeBranchInternal( const BlockInfo& src_info, - const BlockInfo& dest_info) const { + const BlockInfo& dest_info, + bool forced) const { auto kind = src_info.succ_edge.find(dest_info.id)->second; switch (kind) { case EdgeKind::kBack: // Nothing to do. The loop backedge is implicit. break; case EdgeKind::kSwitchBreak: { - // Don't bother with a break at the end of a case. + if (forced) { + return std::make_unique(); + } + // Unless forced, don't bother with a break at the end of a case/default + // clause. const auto header = dest_info.header_for_merge; assert(header != 0); const auto* exiting_construct = GetBlockInfo(header)->construct; @@ -2148,6 +2156,52 @@ std::unique_ptr FunctionEmitter::MakeSimpleIf( return if_stmt; } +bool FunctionEmitter::EmitConditionalCaseFallThrough( + const BlockInfo& src_info, + std::unique_ptr cond, + EdgeKind other_edge_kind, + const BlockInfo& other_dest, + bool fall_through_is_true_branch) { + // In WGSL, the fallthrough statement must come last in the case clause. + // So we'll emit an if statement for the other branch, and then emit + // the fallthrough. + + // We have two distinct destinations. But we only get here if this + // is a normal terminator; in particular the source block is *not* the + // start of an if-selection. So at most one branch is a kForward or + // kCaseFallThrough. + if (other_edge_kind == EdgeKind::kForward) { + return Fail() + << "internal error: normal terminator OpBranchConditional has " + "both forward and fallthrough edges"; + } + if (other_edge_kind == EdgeKind::kIfBreak) { + return Fail() + << "internal error: normal terminator OpBranchConditional has " + "both IfBreak and fallthrough edges. Violates nesting rule"; + } + if (other_edge_kind == EdgeKind::kBack) { + return Fail() + << "internal error: normal terminator OpBranchConditional has " + "both backedge and fallthrough edges. Violates nesting rule"; + } + auto other_branch = MakeForcedBranch(src_info, other_dest); + if (other_branch == nullptr) { + return Fail() << "internal error: expected a branch for edge-kind " + << int(other_edge_kind); + } + if (fall_through_is_true_branch) { + AddStatement( + MakeSimpleIf(std::move(cond), nullptr, std::move(other_branch))); + } else { + AddStatement( + MakeSimpleIf(std::move(cond), std::move(other_branch), nullptr)); + } + AddStatement(std::make_unique()); + + return success(); +} + bool FunctionEmitter::EmitStatementsInBasicBlock(const BlockInfo& block_info, bool* already_emitted) { if (*already_emitted) { diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h index a5a53471eb..1719defdef 100644 --- a/src/reader/spirv/function.h +++ b/src/reader/spirv/function.h @@ -330,12 +330,42 @@ class FunctionEmitter { /// Returns a new statement to represent the given branch representing a /// "normal" terminator, as in the sense of EmitNormalTerminator. If no - /// WGSL statement is required, the statement will be nullptr. + /// WGSL statement is required, the statement will be nullptr. This method + /// tries to avoid emitting a 'break' statement when that would be redundant + /// in WGSL due to implicit breaking out of a switch. /// @param src_info the source block /// @param dest_info the destination block /// @returns the new statement, or a null statement std::unique_ptr MakeBranch(const BlockInfo& src_info, - const BlockInfo& dest_info) const; + const BlockInfo& dest_info) const { + return MakeBranchInternal(src_info, dest_info, false); + } + + /// Returns a new statement to represent the given branch representing a + /// "normal" terminator, as in the sense of EmitNormalTerminator. If no + /// WGSL statement is required, the statement will be nullptr. + /// @param src_info the source block + /// @param dest_info the destination block + /// @returns the new statement, or a null statement + std::unique_ptr MakeForcedBranch( + const BlockInfo& src_info, + const BlockInfo& dest_info) const { + return MakeBranchInternal(src_info, dest_info, true); + } + + /// Returns a new statement to represent the given branch representing a + /// "normal" terminator, as in the sense of EmitNormalTerminator. If no + /// WGSL statement is required, the statement will be nullptr. When |forced| + /// is false, this method tries to avoid emitting a 'break' statement when + /// that would be redundant in WGSL due to implicit breaking out of a switch. + /// When |forced| is true, the method won't try to avoid emitting that break. + /// @param src_info the source block + /// @param dest_info the destination block + /// @param forced if true, always emit the branch (if it exists in WGSL) + /// @returns the new statement, or a null statement + std::unique_ptr MakeBranchInternal(const BlockInfo& src_info, + const BlockInfo& dest_info, + bool forced) const; /// Returns a new if statement with the given statements as the then-clause /// and the else-clause. Either or both clauses might be nullptr. If both @@ -349,6 +379,24 @@ class FunctionEmitter { std::unique_ptr then_stmt, std::unique_ptr else_stmt) const; + /// Emits the statements for an normal-terminator OpBranchConditional + /// where one branch is a case fall through (the true branch if and only + /// if |fall_through_is_true_branch| is true), and the other branch is + /// goes to a different destination, named by |other_dest|. + /// @param src_info the basic block from which we're branching + /// @param cond the branching condition + /// @param other_edge_kind the edge kind from the source block to the other + /// destination + /// @param other_dest the other branching destination + /// @param fall_through_is_true_branch true when the fall-through is the true + /// branch + /// @returns the false if emission fails + bool EmitConditionalCaseFallThrough(const BlockInfo& src_info, + std::unique_ptr cond, + EdgeKind other_edge_kind, + const BlockInfo& other_dest, + bool fall_through_is_true_branch); + /// Emits a normal instruction: not a terminator, label, or variable /// declaration. /// @param inst the instruction diff --git a/src/reader/spirv/function_cfg_test.cc b/src/reader/spirv/function_cfg_test.cc index 4cf029dc29..820f9695cc 100644 --- a/src/reader/spirv/function_cfg_test.cc +++ b/src/reader/spirv/function_cfg_test.cc @@ -10237,8 +10237,63 @@ Return{} )")) << ToString(fe.ast_body()); } -TEST_F(SpvParserTest, DISABLED_EmitBody_Branch_Fallthrough) { - // TODO(dneto): support fallthrough first. +TEST_F(SpvParserTest, EmitBody_Branch_Fallthrough) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_1 + OpSelectionMerge %99 None + OpSwitch %selector %99 20 %20 30 %30 + + %20 = OpLabel + OpStore %var %uint_20 + OpBranch %30 ; uncondtional fallthrough + + %30 = OpLabel + OpStore %var %uint_30 + OpBranch %99 + + %99 = OpLabel + OpStore %var %uint_7 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{1} +} +Switch{ + ScalarConstructor{42} + { + Case 20{ + Assignment{ + Identifier{var} + ScalarConstructor{20} + } + Fallthrough{} + } + Case 30{ + Assignment{ + Identifier{var} + ScalarConstructor{30} + } + } + Default{ + } + } +} +Assignment{ + Identifier{var} + ScalarConstructor{7} +} +Return{} +)")) << ToString(fe.ast_body()); } TEST_F(SpvParserTest, EmitBody_Branch_Forward) { @@ -10299,8 +10354,8 @@ Return{} // kLoopBreak: dup general case // kLoopContinue: TESTED // kIfBreak: invalid: switch and if must have distinct merge blocks -// kCaseFallThrough: TODO(dneto) -// kForward: TESTED +// kCaseFallThrough: not possible, because switch break conflicts with loop +// break kForward: TESTED // // kLoopContinue with: // kBack : symmetry @@ -11039,13 +11094,143 @@ Return{} } TEST_F(SpvParserTest, - DISABLED_EmitBody_BranchConditional_SwitchBreak_Fallthrough_OnTrue) { - // TODO(dneto): needs fallthrough support + EmitBody_BranchConditional_SwitchBreak_Fallthrough_OnTrue) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_1 + OpSelectionMerge %99 None + OpSwitch %selector %99 20 %20 30 %30 + + %20 = OpLabel + OpStore %var %uint_20 + OpBranchConditional %cond %30 %99; fallthrough on true + + %30 = OpLabel + OpStore %var %uint_30 + OpBranch %99 + + %99 = OpLabel + OpStore %var %uint_7 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{1} +} +Switch{ + ScalarConstructor{42} + { + Case 20{ + Assignment{ + Identifier{var} + ScalarConstructor{20} + } + If{ + ( + ScalarConstructor{false} + ) + { + } + } + Else{ + { + Break{} + } + } + Fallthrough{} + } + Case 30{ + Assignment{ + Identifier{var} + ScalarConstructor{30} + } + } + Default{ + } + } +} +Assignment{ + Identifier{var} + ScalarConstructor{7} +} +Return{} +)")) << ToString(fe.ast_body()); } TEST_F(SpvParserTest, - DISABLED_EmitBody_BranchConditional_SwitchBreak_Fallthrough_OnFalse) { - // TODO(dneto): needs fallthrough support + EmitBody_BranchConditional_SwitchBreak_Fallthrough_OnFalse) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_1 + OpSelectionMerge %99 None + OpSwitch %selector %99 20 %20 30 %30 + + %20 = OpLabel + OpStore %var %uint_20 + OpBranchConditional %cond %99 %30; fallthrough on false + + %30 = OpLabel + OpStore %var %uint_30 + OpBranch %99 + + %99 = OpLabel + OpStore %var %uint_7 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{1} +} +Switch{ + ScalarConstructor{42} + { + Case 20{ + Assignment{ + Identifier{var} + ScalarConstructor{20} + } + If{ + ( + ScalarConstructor{false} + ) + { + Break{} + } + } + Fallthrough{} + } + Case 30{ + Assignment{ + Identifier{var} + ScalarConstructor{30} + } + } + Default{ + } + } +} +Assignment{ + Identifier{var} + ScalarConstructor{7} +} +Return{} +)")) << ToString(fe.ast_body()); } TEST_F(SpvParserTest, @@ -11349,12 +11534,53 @@ Return{} } TEST_F(SpvParserTest, - DISABLED_EmitBody_BranchConditional_LoopBreak_Fallthrough_OnTrue) { - // TODO(dneto): needs fallthrough support -} -TEST_F(SpvParserTest, - DISABLED_EmitBody_BranchConditional_LoopBreak_Fallthrough_OnFalse) { - // TODO(dneto): needs fallthrough support + EmitBody_BranchConditional_LoopBreak_Fallthrough_IsError) { + // It's an error because switch break conflicts with loop break. + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_0 + OpBranch %20 + + %20 = OpLabel + OpStore %var %uint_1 + OpLoopMerge %99 %80 None + OpBranch %30 + + %30 = OpLabel + OpSelectionMerge %79 None + OpSwitch %selector %79 40 %40 50 %50 + + %40 = OpLabel + OpStore %var %uint_40 + ; error: branch to 99 bypasses switch's merge + OpBranchConditional %cond %99 %50 ; loop break; fall through + + %50 = OpLabel + OpStore %var %uint_50 + OpBranch %79 + + %79 = OpLabel ; switch merge + OpBranch %80 + + %80 = OpLabel ; continue target + OpStore %var %uint_4 + OpBranch %20 + + %99 = OpLabel + OpStore %var %uint_5 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_FALSE(fe.EmitBody()) << p->error(); + EXPECT_THAT( + p->error(), + Eq("Branch from block 40 to block 99 is an invalid exit from construct " + "starting at block 30; branch bypasses merge block 79")); } TEST_F(SpvParserTest, EmitBody_BranchConditional_LoopBreak_Forward_OnTrue) { @@ -12057,13 +12283,214 @@ Return{} )")) << ToString(fe.ast_body()); } -TEST_F(SpvParserTest, - DISABLED_EmitBody_BranchConditional_Continue_Fallthrough_OnTrue) { - // TODO(dneto): needs fallthrough support +TEST_F(SpvParserTest, EmitBody_BranchConditional_Continue_Fallthrough_OnTrue) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_0 + OpBranch %20 + + %20 = OpLabel + OpStore %var %uint_1 + OpLoopMerge %99 %80 None + OpBranch %30 + + %30 = OpLabel + OpStore %var %uint_2 + OpSelectionMerge %79 None + OpSwitch %selector %79 40 %40 50 %50 + + %40 = OpLabel + OpStore %var %uint_40 + OpBranchConditional %cond %50 %80 ; loop continue; fall through on true + + %50 = OpLabel + OpStore %var %uint_50 + OpBranch %79 + + %79 = OpLabel ; switch merge + OpStore %var %uint_3 + OpBranch %80 + + %80 = OpLabel ; continue target + OpStore %var %uint_4 + OpBranch %20 + + %99 = OpLabel + OpStore %var %uint_5 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{0} } -TEST_F(SpvParserTest, - DISABLED_EmitBody_BranchConditional_Continue_Fallthrough_OnFalse) { - // TODO(dneto): needs fallthrough support +Loop{ + Assignment{ + Identifier{var} + ScalarConstructor{1} + } + Assignment{ + Identifier{var} + ScalarConstructor{2} + } + Switch{ + ScalarConstructor{42} + { + Case 40{ + Assignment{ + Identifier{var} + ScalarConstructor{40} + } + If{ + ( + ScalarConstructor{false} + ) + { + } + } + Else{ + { + Continue{} + } + } + Fallthrough{} + } + Case 50{ + Assignment{ + Identifier{var} + ScalarConstructor{50} + } + } + Default{ + } + } + } + Assignment{ + Identifier{var} + ScalarConstructor{3} + } + continuing { + Assignment{ + Identifier{var} + ScalarConstructor{4} + } + } +} +Assignment{ + Identifier{var} + ScalarConstructor{5} +} +Return{} +)")) << ToString(fe.ast_body()); +} + +TEST_F(SpvParserTest, EmitBody_BranchConditional_Continue_Fallthrough_OnFalse) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_0 + OpBranch %20 + + %20 = OpLabel + OpStore %var %uint_1 + OpLoopMerge %99 %80 None + OpBranch %30 + + %30 = OpLabel + OpStore %var %uint_2 + OpSelectionMerge %79 None + OpSwitch %selector %79 40 %40 50 %50 + + %40 = OpLabel + OpStore %var %uint_40 + OpBranchConditional %cond %80 %50 ; loop continue; fall through on false + + %50 = OpLabel + OpStore %var %uint_50 + OpBranch %79 + + %79 = OpLabel ; switch merge + OpStore %var %uint_3 + OpBranch %80 + + %80 = OpLabel ; continue target + OpStore %var %uint_4 + OpBranch %20 + + %99 = OpLabel + OpStore %var %uint_5 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{0} +} +Loop{ + Assignment{ + Identifier{var} + ScalarConstructor{1} + } + Assignment{ + Identifier{var} + ScalarConstructor{2} + } + Switch{ + ScalarConstructor{42} + { + Case 40{ + Assignment{ + Identifier{var} + ScalarConstructor{40} + } + If{ + ( + ScalarConstructor{false} + ) + { + Continue{} + } + } + Fallthrough{} + } + Case 50{ + Assignment{ + Identifier{var} + ScalarConstructor{50} + } + } + Default{ + } + } + } + Assignment{ + Identifier{var} + ScalarConstructor{3} + } + continuing { + Assignment{ + Identifier{var} + ScalarConstructor{4} + } + } +} +Assignment{ + Identifier{var} + ScalarConstructor{5} +} +Return{} +)")) << ToString(fe.ast_body()); } TEST_F(SpvParserTest, EmitBody_BranchConditional_Continue_Forward_OnTrue) { @@ -12298,16 +12725,103 @@ TEST_F(SpvParserTest, "starting at block 20; branch bypasses merge block 89")); } -TEST_F(SpvParserTest, - DISABLED_EmitBody_BranchConditional_Fallthrough_Fallthrough_Same) { - // Can only be to the same target. - // TODO(dneto): needs fallthrough support +TEST_F(SpvParserTest, EmitBody_BranchConditional_Fallthrough_Fallthrough_Same) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_1 + OpSelectionMerge %99 None + OpSwitch %selector %99 20 %20 30 %30 + + %20 = OpLabel + OpStore %var %uint_20 + OpBranchConditional %cond %30 %30 ; fallthrough fallthrough + + %30 = OpLabel + OpStore %var %uint_30 + OpBranch %99 + + %99 = OpLabel + OpStore %var %uint_7 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{1} +} +Switch{ + ScalarConstructor{42} + { + Case 20{ + Assignment{ + Identifier{var} + ScalarConstructor{20} + } + Fallthrough{} + } + Case 30{ + Assignment{ + Identifier{var} + ScalarConstructor{30} + } + } + Default{ + } + } +} +Assignment{ + Identifier{var} + ScalarConstructor{7} +} +Return{} +)")) << ToString(fe.ast_body()); } -TEST_F( - SpvParserTest, - DISABLED_EmitBody_BranchConditional_Fallthrough_Fallthrough_Different_IsError) { - // TODO(dneto): needs fallthrough support +TEST_F(SpvParserTest, + EmitBody_BranchConditional_Fallthrough_NotLastInCase_IsError) { + // See also + // ClassifyCFGEdges_Fallthrough_BranchConditionalWith_Forward_IsError. + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpSelectionMerge %99 None + OpSwitch %selector %99 20 %20 40 %40 + + %20 = OpLabel ; case 30 + OpSelectionMerge %39 None + OpBranchConditional %cond %40 %30 ; fallthrough and forward + + %30 = OpLabel + OpBranch %39 + + %39 = OpLabel + OpBranch %99 + + %40 = OpLabel ; case 40 + OpBranch %99 + + %99 = OpLabel + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_FALSE(fe.EmitBody()); + // The weird forward branch pulls in 40 as part of the selection rather than + // as a case. + EXPECT_THAT(fe.block_order(), ElementsAre(10, 20, 40, 30, 39, 99)); + EXPECT_THAT( + p->error(), + Eq("Branch from 10 to 40 bypasses header 20 (dominance rule violated)")); } TEST_F(SpvParserTest, EmitBody_BranchConditional_Forward_Forward_Same) {