[spirv-reader] Add switch-selection

- Avoid redundant switch-break.
  WGSL does an implicit break at the end of a switch case, because
  it has fallthrough.

TODO: Emit fallthrough

Bug: tint:3
Change-Id: Ida44b13181a01a2c1459c0447dac496ba5b97ffc
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/22961
Reviewed-by: dan sinclair <dsinclair@google.com>
This commit is contained in:
David Neto 2020-06-11 20:39:06 +00:00
parent be45ff5081
commit 416be308fc
3 changed files with 1451 additions and 38 deletions

View File

@ -14,6 +14,7 @@
#include "src/reader/spirv/function.h" #include "src/reader/spirv/function.h"
#include <algorithm>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
@ -28,6 +29,7 @@
#include "src/ast/assignment_statement.h" #include "src/ast/assignment_statement.h"
#include "src/ast/binary_expression.h" #include "src/ast/binary_expression.h"
#include "src/ast/break_statement.h" #include "src/ast/break_statement.h"
#include "src/ast/case_statement.h"
#include "src/ast/continue_statement.h" #include "src/ast/continue_statement.h"
#include "src/ast/else_statement.h" #include "src/ast/else_statement.h"
#include "src/ast/fallthrough_statement.h" #include "src/ast/fallthrough_statement.h"
@ -38,6 +40,7 @@
#include "src/ast/member_accessor_expression.h" #include "src/ast/member_accessor_expression.h"
#include "src/ast/return_statement.h" #include "src/ast/return_statement.h"
#include "src/ast/scalar_constructor_expression.h" #include "src/ast/scalar_constructor_expression.h"
#include "src/ast/sint_literal.h"
#include "src/ast/storage_class.h" #include "src/ast/storage_class.h"
#include "src/ast/switch_statement.h" #include "src/ast/switch_statement.h"
#include "src/ast/uint_literal.h" #include "src/ast/uint_literal.h"
@ -387,7 +390,7 @@ FunctionEmitter::StatementBlock::StatementBlock(
uint32_t end_id, uint32_t end_id,
CompletionAction completion_action, CompletionAction completion_action,
ast::StatementList statements, ast::StatementList statements,
ast::CaseStatementList cases) std::unique_ptr<ast::CaseStatementList> cases)
: construct_(construct), : construct_(construct),
end_id_(end_id), end_id_(end_id),
completion_action_(completion_action), completion_action_(completion_action),
@ -401,9 +404,8 @@ FunctionEmitter::StatementBlock::~StatementBlock() = default;
void FunctionEmitter::PushNewStatementBlock(const Construct* construct, void FunctionEmitter::PushNewStatementBlock(const Construct* construct,
uint32_t end_id, uint32_t end_id,
CompletionAction action) { CompletionAction action) {
statements_stack_.emplace_back(StatementBlock(construct, end_id, action, statements_stack_.emplace_back(
ast::StatementList{}, StatementBlock{construct, end_id, action, ast::StatementList{}, nullptr});
ast::CaseStatementList{}));
} }
const ast::StatementList& FunctionEmitter::ast_body() { const ast::StatementList& FunctionEmitter::ast_body() {
@ -981,7 +983,6 @@ bool FunctionEmitter::FindSwitchCaseHeaders() {
// Process case targets. // Process case targets.
for (uint32_t iarg = 2; iarg + 1 < branch->NumInOperands(); iarg += 2) { for (uint32_t iarg = 2; iarg + 1 < branch->NumInOperands(); iarg += 2) {
const auto o = branch->GetInOperand(iarg);
const auto value = branch->GetInOperand(iarg).AsLiteralUint64(); const auto value = branch->GetInOperand(iarg).AsLiteralUint64();
const auto case_target_id = branch->GetSingleWordInOperand(iarg + 1); const auto case_target_id = branch->GetSingleWordInOperand(iarg + 1);
@ -1715,8 +1716,14 @@ bool FunctionEmitter::EmitBasicBlock(const BlockInfo& block_info) {
break; break;
case Construct::kSwitchSelection: case Construct::kSwitchSelection:
if (!EmitStatementsInBasicBlock(block_info, &emitted)) {
return false;
}
if (!EmitSwitchStart(block_info)) {
return false;
}
has_normal_terminator = false; has_normal_terminator = false;
return Fail() << "unhandled: switch construct"; break;
} }
} }
@ -1827,6 +1834,128 @@ bool FunctionEmitter::EmitIfStart(const BlockInfo& block_info) {
return success(); return success();
} }
bool FunctionEmitter::EmitSwitchStart(const BlockInfo& block_info) {
// The block is the if-header block. So its construct is the if construct.
auto* construct = block_info.construct;
assert(construct->kind == Construct::kSwitchSelection);
assert(construct->begin_id == block_info.id);
const auto* branch = block_info.basic_block->terminator();
auto* const switch_stmt =
AddStatement(std::make_unique<ast::SwitchStatement>())->AsSwitch();
const auto selector_id = branch->GetSingleWordInOperand(0);
// Generate the code for the selector.
auto selector = MakeExpression(selector_id);
switch_stmt->set_condition(std::move(selector.expr));
// First, push the statement block for the entire switch. All the actual
// work is done by completion actions of the case/default clauses.
PushNewStatementBlock(
construct, construct->end_id, [switch_stmt](StatementBlock* s) {
switch_stmt->set_body(std::move(*std::move(s->cases_)));
});
statements_stack_.back().cases_ = std::make_unique<ast::CaseStatementList>();
// Grab a pointer to the case list. It will get buried in the statement block
// stack.
auto* cases = statements_stack_.back().cases_.get();
// We will push statement-blocks onto the stack to gather the statements in
// the default clause and cases clauses. Determine the list of blocks
// that start each clause.
std::vector<const BlockInfo*> clause_heads;
// Collect the case clauses, even if they are just the merge block.
// First the default clause.
const auto default_id = branch->GetSingleWordInOperand(1);
const auto* default_info = GetBlockInfo(default_id);
clause_heads.push_back(default_info);
// Now the case clauses.
for (uint32_t iarg = 2; iarg + 1 < branch->NumInOperands(); iarg += 2) {
const auto case_target_id = branch->GetSingleWordInOperand(iarg + 1);
clause_heads.push_back(GetBlockInfo(case_target_id));
}
std::stable_sort(clause_heads.begin(), clause_heads.end(),
[](const BlockInfo* lhs, const BlockInfo* rhs) {
return lhs->pos < rhs->pos;
});
// Remove duplicates
{
// Use read index r, and write index w.
// Invariant: w <= r;
size_t w = 0;
for (size_t r = 0; r < clause_heads.size(); ++r) {
if (clause_heads[r] != clause_heads[w]) {
++w; // Advance the write cursor.
}
clause_heads[w] = clause_heads[r];
}
// We know it's not empty because it always has at least a default clause.
assert(!clause_heads.empty());
clause_heads.resize(w + 1);
}
// Push them on in reverse order.
const auto last_clause_index = clause_heads.size() - 1;
for (size_t i = last_clause_index;; --i) {
// Create the case clause. Temporarily put it in the wrong order
// on the case statement list.
cases->emplace_back(std::make_unique<ast::CaseStatement>());
auto* clause = cases->back().get();
// Create a list of integer literals for the selector values leading to
// this case clause.
ast::CaseSelectorList selectors;
const auto* values_ptr = clause_heads[i]->case_values.get();
const bool has_selectors = (values_ptr && !values_ptr->empty());
if (has_selectors) {
std::vector<uint64_t> values(values_ptr->begin(), values_ptr->end());
std::stable_sort(values.begin(), values.end());
for (auto value : values) {
// The rest of this module can handle up to 64 bit switch values.
// The Tint AST handles 32-bit values.
const uint32_t value32 = uint32_t(value & 0xFFFFFFFF);
if (selector.type->is_unsigned_scalar_or_vector()) {
selectors.emplace_back(
std::make_unique<ast::UintLiteral>(selector.type, value32));
} else {
selectors.emplace_back(
std::make_unique<ast::SintLiteral>(selector.type, value32));
}
}
clause->set_selectors(std::move(selectors));
}
// Where does this clause end?
const auto end_id = (i + 1 < clause_heads.size()) ? clause_heads[i + 1]->id
: construct->end_id;
PushNewStatementBlock(construct, end_id, [clause](StatementBlock* s) {
clause->set_body(std::move(s->statements_));
});
if ((default_info == clause_heads[i]) && has_selectors &&
construct->ContainsPos(default_info->pos)) {
// Generate a default clause with a just fallthrough.
ast::StatementList stmts;
stmts.emplace_back(std::make_unique<ast::FallthroughStatement>());
auto case_stmt = std::make_unique<ast::CaseStatement>();
case_stmt->set_body(std::move(stmts));
cases->emplace_back(std::move(case_stmt));
}
if (i == 0) {
break;
}
}
// We've listed cases in reverse order in the switch statement. Reorder them
// to match the presentation order in WGSL.
std::reverse(cases->begin(), cases->end());
return success();
}
bool FunctionEmitter::EmitLoopStart(const Construct* construct) { bool FunctionEmitter::EmitLoopStart(const Construct* construct) {
auto* loop = AddStatement(std::make_unique<ast::LoopStatement>())->AsLoop(); auto* loop = AddStatement(std::make_unique<ast::LoopStatement>())->AsLoop();
PushNewStatementBlock( PushNewStatementBlock(
@ -1946,7 +2075,31 @@ std::unique_ptr<ast::Statement> FunctionEmitter::MakeBranch(
case EdgeKind::kBack: case EdgeKind::kBack:
// Nothing to do. The loop backedge is implicit. // Nothing to do. The loop backedge is implicit.
break; break;
case EdgeKind::kSwitchBreak: case EdgeKind::kSwitchBreak: {
// Don't bother with a break at the end of a case.
const auto header = dest_info.header_for_merge;
assert(header != 0);
const auto* exiting_construct = GetBlockInfo(header)->construct;
assert(exiting_construct->kind == Construct::kSwitchSelection);
const auto candidate_next_case_pos = src_info.pos + 1;
// Leaving the last block from the last case?
if (candidate_next_case_pos == dest_info.pos) {
// No break needed.
return nullptr;
}
// Leaving the last block from not-the-last-case?
if (exiting_construct->ContainsPos(candidate_next_case_pos)) {
const auto* candidate_next_case =
GetBlockInfo(block_order_[candidate_next_case_pos]);
if (candidate_next_case->case_head_for == exiting_construct ||
candidate_next_case->default_head_for == exiting_construct) {
// No break needed.
return nullptr;
}
}
// We need a break.
return std::make_unique<ast::BreakStatement>();
}
case EdgeKind::kLoopBreak: case EdgeKind::kLoopBreak:
return std::make_unique<ast::BreakStatement>(); return std::make_unique<ast::BreakStatement>();
case EdgeKind::kLoopContinue: case EdgeKind::kLoopContinue:

View File

@ -289,6 +289,13 @@ class FunctionEmitter {
/// @returns false if emission failed. /// @returns false if emission failed.
bool EmitIfStart(const BlockInfo& block_info); bool EmitIfStart(const BlockInfo& block_info);
/// Emits a SwitchStatement, including its condition expression, and sets
/// up the statement stack to accumulate subsequent basic blocks into
/// the default clause and case clauses.
/// @param block_info the switch-selection header block
/// @returns false if emission failed.
bool EmitSwitchStart(const BlockInfo& block_info);
/// Emits a LoopStatement, and pushes a new StatementBlock to accumulate /// Emits a LoopStatement, and pushes a new StatementBlock to accumulate
/// the remaining instructions in the current block and subsequent blocks /// the remaining instructions in the current block and subsequent blocks
/// in the loop. /// in the loop.
@ -375,7 +382,7 @@ class FunctionEmitter {
/// Gets the block info for a block ID, if any exists /// Gets the block info for a block ID, if any exists
/// @param id the SPIR-V ID of the OpLabel instruction starting the block /// @param id the SPIR-V ID of the OpLabel instruction starting the block
/// @returns the block info for the given ID, if it exists, or nullptr /// @returns the block info for the given ID, if it exists, or nullptr
BlockInfo* GetBlockInfo(uint32_t id) { BlockInfo* GetBlockInfo(uint32_t id) const {
auto where = block_info_.find(id); auto where = block_info_.find(id);
if (where == block_info_.end()) if (where == block_info_.end())
return nullptr; return nullptr;
@ -434,7 +441,7 @@ class FunctionEmitter {
uint32_t end_id, uint32_t end_id,
CompletionAction completion_action, CompletionAction completion_action,
ast::StatementList statements, ast::StatementList statements,
ast::CaseStatementList cases); std::unique_ptr<ast::CaseStatementList> cases);
StatementBlock(StatementBlock&&); StatementBlock(StatementBlock&&);
~StatementBlock(); ~StatementBlock();
@ -449,10 +456,13 @@ class FunctionEmitter {
// Only one of |statements| or |cases| is active. // Only one of |statements| or |cases| is active.
// The list of statements being built. // The list of statements being built, if this construct is not a switch.
ast::StatementList statements_; ast::StatementList statements_;
// The list of cases being built, for a switch. // The list of switch cases being built, if this construct is a switch.
ast::CaseStatementList cases_; // The algorithm will cache a pointer to the vector. We want that pointer
// to be stable no matter how |statements_stack_| is resized. That's
// why we make this a unique_ptr rather than just a plain vector.
std::unique_ptr<ast::CaseStatementList> cases_;
}; };
/// Pushes an empty statement block onto the statements stack. /// Pushes an empty statement block onto the statements stack.

File diff suppressed because it is too large Load Diff