From 66e7569e15a627eb4ed1849c9a4fe8bb435c3b15 Mon Sep 17 00:00:00 2001 From: David Neto Date: Mon, 20 Dec 2021 16:46:55 +0000 Subject: [PATCH] spirv-writer: Fix termination of basic blocks There are a few places where a branch or return is created, conditionally on whether a terminator was the last thing seen. The goal is to generate a SPIR-V basic block terminator exactly when needed, and to avoid generating a branch or return immediately after a prior terminator. Previously, the decision was based on the last thing seen in the AST. But we should instead check the emitted SPIR-V instead. This fixes cases such as a break or return inside an else-if. That's because an if/elseif is actually a selection inside a selection. Looking at the AST only works when trying to terminate the *inside* selection. In the outer recursive call, the last AST node is no longer a terminator, and we would skip generating the branch to the merge block. Fixed: tint:1315 Change-Id: I6b886ce85d1d681f2063997e469e0c1b4e5973a2 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/73480 Kokoro: Kokoro Auto-Submit: David Neto Reviewed-by: James Price Commit-Queue: James Price --- src/writer/spirv/builder.cc | 55 +++++++++++-------- src/writer/spirv/builder.h | 4 ++ src/writer/spirv/builder_if_test.cc | 85 +++++++++++++++++++++++++++++ src/writer/spirv/function.h | 4 +- 4 files changed, 124 insertions(+), 24 deletions(-) diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 04334a20c1..1f6246ef64 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -96,23 +96,6 @@ bool LastIsFallthrough(const ast::BlockStatement* stmts) { return !stmts->Empty() && stmts->Last()->Is(); } -// A terminator is anything which will cause a SPIR-V terminator to be emitted. -// This means things like breaks, fallthroughs and continues which all emit an -// OpBranch or return for the OpReturn emission. -bool LastIsTerminator(const ast::BlockStatement* stmts) { - if (IsAnyOf(stmts->Last())) { - return true; - } - - if (auto* block = As(stmts->Last())) { - return LastIsTerminator(block); - } - - return false; -} - /// Returns the matrix type that is `type` or that is wrapped by /// one or more levels of an arrays inside of `type`. /// @param type the given type, which must not be null @@ -650,7 +633,7 @@ bool Builder::GenerateFunction(const ast::Function* func_ast) { } } - if (!LastIsTerminator(func_ast->body)) { + if (InsideBasicBlock()) { if (func->ReturnType()->Is()) { push_function_inst(spv::Op::OpReturn, {}); } else { @@ -3500,7 +3483,7 @@ bool Builder::GenerateConditionalBlock( return false; } // We only branch if the last element of the body didn't already branch. - if (!LastIsTerminator(true_body)) { + if (InsideBasicBlock()) { if (!push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)})) { return false; @@ -3525,7 +3508,7 @@ bool Builder::GenerateConditionalBlock( return false; } } - if (!LastIsTerminator(else_stmt->body)) { + if (InsideBasicBlock()) { if (!push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)})) { return false; @@ -3673,7 +3656,7 @@ bool Builder::GenerateSwitchStatement(const ast::SwitchStatement* stmt) { {Operand::Int(case_ids[i + 1])})) { return false; } - } else if (!LastIsTerminator(item->body)) { + } else if (InsideBasicBlock()) { if (!push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)})) { return false; @@ -3765,7 +3748,7 @@ bool Builder::GenerateLoopStatement(const ast::LoopStatement* stmt) { } // We only branch if the last element of the body didn't already branch. - if (!LastIsTerminator(stmt->body)) { + if (InsideBasicBlock()) { if (!push_function_inst(spv::Op::OpBranch, {Operand::Int(continue_block_id)})) { return false; @@ -4414,6 +4397,34 @@ bool Builder::push_function_inst(spv::Op op, const OperandList& operands) { return true; } +bool Builder::InsideBasicBlock() const { + if (functions_.empty()) { + return false; + } + const auto& instructions = functions_.back().instructions(); + if (instructions.empty()) { + // The Function object does not explicitly represent its entry block + // label. So return *true* because an empty list means the only + // thing in the function is that entry block label. + return true; + } + const auto& inst = instructions.back(); + switch (inst.opcode()) { + case spv::Op::OpBranch: + case spv::Op::OpBranchConditional: + case spv::Op::OpSwitch: + case spv::Op::OpReturn: + case spv::Op::OpReturnValue: + case spv::Op::OpUnreachable: + case spv::Op::OpKill: + case spv::Op::OpTerminateInvocation: + return false; + default: + break; + } + return true; +} + Builder::ContinuingInfo::ContinuingInfo( const ast::Statement* the_last_statement, uint32_t loop_id, diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h index d2b5237100..fdfc3b30e9 100644 --- a/src/writer/spirv/builder.h +++ b/src/writer/spirv/builder.h @@ -216,6 +216,10 @@ class Builder { functions_.back().push_var(operands); } + /// @returns true if the current instruction insertion point is + /// inside a basic block. + bool InsideBasicBlock() const; + /// Converts a storage class to a SPIR-V storage class. /// @param klass the storage class to convert /// @returns the SPIR-V storage class or SpvStorageClassMax on error. diff --git a/src/writer/spirv/builder_if_test.cc b/src/writer/spirv/builder_if_test.cc index 6273bee90c..2fda0e1d62 100644 --- a/src/writer/spirv/builder_if_test.cc +++ b/src/writer/spirv/builder_if_test.cc @@ -603,6 +603,91 @@ OpReturn )"); } +TEST_F(BuilderTest, If_ElseIf_WithReturn) { + // crbug.com/tint/1315 + // if (false) { + // } else if (true) { + // return; + // } + + auto* if_stmt = If(Expr(false), Block(), + ast::ElseStatementList{Else(Expr(true), Block(Return()))}); + auto* fn = Func("f", {}, ty.void_(), {if_stmt}); + + spirv::Builder& b = Build(); + + EXPECT_TRUE(b.GenerateFunction(fn)) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeVoid +%1 = OpTypeFunction %2 +%5 = OpTypeBool +%6 = OpConstantFalse %5 +%10 = OpConstantTrue %5 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(OpSelectionMerge %7 None +OpBranchConditional %6 %8 %9 +%8 = OpLabel +OpBranch %7 +%9 = OpLabel +OpSelectionMerge %11 None +OpBranchConditional %10 %12 %11 +%12 = OpLabel +OpReturn +%11 = OpLabel +OpBranch %7 +%7 = OpLabel +OpReturn +)"); +} + +TEST_F(BuilderTest, Loop_If_ElseIf_WithBreak) { + // crbug.com/tint/1315 + // loop { + // if (false) { + // } else if (true) { + // break; + // } + // } + + auto* if_stmt = If(Expr(false), Block(), + ast::ElseStatementList{Else(Expr(true), Block(Break()))}); + auto* fn = Func("f", {}, ty.void_(), {Loop(Block(if_stmt))}); + + spirv::Builder& b = Build(); + + EXPECT_TRUE(b.GenerateFunction(fn)) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeVoid +%1 = OpTypeFunction %2 +%9 = OpTypeBool +%10 = OpConstantFalse %9 +%14 = OpConstantTrue %9 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(OpBranch %5 +%5 = OpLabel +OpLoopMerge %6 %7 None +OpBranch %8 +%8 = OpLabel +OpSelectionMerge %11 None +OpBranchConditional %10 %12 %13 +%12 = OpLabel +OpBranch %11 +%13 = OpLabel +OpSelectionMerge %15 None +OpBranchConditional %14 %16 %15 +%16 = OpLabel +OpBranch %6 +%15 = OpLabel +OpBranch %11 +%11 = OpLabel +OpBranch %7 +%7 = OpLabel +OpBranch %5 +%6 = OpLabel +OpReturn +)"); +} + } // namespace } // namespace spirv } // namespace writer diff --git a/src/writer/spirv/function.h b/src/writer/spirv/function.h index 926f33f312..df747ec5ab 100644 --- a/src/writer/spirv/function.h +++ b/src/writer/spirv/function.h @@ -32,7 +32,7 @@ class Function { /// Constructor /// @param declaration the function declaration - /// @param label_op the operand for the initial function label + /// @param label_op the operand for function's entry block label /// @param params the function parameters Function(const Instruction& declaration, const Operand& label_op, @@ -49,7 +49,7 @@ class Function { /// @returns the declaration const Instruction& declaration() const { return declaration_; } - /// @returns the function label id + /// @returns the label ID for the function entry block uint32_t label_id() const { return label_op_.to_i(); } /// Adds an instruction to the instruction list