mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-11 14:41:50 +00:00
spir-writer: handle break continuing block
The continuing block can exit the loop in very constrained ways:
When a break statement is placed such that it would exit from a loop’s
§ 7.3.8 Continuing Statement, then:
- The break statement must appear as either:
- The only statement in the if clause of an if statement that has:
- no else clause or an empty else clause
- no elseif clauses
- The only statement in the else clause of an if statement that has an
empty if clause and no elseif clauses.
- That if statement must appear last in the continuing clause.
By design, this allows a lossless round-trip from SPIR-V to WGSL and
back to SPIR-V. But that requires this special case construct in WGSL
to be translated to an OpBranchConditional with one target being
the loop's megre block (which is where 'break' branches to), and the
other targets the loop header (which is the loop backedge). That
OpBranchConditional takes the place of the normal case of an
unconditional backedge.
Avoids errors like this:
continue construct with the continue target X is not
post dominated by the back-edge block Y
Fixed: 1034
Change-Id: If472a179380b8d77af746a3cd8e279c8a5e56b37
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/59800
Auto-Submit: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: David Neto <dneto@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
@@ -29,6 +29,7 @@
|
||||
#include "src/ast/bitcast_expression.h"
|
||||
#include "src/ast/bool.h"
|
||||
#include "src/ast/bool_literal.h"
|
||||
#include "src/ast/break_statement.h"
|
||||
#include "src/ast/call_expression.h"
|
||||
#include "src/ast/call_statement.h"
|
||||
#include "src/ast/case_statement.h"
|
||||
@@ -1782,6 +1783,17 @@ class ProgramBuilder {
|
||||
return func;
|
||||
}
|
||||
|
||||
/// Creates an ast::BreakStatement
|
||||
/// @param source the source information
|
||||
/// @returns the break statement pointer
|
||||
ast::BreakStatement* Break(const Source& source) {
|
||||
return create<ast::BreakStatement>(source);
|
||||
}
|
||||
|
||||
/// Creates an ast::BreakStatement
|
||||
/// @returns the break statement pointer
|
||||
ast::BreakStatement* Break() { return create<ast::BreakStatement>(); }
|
||||
|
||||
/// Creates an ast::ReturnStatement with no return value
|
||||
/// @param source the source information
|
||||
/// @returns the return statement pointer
|
||||
|
||||
@@ -3447,6 +3447,49 @@ bool Builder::GenerateConditionalBlock(
|
||||
}
|
||||
|
||||
bool Builder::GenerateIfStatement(ast::IfStatement* stmt) {
|
||||
if (!continuing_stack_.empty() &&
|
||||
stmt == continuing_stack_.back().last_statement->As<ast::IfStatement>()) {
|
||||
const ContinuingInfo& ci = continuing_stack_.back();
|
||||
// Match one of two patterns: the break-if and break-unless patterns.
|
||||
//
|
||||
// The break-if pattern:
|
||||
// continuing { ...
|
||||
// if (cond) { break; }
|
||||
// }
|
||||
//
|
||||
// The break-unless pattern:
|
||||
// continuing { ...
|
||||
// if (cond) {} else {break;}
|
||||
// }
|
||||
auto is_just_a_break = [](ast::BlockStatement* block) {
|
||||
return block && (block->size() == 1) &&
|
||||
block->last()->Is<ast::BreakStatement>();
|
||||
};
|
||||
if (is_just_a_break(stmt->body()) && !stmt->has_else_statements()) {
|
||||
// It's a break-if.
|
||||
TINT_ASSERT(Writer, !backedge_stack_.empty());
|
||||
const auto cond_id = GenerateExpression(stmt->condition());
|
||||
backedge_stack_.back() =
|
||||
Backedge(spv::Op::OpBranchConditional,
|
||||
{Operand::Int(cond_id), Operand::Int(ci.break_target_id),
|
||||
Operand::Int(ci.loop_header_id)});
|
||||
return true;
|
||||
} else if (stmt->body()->empty()) {
|
||||
const auto& es = stmt->else_statements();
|
||||
if (es.size() == 1 && !es.back()->HasCondition() &&
|
||||
is_just_a_break(es.back()->body())) {
|
||||
// It's a break-unless.
|
||||
TINT_ASSERT(Writer, !backedge_stack_.empty());
|
||||
const auto cond_id = GenerateExpression(stmt->condition());
|
||||
backedge_stack_.back() =
|
||||
Backedge(spv::Op::OpBranchConditional,
|
||||
{Operand::Int(cond_id), Operand::Int(ci.loop_header_id),
|
||||
Operand::Int(ci.break_target_id)});
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!GenerateConditionalBlock(stmt->condition(), stmt->body(), 0,
|
||||
stmt->else_statements())) {
|
||||
return false;
|
||||
@@ -3603,6 +3646,11 @@ bool Builder::GenerateLoopStatement(ast::LoopStatement* stmt) {
|
||||
continue_stack_.push_back(continue_block_id);
|
||||
merge_stack_.push_back(merge_block_id);
|
||||
|
||||
// Usually, the backedge is a simple branch. This will be modified if the
|
||||
// backedge block in the continuing construct has an exiting edge.
|
||||
backedge_stack_.emplace_back(spv::Op::OpBranch,
|
||||
OperandList{Operand::Int(loop_header_id)});
|
||||
|
||||
if (!push_function_inst(spv::Op::OpBranch, {Operand::Int(body_block_id)})) {
|
||||
return false;
|
||||
}
|
||||
@@ -3630,16 +3678,23 @@ bool Builder::GenerateLoopStatement(ast::LoopStatement* stmt) {
|
||||
return false;
|
||||
}
|
||||
if (stmt->has_continuing()) {
|
||||
continuing_stack_.emplace_back(stmt->continuing()->last(), loop_header_id,
|
||||
merge_block_id);
|
||||
if (!GenerateBlockStatementWithoutScoping(stmt->continuing())) {
|
||||
return false;
|
||||
}
|
||||
continuing_stack_.pop_back();
|
||||
}
|
||||
|
||||
scope_stack_.pop_scope();
|
||||
|
||||
if (!push_function_inst(spv::Op::OpBranch, {Operand::Int(loop_header_id)})) {
|
||||
// Generate the backedge.
|
||||
TINT_ASSERT(Writer, !backedge_stack_.empty());
|
||||
const Backedge& backedge = backedge_stack_.back();
|
||||
if (!push_function_inst(backedge.opcode, backedge.operands)) {
|
||||
return false;
|
||||
}
|
||||
backedge_stack_.pop_back();
|
||||
|
||||
merge_stack_.pop_back();
|
||||
continue_stack_.pop_back();
|
||||
@@ -4260,6 +4315,26 @@ bool Builder::push_function_inst(spv::Op op, const OperandList& operands) {
|
||||
return true;
|
||||
}
|
||||
|
||||
Builder::ContinuingInfo::ContinuingInfo(
|
||||
const ast::Statement* the_last_statement,
|
||||
uint32_t loop_id,
|
||||
uint32_t break_id)
|
||||
: last_statement(the_last_statement),
|
||||
loop_header_id(loop_id),
|
||||
break_target_id(break_id) {
|
||||
TINT_ASSERT(Writer, last_statement != nullptr);
|
||||
TINT_ASSERT(Writer, loop_header_id != 0u);
|
||||
TINT_ASSERT(Writer, break_target_id != 0u);
|
||||
}
|
||||
|
||||
Builder::Backedge::Backedge(spv::Op the_opcode, OperandList the_operands)
|
||||
: opcode(the_opcode), operands(the_operands) {}
|
||||
|
||||
Builder::Backedge::Backedge(const Builder::Backedge& other) = default;
|
||||
Builder::Backedge& Builder::Backedge::operator=(
|
||||
const Builder::Backedge& other) = default;
|
||||
Builder::Backedge::~Backedge() = default;
|
||||
|
||||
} // namespace spirv
|
||||
} // namespace writer
|
||||
} // namespace tint
|
||||
|
||||
@@ -584,6 +584,33 @@ class Builder {
|
||||
std::vector<uint32_t> continue_stack_;
|
||||
std::unordered_set<uint32_t> capability_set_;
|
||||
bool has_overridable_workgroup_size_ = false;
|
||||
|
||||
struct ContinuingInfo {
|
||||
ContinuingInfo(const ast::Statement* last_statement,
|
||||
uint32_t loop_header_id,
|
||||
uint32_t break_target_id);
|
||||
// The last statement in the continiung block.
|
||||
const ast::Statement* const last_statement = nullptr;
|
||||
// The ID of the loop header
|
||||
const uint32_t loop_header_id = 0u;
|
||||
// The ID of the merge block for the loop.
|
||||
const uint32_t break_target_id = 0u;
|
||||
};
|
||||
// Stack of nodes, where each is the last statement in a surrounding
|
||||
// continuing block.
|
||||
std::vector<ContinuingInfo> continuing_stack_;
|
||||
|
||||
// The instruction to emit as the backedge of a loop.
|
||||
struct Backedge {
|
||||
Backedge(spv::Op, OperandList);
|
||||
Backedge(const Backedge&);
|
||||
Backedge& operator=(const Backedge&);
|
||||
~Backedge();
|
||||
|
||||
spv::Op opcode;
|
||||
OperandList operands;
|
||||
};
|
||||
std::vector<Backedge> backedge_stack_;
|
||||
};
|
||||
|
||||
} // namespace spirv
|
||||
|
||||
@@ -219,6 +219,186 @@ OpBranch %1
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Loop_WithContinuing_BreakIf) {
|
||||
// loop {
|
||||
// continuing {
|
||||
// if (true) { break; }
|
||||
// }
|
||||
// }
|
||||
|
||||
auto* if_stmt = create<ast::IfStatement>(Expr(true), Block(Break()),
|
||||
ast::ElseStatementList{});
|
||||
auto* continuing = Block(if_stmt);
|
||||
auto* loop = Loop(Block(), continuing);
|
||||
WrapInFunction(loop);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
|
||||
b.push_function(Function{});
|
||||
|
||||
EXPECT_TRUE(b.GenerateLoopStatement(loop)) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeBool
|
||||
%6 = OpConstantTrue %5
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
R"(OpBranch %1
|
||||
%1 = OpLabel
|
||||
OpLoopMerge %2 %3 None
|
||||
OpBranch %4
|
||||
%4 = OpLabel
|
||||
OpBranch %3
|
||||
%3 = OpLabel
|
||||
OpBranchConditional %6 %2 %1
|
||||
%2 = OpLabel
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Loop_WithContinuing_BreakUnless) {
|
||||
// loop {
|
||||
// continuing {
|
||||
// if (true) {} else { break; }
|
||||
// }
|
||||
// }
|
||||
auto* if_stmt = create<ast::IfStatement>(
|
||||
Expr(true), Block(),
|
||||
ast::ElseStatementList{Else(nullptr, Block(Break()))});
|
||||
auto* continuing = Block(if_stmt);
|
||||
auto* loop = Loop(Block(), continuing);
|
||||
WrapInFunction(loop);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
|
||||
b.push_function(Function{});
|
||||
|
||||
EXPECT_TRUE(b.GenerateLoopStatement(loop)) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeBool
|
||||
%6 = OpConstantTrue %5
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
R"(OpBranch %1
|
||||
%1 = OpLabel
|
||||
OpLoopMerge %2 %3 None
|
||||
OpBranch %4
|
||||
%4 = OpLabel
|
||||
OpBranch %3
|
||||
%3 = OpLabel
|
||||
OpBranchConditional %6 %1 %2
|
||||
%2 = OpLabel
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Loop_WithContinuing_BreakIf_Nested) {
|
||||
// Make sure the right backedge and break target are used.
|
||||
// loop {
|
||||
// continuing {
|
||||
// loop {
|
||||
// continuing {
|
||||
// if (true) { break; }
|
||||
// }
|
||||
// }
|
||||
// if (true) { break; }
|
||||
// }
|
||||
// }
|
||||
|
||||
auto* inner_if_stmt = create<ast::IfStatement>(Expr(true), Block(Break()),
|
||||
ast::ElseStatementList{});
|
||||
auto* inner_continuing = Block(inner_if_stmt);
|
||||
auto* inner_loop = Loop(Block(), inner_continuing);
|
||||
|
||||
auto* outer_if_stmt = create<ast::IfStatement>(Expr(true), Block(Break()),
|
||||
ast::ElseStatementList{});
|
||||
auto* outer_continuing = Block(inner_loop, outer_if_stmt);
|
||||
auto* outer_loop = Loop(Block(), outer_continuing);
|
||||
|
||||
WrapInFunction(outer_loop);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
|
||||
b.push_function(Function{});
|
||||
|
||||
EXPECT_TRUE(b.GenerateLoopStatement(outer_loop)) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%9 = OpTypeBool
|
||||
%10 = OpConstantTrue %9
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
R"(OpBranch %1
|
||||
%1 = OpLabel
|
||||
OpLoopMerge %2 %3 None
|
||||
OpBranch %4
|
||||
%4 = OpLabel
|
||||
OpBranch %3
|
||||
%3 = OpLabel
|
||||
OpBranch %5
|
||||
%5 = OpLabel
|
||||
OpLoopMerge %6 %7 None
|
||||
OpBranch %8
|
||||
%8 = OpLabel
|
||||
OpBranch %7
|
||||
%7 = OpLabel
|
||||
OpBranchConditional %10 %6 %5
|
||||
%6 = OpLabel
|
||||
OpBranchConditional %10 %2 %1
|
||||
%2 = OpLabel
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Loop_WithContinuing_BreakUnless_Nested) {
|
||||
// Make sure the right backedge and break target are used.
|
||||
// loop {
|
||||
// continuing {
|
||||
// loop {
|
||||
// continuing {
|
||||
// if (true) {} else { break; }
|
||||
// }
|
||||
// }
|
||||
// if (true) {} else { break; }
|
||||
// }
|
||||
// }
|
||||
|
||||
auto* inner_if_stmt = create<ast::IfStatement>(
|
||||
Expr(true), Block(),
|
||||
ast::ElseStatementList{Else(nullptr, Block(Break()))});
|
||||
auto* inner_continuing = Block(inner_if_stmt);
|
||||
auto* inner_loop = Loop(Block(), inner_continuing);
|
||||
|
||||
auto* outer_if_stmt = create<ast::IfStatement>(
|
||||
Expr(true), Block(),
|
||||
ast::ElseStatementList{Else(nullptr, Block(Break()))});
|
||||
auto* outer_continuing = Block(inner_loop, outer_if_stmt);
|
||||
auto* outer_loop = Loop(Block(), outer_continuing);
|
||||
|
||||
WrapInFunction(outer_loop);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
|
||||
b.push_function(Function{});
|
||||
|
||||
EXPECT_TRUE(b.GenerateLoopStatement(outer_loop)) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%9 = OpTypeBool
|
||||
%10 = OpConstantTrue %9
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
R"(OpBranch %1
|
||||
%1 = OpLabel
|
||||
OpLoopMerge %2 %3 None
|
||||
OpBranch %4
|
||||
%4 = OpLabel
|
||||
OpBranch %3
|
||||
%3 = OpLabel
|
||||
OpBranch %5
|
||||
%5 = OpLabel
|
||||
OpLoopMerge %6 %7 None
|
||||
OpBranch %8
|
||||
%8 = OpLabel
|
||||
OpBranch %7
|
||||
%7 = OpLabel
|
||||
OpBranchConditional %10 %5 %6
|
||||
%6 = OpLabel
|
||||
OpBranchConditional %10 %1 %2
|
||||
%2 = OpLabel
|
||||
)");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace spirv
|
||||
} // namespace writer
|
||||
|
||||
Reference in New Issue
Block a user