[spirv-reader] Add switch-selection
- Avoid redundant switch-break. WGSL does an implicit break at the end of a switch case, because it has fallthrough. TODO: Emit fallthrough Bug: tint:3 Change-Id: Ida44b13181a01a2c1459c0447dac496ba5b97ffc Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/22961 Reviewed-by: dan sinclair <dsinclair@google.com>
This commit is contained in:
parent
be45ff5081
commit
416be308fc
|
@ -14,6 +14,7 @@
|
||||||
|
|
||||||
#include "src/reader/spirv/function.h"
|
#include "src/reader/spirv/function.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
@ -28,6 +29,7 @@
|
||||||
#include "src/ast/assignment_statement.h"
|
#include "src/ast/assignment_statement.h"
|
||||||
#include "src/ast/binary_expression.h"
|
#include "src/ast/binary_expression.h"
|
||||||
#include "src/ast/break_statement.h"
|
#include "src/ast/break_statement.h"
|
||||||
|
#include "src/ast/case_statement.h"
|
||||||
#include "src/ast/continue_statement.h"
|
#include "src/ast/continue_statement.h"
|
||||||
#include "src/ast/else_statement.h"
|
#include "src/ast/else_statement.h"
|
||||||
#include "src/ast/fallthrough_statement.h"
|
#include "src/ast/fallthrough_statement.h"
|
||||||
|
@ -38,6 +40,7 @@
|
||||||
#include "src/ast/member_accessor_expression.h"
|
#include "src/ast/member_accessor_expression.h"
|
||||||
#include "src/ast/return_statement.h"
|
#include "src/ast/return_statement.h"
|
||||||
#include "src/ast/scalar_constructor_expression.h"
|
#include "src/ast/scalar_constructor_expression.h"
|
||||||
|
#include "src/ast/sint_literal.h"
|
||||||
#include "src/ast/storage_class.h"
|
#include "src/ast/storage_class.h"
|
||||||
#include "src/ast/switch_statement.h"
|
#include "src/ast/switch_statement.h"
|
||||||
#include "src/ast/uint_literal.h"
|
#include "src/ast/uint_literal.h"
|
||||||
|
@ -387,7 +390,7 @@ FunctionEmitter::StatementBlock::StatementBlock(
|
||||||
uint32_t end_id,
|
uint32_t end_id,
|
||||||
CompletionAction completion_action,
|
CompletionAction completion_action,
|
||||||
ast::StatementList statements,
|
ast::StatementList statements,
|
||||||
ast::CaseStatementList cases)
|
std::unique_ptr<ast::CaseStatementList> cases)
|
||||||
: construct_(construct),
|
: construct_(construct),
|
||||||
end_id_(end_id),
|
end_id_(end_id),
|
||||||
completion_action_(completion_action),
|
completion_action_(completion_action),
|
||||||
|
@ -401,9 +404,8 @@ FunctionEmitter::StatementBlock::~StatementBlock() = default;
|
||||||
void FunctionEmitter::PushNewStatementBlock(const Construct* construct,
|
void FunctionEmitter::PushNewStatementBlock(const Construct* construct,
|
||||||
uint32_t end_id,
|
uint32_t end_id,
|
||||||
CompletionAction action) {
|
CompletionAction action) {
|
||||||
statements_stack_.emplace_back(StatementBlock(construct, end_id, action,
|
statements_stack_.emplace_back(
|
||||||
ast::StatementList{},
|
StatementBlock{construct, end_id, action, ast::StatementList{}, nullptr});
|
||||||
ast::CaseStatementList{}));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const ast::StatementList& FunctionEmitter::ast_body() {
|
const ast::StatementList& FunctionEmitter::ast_body() {
|
||||||
|
@ -981,7 +983,6 @@ bool FunctionEmitter::FindSwitchCaseHeaders() {
|
||||||
|
|
||||||
// Process case targets.
|
// Process case targets.
|
||||||
for (uint32_t iarg = 2; iarg + 1 < branch->NumInOperands(); iarg += 2) {
|
for (uint32_t iarg = 2; iarg + 1 < branch->NumInOperands(); iarg += 2) {
|
||||||
const auto o = branch->GetInOperand(iarg);
|
|
||||||
const auto value = branch->GetInOperand(iarg).AsLiteralUint64();
|
const auto value = branch->GetInOperand(iarg).AsLiteralUint64();
|
||||||
const auto case_target_id = branch->GetSingleWordInOperand(iarg + 1);
|
const auto case_target_id = branch->GetSingleWordInOperand(iarg + 1);
|
||||||
|
|
||||||
|
@ -1715,8 +1716,14 @@ bool FunctionEmitter::EmitBasicBlock(const BlockInfo& block_info) {
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case Construct::kSwitchSelection:
|
case Construct::kSwitchSelection:
|
||||||
|
if (!EmitStatementsInBasicBlock(block_info, &emitted)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!EmitSwitchStart(block_info)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
has_normal_terminator = false;
|
has_normal_terminator = false;
|
||||||
return Fail() << "unhandled: switch construct";
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1827,6 +1834,128 @@ bool FunctionEmitter::EmitIfStart(const BlockInfo& block_info) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool FunctionEmitter::EmitSwitchStart(const BlockInfo& block_info) {
|
||||||
|
// The block is the if-header block. So its construct is the if construct.
|
||||||
|
auto* construct = block_info.construct;
|
||||||
|
assert(construct->kind == Construct::kSwitchSelection);
|
||||||
|
assert(construct->begin_id == block_info.id);
|
||||||
|
const auto* branch = block_info.basic_block->terminator();
|
||||||
|
|
||||||
|
auto* const switch_stmt =
|
||||||
|
AddStatement(std::make_unique<ast::SwitchStatement>())->AsSwitch();
|
||||||
|
const auto selector_id = branch->GetSingleWordInOperand(0);
|
||||||
|
// Generate the code for the selector.
|
||||||
|
auto selector = MakeExpression(selector_id);
|
||||||
|
switch_stmt->set_condition(std::move(selector.expr));
|
||||||
|
|
||||||
|
// First, push the statement block for the entire switch. All the actual
|
||||||
|
// work is done by completion actions of the case/default clauses.
|
||||||
|
PushNewStatementBlock(
|
||||||
|
construct, construct->end_id, [switch_stmt](StatementBlock* s) {
|
||||||
|
switch_stmt->set_body(std::move(*std::move(s->cases_)));
|
||||||
|
});
|
||||||
|
statements_stack_.back().cases_ = std::make_unique<ast::CaseStatementList>();
|
||||||
|
// Grab a pointer to the case list. It will get buried in the statement block
|
||||||
|
// stack.
|
||||||
|
auto* cases = statements_stack_.back().cases_.get();
|
||||||
|
|
||||||
|
// We will push statement-blocks onto the stack to gather the statements in
|
||||||
|
// the default clause and cases clauses. Determine the list of blocks
|
||||||
|
// that start each clause.
|
||||||
|
std::vector<const BlockInfo*> clause_heads;
|
||||||
|
|
||||||
|
// Collect the case clauses, even if they are just the merge block.
|
||||||
|
// First the default clause.
|
||||||
|
const auto default_id = branch->GetSingleWordInOperand(1);
|
||||||
|
const auto* default_info = GetBlockInfo(default_id);
|
||||||
|
clause_heads.push_back(default_info);
|
||||||
|
// Now the case clauses.
|
||||||
|
for (uint32_t iarg = 2; iarg + 1 < branch->NumInOperands(); iarg += 2) {
|
||||||
|
const auto case_target_id = branch->GetSingleWordInOperand(iarg + 1);
|
||||||
|
clause_heads.push_back(GetBlockInfo(case_target_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::stable_sort(clause_heads.begin(), clause_heads.end(),
|
||||||
|
[](const BlockInfo* lhs, const BlockInfo* rhs) {
|
||||||
|
return lhs->pos < rhs->pos;
|
||||||
|
});
|
||||||
|
// Remove duplicates
|
||||||
|
{
|
||||||
|
// Use read index r, and write index w.
|
||||||
|
// Invariant: w <= r;
|
||||||
|
size_t w = 0;
|
||||||
|
for (size_t r = 0; r < clause_heads.size(); ++r) {
|
||||||
|
if (clause_heads[r] != clause_heads[w]) {
|
||||||
|
++w; // Advance the write cursor.
|
||||||
|
}
|
||||||
|
clause_heads[w] = clause_heads[r];
|
||||||
|
}
|
||||||
|
// We know it's not empty because it always has at least a default clause.
|
||||||
|
assert(!clause_heads.empty());
|
||||||
|
clause_heads.resize(w + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Push them on in reverse order.
|
||||||
|
const auto last_clause_index = clause_heads.size() - 1;
|
||||||
|
for (size_t i = last_clause_index;; --i) {
|
||||||
|
// Create the case clause. Temporarily put it in the wrong order
|
||||||
|
// on the case statement list.
|
||||||
|
cases->emplace_back(std::make_unique<ast::CaseStatement>());
|
||||||
|
auto* clause = cases->back().get();
|
||||||
|
|
||||||
|
// Create a list of integer literals for the selector values leading to
|
||||||
|
// this case clause.
|
||||||
|
ast::CaseSelectorList selectors;
|
||||||
|
const auto* values_ptr = clause_heads[i]->case_values.get();
|
||||||
|
const bool has_selectors = (values_ptr && !values_ptr->empty());
|
||||||
|
if (has_selectors) {
|
||||||
|
std::vector<uint64_t> values(values_ptr->begin(), values_ptr->end());
|
||||||
|
std::stable_sort(values.begin(), values.end());
|
||||||
|
for (auto value : values) {
|
||||||
|
// The rest of this module can handle up to 64 bit switch values.
|
||||||
|
// The Tint AST handles 32-bit values.
|
||||||
|
const uint32_t value32 = uint32_t(value & 0xFFFFFFFF);
|
||||||
|
if (selector.type->is_unsigned_scalar_or_vector()) {
|
||||||
|
selectors.emplace_back(
|
||||||
|
std::make_unique<ast::UintLiteral>(selector.type, value32));
|
||||||
|
} else {
|
||||||
|
selectors.emplace_back(
|
||||||
|
std::make_unique<ast::SintLiteral>(selector.type, value32));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
clause->set_selectors(std::move(selectors));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Where does this clause end?
|
||||||
|
const auto end_id = (i + 1 < clause_heads.size()) ? clause_heads[i + 1]->id
|
||||||
|
: construct->end_id;
|
||||||
|
|
||||||
|
PushNewStatementBlock(construct, end_id, [clause](StatementBlock* s) {
|
||||||
|
clause->set_body(std::move(s->statements_));
|
||||||
|
});
|
||||||
|
|
||||||
|
if ((default_info == clause_heads[i]) && has_selectors &&
|
||||||
|
construct->ContainsPos(default_info->pos)) {
|
||||||
|
// Generate a default clause with a just fallthrough.
|
||||||
|
ast::StatementList stmts;
|
||||||
|
stmts.emplace_back(std::make_unique<ast::FallthroughStatement>());
|
||||||
|
auto case_stmt = std::make_unique<ast::CaseStatement>();
|
||||||
|
case_stmt->set_body(std::move(stmts));
|
||||||
|
cases->emplace_back(std::move(case_stmt));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i == 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We've listed cases in reverse order in the switch statement. Reorder them
|
||||||
|
// to match the presentation order in WGSL.
|
||||||
|
std::reverse(cases->begin(), cases->end());
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
bool FunctionEmitter::EmitLoopStart(const Construct* construct) {
|
bool FunctionEmitter::EmitLoopStart(const Construct* construct) {
|
||||||
auto* loop = AddStatement(std::make_unique<ast::LoopStatement>())->AsLoop();
|
auto* loop = AddStatement(std::make_unique<ast::LoopStatement>())->AsLoop();
|
||||||
PushNewStatementBlock(
|
PushNewStatementBlock(
|
||||||
|
@ -1946,7 +2075,31 @@ std::unique_ptr<ast::Statement> FunctionEmitter::MakeBranch(
|
||||||
case EdgeKind::kBack:
|
case EdgeKind::kBack:
|
||||||
// Nothing to do. The loop backedge is implicit.
|
// Nothing to do. The loop backedge is implicit.
|
||||||
break;
|
break;
|
||||||
case EdgeKind::kSwitchBreak:
|
case EdgeKind::kSwitchBreak: {
|
||||||
|
// Don't bother with a break at the end of a case.
|
||||||
|
const auto header = dest_info.header_for_merge;
|
||||||
|
assert(header != 0);
|
||||||
|
const auto* exiting_construct = GetBlockInfo(header)->construct;
|
||||||
|
assert(exiting_construct->kind == Construct::kSwitchSelection);
|
||||||
|
const auto candidate_next_case_pos = src_info.pos + 1;
|
||||||
|
// Leaving the last block from the last case?
|
||||||
|
if (candidate_next_case_pos == dest_info.pos) {
|
||||||
|
// No break needed.
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
// Leaving the last block from not-the-last-case?
|
||||||
|
if (exiting_construct->ContainsPos(candidate_next_case_pos)) {
|
||||||
|
const auto* candidate_next_case =
|
||||||
|
GetBlockInfo(block_order_[candidate_next_case_pos]);
|
||||||
|
if (candidate_next_case->case_head_for == exiting_construct ||
|
||||||
|
candidate_next_case->default_head_for == exiting_construct) {
|
||||||
|
// No break needed.
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// We need a break.
|
||||||
|
return std::make_unique<ast::BreakStatement>();
|
||||||
|
}
|
||||||
case EdgeKind::kLoopBreak:
|
case EdgeKind::kLoopBreak:
|
||||||
return std::make_unique<ast::BreakStatement>();
|
return std::make_unique<ast::BreakStatement>();
|
||||||
case EdgeKind::kLoopContinue:
|
case EdgeKind::kLoopContinue:
|
||||||
|
|
|
@ -289,6 +289,13 @@ class FunctionEmitter {
|
||||||
/// @returns false if emission failed.
|
/// @returns false if emission failed.
|
||||||
bool EmitIfStart(const BlockInfo& block_info);
|
bool EmitIfStart(const BlockInfo& block_info);
|
||||||
|
|
||||||
|
/// Emits a SwitchStatement, including its condition expression, and sets
|
||||||
|
/// up the statement stack to accumulate subsequent basic blocks into
|
||||||
|
/// the default clause and case clauses.
|
||||||
|
/// @param block_info the switch-selection header block
|
||||||
|
/// @returns false if emission failed.
|
||||||
|
bool EmitSwitchStart(const BlockInfo& block_info);
|
||||||
|
|
||||||
/// Emits a LoopStatement, and pushes a new StatementBlock to accumulate
|
/// Emits a LoopStatement, and pushes a new StatementBlock to accumulate
|
||||||
/// the remaining instructions in the current block and subsequent blocks
|
/// the remaining instructions in the current block and subsequent blocks
|
||||||
/// in the loop.
|
/// in the loop.
|
||||||
|
@ -375,7 +382,7 @@ class FunctionEmitter {
|
||||||
/// Gets the block info for a block ID, if any exists
|
/// Gets the block info for a block ID, if any exists
|
||||||
/// @param id the SPIR-V ID of the OpLabel instruction starting the block
|
/// @param id the SPIR-V ID of the OpLabel instruction starting the block
|
||||||
/// @returns the block info for the given ID, if it exists, or nullptr
|
/// @returns the block info for the given ID, if it exists, or nullptr
|
||||||
BlockInfo* GetBlockInfo(uint32_t id) {
|
BlockInfo* GetBlockInfo(uint32_t id) const {
|
||||||
auto where = block_info_.find(id);
|
auto where = block_info_.find(id);
|
||||||
if (where == block_info_.end())
|
if (where == block_info_.end())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -434,7 +441,7 @@ class FunctionEmitter {
|
||||||
uint32_t end_id,
|
uint32_t end_id,
|
||||||
CompletionAction completion_action,
|
CompletionAction completion_action,
|
||||||
ast::StatementList statements,
|
ast::StatementList statements,
|
||||||
ast::CaseStatementList cases);
|
std::unique_ptr<ast::CaseStatementList> cases);
|
||||||
StatementBlock(StatementBlock&&);
|
StatementBlock(StatementBlock&&);
|
||||||
~StatementBlock();
|
~StatementBlock();
|
||||||
|
|
||||||
|
@ -449,10 +456,13 @@ class FunctionEmitter {
|
||||||
|
|
||||||
// Only one of |statements| or |cases| is active.
|
// Only one of |statements| or |cases| is active.
|
||||||
|
|
||||||
// The list of statements being built.
|
// The list of statements being built, if this construct is not a switch.
|
||||||
ast::StatementList statements_;
|
ast::StatementList statements_;
|
||||||
// The list of cases being built, for a switch.
|
// The list of switch cases being built, if this construct is a switch.
|
||||||
ast::CaseStatementList cases_;
|
// The algorithm will cache a pointer to the vector. We want that pointer
|
||||||
|
// to be stable no matter how |statements_stack_| is resized. That's
|
||||||
|
// why we make this a unique_ptr rather than just a plain vector.
|
||||||
|
std::unique_ptr<ast::CaseStatementList> cases_;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Pushes an empty statement block onto the statements stack.
|
/// Pushes an empty statement block onto the statements stack.
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue