[spirv-reader] Add fallthrough

Bug: tint:3
Change-Id: Ib2d337156d419ed13ef9c67aa94ac3ee90f79548
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23041
Reviewed-by: dan sinclair <dsinclair@google.com>
This commit is contained in:
David Neto 2020-06-11 20:53:39 +00:00 committed by dan sinclair
parent 416be308fc
commit 709b62528c
3 changed files with 652 additions and 36 deletions

View File

@ -2040,9 +2040,12 @@ bool FunctionEmitter::EmitNormalTerminator(const BlockInfo& block_info) {
// The fallthrough case is special because WGSL requires the fallthrough
// statement to be last in the case clause.
if (true_kind == EdgeKind::kCaseFallThrough ||
false_kind == EdgeKind::kCaseFallThrough) {
return Fail() << "fallthrough is unhandled";
if (true_kind == EdgeKind::kCaseFallThrough) {
return EmitConditionalCaseFallThrough(block_info, std::move(cond),
false_kind, *false_info, true);
} else if (false_kind == EdgeKind::kCaseFallThrough) {
return EmitConditionalCaseFallThrough(block_info, std::move(cond),
true_kind, *true_info, false);
}
// At this point, at most one edge is kForward or kIfBreak.
@ -2067,16 +2070,21 @@ bool FunctionEmitter::EmitNormalTerminator(const BlockInfo& block_info) {
return success();
}
std::unique_ptr<ast::Statement> FunctionEmitter::MakeBranch(
std::unique_ptr<ast::Statement> FunctionEmitter::MakeBranchInternal(
const BlockInfo& src_info,
const BlockInfo& dest_info) const {
const BlockInfo& dest_info,
bool forced) const {
auto kind = src_info.succ_edge.find(dest_info.id)->second;
switch (kind) {
case EdgeKind::kBack:
// Nothing to do. The loop backedge is implicit.
break;
case EdgeKind::kSwitchBreak: {
// Don't bother with a break at the end of a case.
if (forced) {
return std::make_unique<ast::BreakStatement>();
}
// Unless forced, don't bother with a break at the end of a case/default
// clause.
const auto header = dest_info.header_for_merge;
assert(header != 0);
const auto* exiting_construct = GetBlockInfo(header)->construct;
@ -2148,6 +2156,52 @@ std::unique_ptr<ast::Statement> FunctionEmitter::MakeSimpleIf(
return if_stmt;
}
bool FunctionEmitter::EmitConditionalCaseFallThrough(
const BlockInfo& src_info,
std::unique_ptr<ast::Expression> cond,
EdgeKind other_edge_kind,
const BlockInfo& other_dest,
bool fall_through_is_true_branch) {
// In WGSL, the fallthrough statement must come last in the case clause.
// So we'll emit an if statement for the other branch, and then emit
// the fallthrough.
// We have two distinct destinations. But we only get here if this
// is a normal terminator; in particular the source block is *not* the
// start of an if-selection. So at most one branch is a kForward or
// kCaseFallThrough.
if (other_edge_kind == EdgeKind::kForward) {
return Fail()
<< "internal error: normal terminator OpBranchConditional has "
"both forward and fallthrough edges";
}
if (other_edge_kind == EdgeKind::kIfBreak) {
return Fail()
<< "internal error: normal terminator OpBranchConditional has "
"both IfBreak and fallthrough edges. Violates nesting rule";
}
if (other_edge_kind == EdgeKind::kBack) {
return Fail()
<< "internal error: normal terminator OpBranchConditional has "
"both backedge and fallthrough edges. Violates nesting rule";
}
auto other_branch = MakeForcedBranch(src_info, other_dest);
if (other_branch == nullptr) {
return Fail() << "internal error: expected a branch for edge-kind "
<< int(other_edge_kind);
}
if (fall_through_is_true_branch) {
AddStatement(
MakeSimpleIf(std::move(cond), nullptr, std::move(other_branch)));
} else {
AddStatement(
MakeSimpleIf(std::move(cond), std::move(other_branch), nullptr));
}
AddStatement(std::make_unique<ast::FallthroughStatement>());
return success();
}
bool FunctionEmitter::EmitStatementsInBasicBlock(const BlockInfo& block_info,
bool* already_emitted) {
if (*already_emitted) {

View File

@ -330,12 +330,42 @@ class FunctionEmitter {
/// Returns a new statement to represent the given branch representing a
/// "normal" terminator, as in the sense of EmitNormalTerminator. If no
/// WGSL statement is required, the statement will be nullptr.
/// WGSL statement is required, the statement will be nullptr. This method
/// tries to avoid emitting a 'break' statement when that would be redundant
/// in WGSL due to implicit breaking out of a switch.
/// @param src_info the source block
/// @param dest_info the destination block
/// @returns the new statement, or a null statement
std::unique_ptr<ast::Statement> MakeBranch(const BlockInfo& src_info,
const BlockInfo& dest_info) const;
const BlockInfo& dest_info) const {
return MakeBranchInternal(src_info, dest_info, false);
}
/// Returns a new statement to represent the given branch representing a
/// "normal" terminator, as in the sense of EmitNormalTerminator. If no
/// WGSL statement is required, the statement will be nullptr.
/// @param src_info the source block
/// @param dest_info the destination block
/// @returns the new statement, or a null statement
std::unique_ptr<ast::Statement> MakeForcedBranch(
const BlockInfo& src_info,
const BlockInfo& dest_info) const {
return MakeBranchInternal(src_info, dest_info, true);
}
/// Returns a new statement to represent the given branch representing a
/// "normal" terminator, as in the sense of EmitNormalTerminator. If no
/// WGSL statement is required, the statement will be nullptr. When |forced|
/// is false, this method tries to avoid emitting a 'break' statement when
/// that would be redundant in WGSL due to implicit breaking out of a switch.
/// When |forced| is true, the method won't try to avoid emitting that break.
/// @param src_info the source block
/// @param dest_info the destination block
/// @param forced if true, always emit the branch (if it exists in WGSL)
/// @returns the new statement, or a null statement
std::unique_ptr<ast::Statement> MakeBranchInternal(const BlockInfo& src_info,
const BlockInfo& dest_info,
bool forced) const;
/// Returns a new if statement with the given statements as the then-clause
/// and the else-clause. Either or both clauses might be nullptr. If both
@ -349,6 +379,24 @@ class FunctionEmitter {
std::unique_ptr<ast::Statement> then_stmt,
std::unique_ptr<ast::Statement> else_stmt) const;
/// Emits the statements for an normal-terminator OpBranchConditional
/// where one branch is a case fall through (the true branch if and only
/// if |fall_through_is_true_branch| is true), and the other branch is
/// goes to a different destination, named by |other_dest|.
/// @param src_info the basic block from which we're branching
/// @param cond the branching condition
/// @param other_edge_kind the edge kind from the source block to the other
/// destination
/// @param other_dest the other branching destination
/// @param fall_through_is_true_branch true when the fall-through is the true
/// branch
/// @returns the false if emission fails
bool EmitConditionalCaseFallThrough(const BlockInfo& src_info,
std::unique_ptr<ast::Expression> cond,
EdgeKind other_edge_kind,
const BlockInfo& other_dest,
bool fall_through_is_true_branch);
/// Emits a normal instruction: not a terminator, label, or variable
/// declaration.
/// @param inst the instruction

View File

@ -10237,8 +10237,63 @@ Return{}
)")) << ToString(fe.ast_body());
}
TEST_F(SpvParserTest, DISABLED_EmitBody_Branch_Fallthrough) {
// TODO(dneto): support fallthrough first.
TEST_F(SpvParserTest, EmitBody_Branch_Fallthrough) {
auto* p = parser(test::Assemble(CommonTypes() + R"(
%100 = OpFunction %void None %voidfn
%10 = OpLabel
OpStore %var %uint_1
OpSelectionMerge %99 None
OpSwitch %selector %99 20 %20 30 %30
%20 = OpLabel
OpStore %var %uint_20
OpBranch %30 ; uncondtional fallthrough
%30 = OpLabel
OpStore %var %uint_30
OpBranch %99
%99 = OpLabel
OpStore %var %uint_7
OpReturn
OpFunctionEnd
)"));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_TRUE(fe.EmitBody()) << p->error();
EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{
Identifier{var}
ScalarConstructor{1}
}
Switch{
ScalarConstructor{42}
{
Case 20{
Assignment{
Identifier{var}
ScalarConstructor{20}
}
Fallthrough{}
}
Case 30{
Assignment{
Identifier{var}
ScalarConstructor{30}
}
}
Default{
}
}
}
Assignment{
Identifier{var}
ScalarConstructor{7}
}
Return{}
)")) << ToString(fe.ast_body());
}
TEST_F(SpvParserTest, EmitBody_Branch_Forward) {
@ -10299,8 +10354,8 @@ Return{}
// kLoopBreak: dup general case
// kLoopContinue: TESTED
// kIfBreak: invalid: switch and if must have distinct merge blocks
// kCaseFallThrough: TODO(dneto)
// kForward: TESTED
// kCaseFallThrough: not possible, because switch break conflicts with loop
// break kForward: TESTED
//
// kLoopContinue with:
// kBack : symmetry
@ -11039,13 +11094,143 @@ Return{}
}
TEST_F(SpvParserTest,
DISABLED_EmitBody_BranchConditional_SwitchBreak_Fallthrough_OnTrue) {
// TODO(dneto): needs fallthrough support
EmitBody_BranchConditional_SwitchBreak_Fallthrough_OnTrue) {
auto* p = parser(test::Assemble(CommonTypes() + R"(
%100 = OpFunction %void None %voidfn
%10 = OpLabel
OpStore %var %uint_1
OpSelectionMerge %99 None
OpSwitch %selector %99 20 %20 30 %30
%20 = OpLabel
OpStore %var %uint_20
OpBranchConditional %cond %30 %99; fallthrough on true
%30 = OpLabel
OpStore %var %uint_30
OpBranch %99
%99 = OpLabel
OpStore %var %uint_7
OpReturn
OpFunctionEnd
)"));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_TRUE(fe.EmitBody()) << p->error();
EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{
Identifier{var}
ScalarConstructor{1}
}
Switch{
ScalarConstructor{42}
{
Case 20{
Assignment{
Identifier{var}
ScalarConstructor{20}
}
If{
(
ScalarConstructor{false}
)
{
}
}
Else{
{
Break{}
}
}
Fallthrough{}
}
Case 30{
Assignment{
Identifier{var}
ScalarConstructor{30}
}
}
Default{
}
}
}
Assignment{
Identifier{var}
ScalarConstructor{7}
}
Return{}
)")) << ToString(fe.ast_body());
}
TEST_F(SpvParserTest,
DISABLED_EmitBody_BranchConditional_SwitchBreak_Fallthrough_OnFalse) {
// TODO(dneto): needs fallthrough support
EmitBody_BranchConditional_SwitchBreak_Fallthrough_OnFalse) {
auto* p = parser(test::Assemble(CommonTypes() + R"(
%100 = OpFunction %void None %voidfn
%10 = OpLabel
OpStore %var %uint_1
OpSelectionMerge %99 None
OpSwitch %selector %99 20 %20 30 %30
%20 = OpLabel
OpStore %var %uint_20
OpBranchConditional %cond %99 %30; fallthrough on false
%30 = OpLabel
OpStore %var %uint_30
OpBranch %99
%99 = OpLabel
OpStore %var %uint_7
OpReturn
OpFunctionEnd
)"));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_TRUE(fe.EmitBody()) << p->error();
EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{
Identifier{var}
ScalarConstructor{1}
}
Switch{
ScalarConstructor{42}
{
Case 20{
Assignment{
Identifier{var}
ScalarConstructor{20}
}
If{
(
ScalarConstructor{false}
)
{
Break{}
}
}
Fallthrough{}
}
Case 30{
Assignment{
Identifier{var}
ScalarConstructor{30}
}
}
Default{
}
}
}
Assignment{
Identifier{var}
ScalarConstructor{7}
}
Return{}
)")) << ToString(fe.ast_body());
}
TEST_F(SpvParserTest,
@ -11349,12 +11534,53 @@ Return{}
}
TEST_F(SpvParserTest,
DISABLED_EmitBody_BranchConditional_LoopBreak_Fallthrough_OnTrue) {
// TODO(dneto): needs fallthrough support
}
TEST_F(SpvParserTest,
DISABLED_EmitBody_BranchConditional_LoopBreak_Fallthrough_OnFalse) {
// TODO(dneto): needs fallthrough support
EmitBody_BranchConditional_LoopBreak_Fallthrough_IsError) {
// It's an error because switch break conflicts with loop break.
auto* p = parser(test::Assemble(CommonTypes() + R"(
%100 = OpFunction %void None %voidfn
%10 = OpLabel
OpStore %var %uint_0
OpBranch %20
%20 = OpLabel
OpStore %var %uint_1
OpLoopMerge %99 %80 None
OpBranch %30
%30 = OpLabel
OpSelectionMerge %79 None
OpSwitch %selector %79 40 %40 50 %50
%40 = OpLabel
OpStore %var %uint_40
; error: branch to 99 bypasses switch's merge
OpBranchConditional %cond %99 %50 ; loop break; fall through
%50 = OpLabel
OpStore %var %uint_50
OpBranch %79
%79 = OpLabel ; switch merge
OpBranch %80
%80 = OpLabel ; continue target
OpStore %var %uint_4
OpBranch %20
%99 = OpLabel
OpStore %var %uint_5
OpReturn
OpFunctionEnd
)"));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_FALSE(fe.EmitBody()) << p->error();
EXPECT_THAT(
p->error(),
Eq("Branch from block 40 to block 99 is an invalid exit from construct "
"starting at block 30; branch bypasses merge block 79"));
}
TEST_F(SpvParserTest, EmitBody_BranchConditional_LoopBreak_Forward_OnTrue) {
@ -12057,13 +12283,214 @@ Return{}
)")) << ToString(fe.ast_body());
}
TEST_F(SpvParserTest,
DISABLED_EmitBody_BranchConditional_Continue_Fallthrough_OnTrue) {
// TODO(dneto): needs fallthrough support
TEST_F(SpvParserTest, EmitBody_BranchConditional_Continue_Fallthrough_OnTrue) {
auto* p = parser(test::Assemble(CommonTypes() + R"(
%100 = OpFunction %void None %voidfn
%10 = OpLabel
OpStore %var %uint_0
OpBranch %20
%20 = OpLabel
OpStore %var %uint_1
OpLoopMerge %99 %80 None
OpBranch %30
%30 = OpLabel
OpStore %var %uint_2
OpSelectionMerge %79 None
OpSwitch %selector %79 40 %40 50 %50
%40 = OpLabel
OpStore %var %uint_40
OpBranchConditional %cond %50 %80 ; loop continue; fall through on true
%50 = OpLabel
OpStore %var %uint_50
OpBranch %79
%79 = OpLabel ; switch merge
OpStore %var %uint_3
OpBranch %80
%80 = OpLabel ; continue target
OpStore %var %uint_4
OpBranch %20
%99 = OpLabel
OpStore %var %uint_5
OpReturn
OpFunctionEnd
)"));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_TRUE(fe.EmitBody()) << p->error();
EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{
Identifier{var}
ScalarConstructor{0}
}
TEST_F(SpvParserTest,
DISABLED_EmitBody_BranchConditional_Continue_Fallthrough_OnFalse) {
// TODO(dneto): needs fallthrough support
Loop{
Assignment{
Identifier{var}
ScalarConstructor{1}
}
Assignment{
Identifier{var}
ScalarConstructor{2}
}
Switch{
ScalarConstructor{42}
{
Case 40{
Assignment{
Identifier{var}
ScalarConstructor{40}
}
If{
(
ScalarConstructor{false}
)
{
}
}
Else{
{
Continue{}
}
}
Fallthrough{}
}
Case 50{
Assignment{
Identifier{var}
ScalarConstructor{50}
}
}
Default{
}
}
}
Assignment{
Identifier{var}
ScalarConstructor{3}
}
continuing {
Assignment{
Identifier{var}
ScalarConstructor{4}
}
}
}
Assignment{
Identifier{var}
ScalarConstructor{5}
}
Return{}
)")) << ToString(fe.ast_body());
}
TEST_F(SpvParserTest, EmitBody_BranchConditional_Continue_Fallthrough_OnFalse) {
auto* p = parser(test::Assemble(CommonTypes() + R"(
%100 = OpFunction %void None %voidfn
%10 = OpLabel
OpStore %var %uint_0
OpBranch %20
%20 = OpLabel
OpStore %var %uint_1
OpLoopMerge %99 %80 None
OpBranch %30
%30 = OpLabel
OpStore %var %uint_2
OpSelectionMerge %79 None
OpSwitch %selector %79 40 %40 50 %50
%40 = OpLabel
OpStore %var %uint_40
OpBranchConditional %cond %80 %50 ; loop continue; fall through on false
%50 = OpLabel
OpStore %var %uint_50
OpBranch %79
%79 = OpLabel ; switch merge
OpStore %var %uint_3
OpBranch %80
%80 = OpLabel ; continue target
OpStore %var %uint_4
OpBranch %20
%99 = OpLabel
OpStore %var %uint_5
OpReturn
OpFunctionEnd
)"));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_TRUE(fe.EmitBody()) << p->error();
EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{
Identifier{var}
ScalarConstructor{0}
}
Loop{
Assignment{
Identifier{var}
ScalarConstructor{1}
}
Assignment{
Identifier{var}
ScalarConstructor{2}
}
Switch{
ScalarConstructor{42}
{
Case 40{
Assignment{
Identifier{var}
ScalarConstructor{40}
}
If{
(
ScalarConstructor{false}
)
{
Continue{}
}
}
Fallthrough{}
}
Case 50{
Assignment{
Identifier{var}
ScalarConstructor{50}
}
}
Default{
}
}
}
Assignment{
Identifier{var}
ScalarConstructor{3}
}
continuing {
Assignment{
Identifier{var}
ScalarConstructor{4}
}
}
}
Assignment{
Identifier{var}
ScalarConstructor{5}
}
Return{}
)")) << ToString(fe.ast_body());
}
TEST_F(SpvParserTest, EmitBody_BranchConditional_Continue_Forward_OnTrue) {
@ -12298,16 +12725,103 @@ TEST_F(SpvParserTest,
"starting at block 20; branch bypasses merge block 89"));
}
TEST_F(SpvParserTest,
DISABLED_EmitBody_BranchConditional_Fallthrough_Fallthrough_Same) {
// Can only be to the same target.
// TODO(dneto): needs fallthrough support
TEST_F(SpvParserTest, EmitBody_BranchConditional_Fallthrough_Fallthrough_Same) {
auto* p = parser(test::Assemble(CommonTypes() + R"(
%100 = OpFunction %void None %voidfn
%10 = OpLabel
OpStore %var %uint_1
OpSelectionMerge %99 None
OpSwitch %selector %99 20 %20 30 %30
%20 = OpLabel
OpStore %var %uint_20
OpBranchConditional %cond %30 %30 ; fallthrough fallthrough
%30 = OpLabel
OpStore %var %uint_30
OpBranch %99
%99 = OpLabel
OpStore %var %uint_7
OpReturn
OpFunctionEnd
)"));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_TRUE(fe.EmitBody()) << p->error();
EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{
Identifier{var}
ScalarConstructor{1}
}
Switch{
ScalarConstructor{42}
{
Case 20{
Assignment{
Identifier{var}
ScalarConstructor{20}
}
Fallthrough{}
}
Case 30{
Assignment{
Identifier{var}
ScalarConstructor{30}
}
}
Default{
}
}
}
Assignment{
Identifier{var}
ScalarConstructor{7}
}
Return{}
)")) << ToString(fe.ast_body());
}
TEST_F(
SpvParserTest,
DISABLED_EmitBody_BranchConditional_Fallthrough_Fallthrough_Different_IsError) {
// TODO(dneto): needs fallthrough support
TEST_F(SpvParserTest,
EmitBody_BranchConditional_Fallthrough_NotLastInCase_IsError) {
// See also
// ClassifyCFGEdges_Fallthrough_BranchConditionalWith_Forward_IsError.
auto* p = parser(test::Assemble(CommonTypes() + R"(
%100 = OpFunction %void None %voidfn
%10 = OpLabel
OpSelectionMerge %99 None
OpSwitch %selector %99 20 %20 40 %40
%20 = OpLabel ; case 30
OpSelectionMerge %39 None
OpBranchConditional %cond %40 %30 ; fallthrough and forward
%30 = OpLabel
OpBranch %39
%39 = OpLabel
OpBranch %99
%40 = OpLabel ; case 40
OpBranch %99
%99 = OpLabel
OpReturn
OpFunctionEnd
)"));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_FALSE(fe.EmitBody());
// The weird forward branch pulls in 40 as part of the selection rather than
// as a case.
EXPECT_THAT(fe.block_order(), ElementsAre(10, 20, 40, 30, 39, 99));
EXPECT_THAT(
p->error(),
Eq("Branch from 10 to 40 bypasses header 20 (dominance rule violated)"));
}
TEST_F(SpvParserTest, EmitBody_BranchConditional_Forward_Forward_Same) {