spirv-writer: Fix termination of basic blocks

There are a few places where a branch or return is created,
conditionally on whether a terminator was the last thing seen.
The goal is to generate a SPIR-V basic block terminator exactly
when needed, and to avoid generating a branch or return immediately
after a prior terminator.

Previously, the decision was based on the last thing seen in the AST.
But we should instead check the emitted SPIR-V instead.

This fixes cases such as a break or return inside an else-if.
That's because an if/elseif is actually a selection inside a selection.
Looking at the AST only works when trying to terminate the *inside*
selection.  In the outer recursive call, the last AST node is
no longer a terminator, and we would skip generating the branch
to the merge block.

Fixed: tint:1315
Change-Id: I6b886ce85d1d681f2063997e469e0c1b4e5973a2
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/73480
Kokoro: Kokoro <noreply+kokoro@google.com>
Auto-Submit: David Neto <dneto@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: James Price <jrprice@google.com>
This commit is contained in:
David Neto 2021-12-20 16:46:55 +00:00 committed by Tint LUCI CQ
parent a9d6c34d86
commit 66e7569e15
4 changed files with 124 additions and 24 deletions

View File

@ -96,23 +96,6 @@ bool LastIsFallthrough(const ast::BlockStatement* stmts) {
return !stmts->Empty() && stmts->Last()->Is<ast::FallthroughStatement>();
}
// A terminator is anything which will cause a SPIR-V terminator to be emitted.
// This means things like breaks, fallthroughs and continues which all emit an
// OpBranch or return for the OpReturn emission.
bool LastIsTerminator(const ast::BlockStatement* stmts) {
if (IsAnyOf<ast::BreakStatement, ast::ContinueStatement,
ast::DiscardStatement, ast::ReturnStatement,
ast::FallthroughStatement>(stmts->Last())) {
return true;
}
if (auto* block = As<ast::BlockStatement>(stmts->Last())) {
return LastIsTerminator(block);
}
return false;
}
/// Returns the matrix type that is `type` or that is wrapped by
/// one or more levels of an arrays inside of `type`.
/// @param type the given type, which must not be null
@ -650,7 +633,7 @@ bool Builder::GenerateFunction(const ast::Function* func_ast) {
}
}
if (!LastIsTerminator(func_ast->body)) {
if (InsideBasicBlock()) {
if (func->ReturnType()->Is<sem::Void>()) {
push_function_inst(spv::Op::OpReturn, {});
} else {
@ -3500,7 +3483,7 @@ bool Builder::GenerateConditionalBlock(
return false;
}
// We only branch if the last element of the body didn't already branch.
if (!LastIsTerminator(true_body)) {
if (InsideBasicBlock()) {
if (!push_function_inst(spv::Op::OpBranch,
{Operand::Int(merge_block_id)})) {
return false;
@ -3525,7 +3508,7 @@ bool Builder::GenerateConditionalBlock(
return false;
}
}
if (!LastIsTerminator(else_stmt->body)) {
if (InsideBasicBlock()) {
if (!push_function_inst(spv::Op::OpBranch,
{Operand::Int(merge_block_id)})) {
return false;
@ -3673,7 +3656,7 @@ bool Builder::GenerateSwitchStatement(const ast::SwitchStatement* stmt) {
{Operand::Int(case_ids[i + 1])})) {
return false;
}
} else if (!LastIsTerminator(item->body)) {
} else if (InsideBasicBlock()) {
if (!push_function_inst(spv::Op::OpBranch,
{Operand::Int(merge_block_id)})) {
return false;
@ -3765,7 +3748,7 @@ bool Builder::GenerateLoopStatement(const ast::LoopStatement* stmt) {
}
// We only branch if the last element of the body didn't already branch.
if (!LastIsTerminator(stmt->body)) {
if (InsideBasicBlock()) {
if (!push_function_inst(spv::Op::OpBranch,
{Operand::Int(continue_block_id)})) {
return false;
@ -4414,6 +4397,34 @@ bool Builder::push_function_inst(spv::Op op, const OperandList& operands) {
return true;
}
bool Builder::InsideBasicBlock() const {
if (functions_.empty()) {
return false;
}
const auto& instructions = functions_.back().instructions();
if (instructions.empty()) {
// The Function object does not explicitly represent its entry block
// label. So return *true* because an empty list means the only
// thing in the function is that entry block label.
return true;
}
const auto& inst = instructions.back();
switch (inst.opcode()) {
case spv::Op::OpBranch:
case spv::Op::OpBranchConditional:
case spv::Op::OpSwitch:
case spv::Op::OpReturn:
case spv::Op::OpReturnValue:
case spv::Op::OpUnreachable:
case spv::Op::OpKill:
case spv::Op::OpTerminateInvocation:
return false;
default:
break;
}
return true;
}
Builder::ContinuingInfo::ContinuingInfo(
const ast::Statement* the_last_statement,
uint32_t loop_id,

View File

@ -216,6 +216,10 @@ class Builder {
functions_.back().push_var(operands);
}
/// @returns true if the current instruction insertion point is
/// inside a basic block.
bool InsideBasicBlock() const;
/// Converts a storage class to a SPIR-V storage class.
/// @param klass the storage class to convert
/// @returns the SPIR-V storage class or SpvStorageClassMax on error.

View File

@ -603,6 +603,91 @@ OpReturn
)");
}
TEST_F(BuilderTest, If_ElseIf_WithReturn) {
// crbug.com/tint/1315
// if (false) {
// } else if (true) {
// return;
// }
auto* if_stmt = If(Expr(false), Block(),
ast::ElseStatementList{Else(Expr(true), Block(Return()))});
auto* fn = Func("f", {}, ty.void_(), {if_stmt});
spirv::Builder& b = Build();
EXPECT_TRUE(b.GenerateFunction(fn)) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeVoid
%1 = OpTypeFunction %2
%5 = OpTypeBool
%6 = OpConstantFalse %5
%10 = OpConstantTrue %5
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(OpSelectionMerge %7 None
OpBranchConditional %6 %8 %9
%8 = OpLabel
OpBranch %7
%9 = OpLabel
OpSelectionMerge %11 None
OpBranchConditional %10 %12 %11
%12 = OpLabel
OpReturn
%11 = OpLabel
OpBranch %7
%7 = OpLabel
OpReturn
)");
}
TEST_F(BuilderTest, Loop_If_ElseIf_WithBreak) {
// crbug.com/tint/1315
// loop {
// if (false) {
// } else if (true) {
// break;
// }
// }
auto* if_stmt = If(Expr(false), Block(),
ast::ElseStatementList{Else(Expr(true), Block(Break()))});
auto* fn = Func("f", {}, ty.void_(), {Loop(Block(if_stmt))});
spirv::Builder& b = Build();
EXPECT_TRUE(b.GenerateFunction(fn)) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeVoid
%1 = OpTypeFunction %2
%9 = OpTypeBool
%10 = OpConstantFalse %9
%14 = OpConstantTrue %9
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(OpBranch %5
%5 = OpLabel
OpLoopMerge %6 %7 None
OpBranch %8
%8 = OpLabel
OpSelectionMerge %11 None
OpBranchConditional %10 %12 %13
%12 = OpLabel
OpBranch %11
%13 = OpLabel
OpSelectionMerge %15 None
OpBranchConditional %14 %16 %15
%16 = OpLabel
OpBranch %6
%15 = OpLabel
OpBranch %11
%11 = OpLabel
OpBranch %7
%7 = OpLabel
OpBranch %5
%6 = OpLabel
OpReturn
)");
}
} // namespace
} // namespace spirv
} // namespace writer

View File

@ -32,7 +32,7 @@ class Function {
/// Constructor
/// @param declaration the function declaration
/// @param label_op the operand for the initial function label
/// @param label_op the operand for function's entry block label
/// @param params the function parameters
Function(const Instruction& declaration,
const Operand& label_op,
@ -49,7 +49,7 @@ class Function {
/// @returns the declaration
const Instruction& declaration() const { return declaration_; }
/// @returns the function label id
/// @returns the label ID for the function entry block
uint32_t label_id() const { return label_op_.to_i(); }
/// Adds an instruction to the instruction list