reader/spirv: Remove use of BlockStatement::append()

Introduce `StatementBuilder`s , which may hold mutable state, before being converted into the immutable AST node on completion of the `BlockStatement`.

Bug: tint:396
Bug: tint:390
Change-Id: I0381c4ae7948be0de02bc13e54e0037a72baaf0c
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/35506
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2020-12-14 19:48:47 +00:00 committed by Commit Bot service account
parent 2353bd0d3d
commit b833f1572b
7 changed files with 315 additions and 142 deletions

View File

@ -24,6 +24,10 @@ namespace ast {
BlockStatement::BlockStatement(const Source& source) : Base(source) {} BlockStatement::BlockStatement(const Source& source) : Base(source) {}
BlockStatement::BlockStatement(const Source& source,
const StatementList& statements)
: Base(source), statements_(std::move(statements)) {}
BlockStatement::BlockStatement(BlockStatement&&) = default; BlockStatement::BlockStatement(BlockStatement&&) = default;
BlockStatement::~BlockStatement() = default; BlockStatement::~BlockStatement() = default;

View File

@ -30,6 +30,10 @@ class BlockStatement : public Castable<BlockStatement, Statement> {
/// Constructor /// Constructor
/// @param source the block statement source /// @param source the block statement source
explicit BlockStatement(const Source& source); explicit BlockStatement(const Source& source);
/// Constructor
/// @param source the block statement source
/// @param statements the block statements
BlockStatement(const Source& source, const StatementList& statements);
/// Move constructor /// Move constructor
BlockStatement(BlockStatement&&); BlockStatement(BlockStatement&&);
~BlockStatement() override; ~BlockStatement() override;

View File

@ -42,6 +42,9 @@ class Statement : public Castable<Statement, Node> {
Statement(const Statement&) = delete; Statement(const Statement&) = delete;
}; };
/// A list of statements
using StatementList = std::vector<Statement*>;
} // namespace ast } // namespace ast
} // namespace tint } // namespace tint

View File

@ -608,6 +608,70 @@ class StructuredTraverser {
std::unordered_set<uint32_t> visited_; std::unordered_set<uint32_t> visited_;
}; };
/// A StatementBuilder for ast::SwitchStatment
/// @see StatementBuilder
struct SwitchStatementBuilder
: public Castable<SwitchStatementBuilder, StatementBuilder> {
/// Constructor
/// @param cond the switch statement condition
explicit SwitchStatementBuilder(ast::Expression* cond) : condition(cond) {}
/// @param mod the ast Module to build into
/// @returns the built ast::SwitchStatement
ast::SwitchStatement* Build(ast::Module* mod) const override {
// We've listed cases in reverse order in the switch statement.
// Reorder them to match the presentation order in WGSL.
auto reversed_cases = cases;
std::reverse(reversed_cases.begin(), reversed_cases.end());
return mod->create<ast::SwitchStatement>(Source{}, condition,
reversed_cases);
}
/// Switch statement condition
ast::Expression* const condition;
/// Switch statement cases
ast::CaseStatementList cases;
};
/// A StatementBuilder for ast::IfStatement
/// @see StatementBuilder
struct IfStatementBuilder
: public Castable<IfStatementBuilder, StatementBuilder> {
/// Constructor
/// @param c the if-statement condition
explicit IfStatementBuilder(ast::Expression* c) : cond(c) {}
/// @param mod the ast Module to build into
/// @returns the built ast::IfStatement
ast::IfStatement* Build(ast::Module* mod) const override {
return mod->create<ast::IfStatement>(Source{}, cond, body, else_stmts);
}
/// If-statement condition
ast::Expression* const cond;
/// If-statement block body
ast::BlockStatement* body = nullptr;
/// Optional if-statement else statements
ast::ElseStatementList else_stmts;
};
/// A StatementBuilder for ast::LoopStatement
/// @see StatementBuilder
struct LoopStatementBuilder
: public Castable<LoopStatementBuilder, StatementBuilder> {
/// @param mod the ast Module to build into
/// @returns the built ast::LoopStatement
ast::LoopStatement* Build(ast::Module* mod) const override {
return mod->create<ast::LoopStatement>(Source{}, body, continuing);
}
/// Loop-statement block body
ast::BlockStatement* body = nullptr;
/// Loop-statement continuing body
ast::BlockStatement* continuing = nullptr;
};
} // namespace } // namespace
BlockInfo::BlockInfo(const spvtools::opt::BasicBlock& bb) BlockInfo::BlockInfo(const spvtools::opt::BasicBlock& bb)
@ -622,6 +686,17 @@ DefInfo::DefInfo(const spvtools::opt::Instruction& def_inst,
DefInfo::~DefInfo() = default; DefInfo::~DefInfo() = default;
bool StatementBuilder::IsValid() const {
return true;
}
ast::Node* StatementBuilder::Clone(ast::CloneContext*) const {
return nullptr;
}
void StatementBuilder::to_str(std::ostream& out, size_t indent) const {
make_indent(out, indent);
out << "StatementBuilder" << std::endl;
}
FunctionEmitter::FunctionEmitter(ParserImpl* pi, FunctionEmitter::FunctionEmitter(ParserImpl* pi,
const spvtools::opt::Function& function, const spvtools::opt::Function& function,
const EntryPointInfo* ep_info) const EntryPointInfo* ep_info)
@ -636,7 +711,7 @@ FunctionEmitter::FunctionEmitter(ParserImpl* pi,
function_(function), function_(function),
i32_(ast_module_.create<ast::type::I32>()), i32_(ast_module_.create<ast::type::I32>()),
ep_info_(ep_info) { ep_info_(ep_info) {
PushNewStatementBlock(nullptr, 0, nullptr, nullptr, nullptr); PushNewStatementBlock(nullptr, 0, nullptr, nullptr);
} }
FunctionEmitter::FunctionEmitter(ParserImpl* pi, FunctionEmitter::FunctionEmitter(ParserImpl* pi,
@ -646,32 +721,62 @@ FunctionEmitter::FunctionEmitter(ParserImpl* pi,
FunctionEmitter::~FunctionEmitter() = default; FunctionEmitter::~FunctionEmitter() = default;
FunctionEmitter::StatementBlock::StatementBlock( FunctionEmitter::StatementBlock::StatementBlock(
const Construct* construct, const spirv::Construct* construct,
uint32_t end_id, uint32_t end_id,
CompletionAction completion_action, FunctionEmitter::CompletionAction completion_action,
ast::BlockStatement* statements,
ast::CaseStatementList* cases) ast::CaseStatementList* cases)
: construct_(construct), : construct_(construct),
end_id_(end_id), end_id_(end_id),
completion_action_(completion_action), completion_action_(completion_action),
statements_(statements),
cases_(cases) {} cases_(cases) {}
FunctionEmitter::StatementBlock::StatementBlock(StatementBlock&&) = default; FunctionEmitter::StatementBlock::StatementBlock(StatementBlock&& other)
: construct_(other.construct_),
end_id_(other.end_id_),
completion_action_(std::move(other.completion_action_)),
statements_(std::move(other.statements_)),
cases_(std::move(other.cases_)) {
other.statements_.clear();
}
FunctionEmitter::StatementBlock::~StatementBlock() = default; FunctionEmitter::StatementBlock::~StatementBlock() {
if (!finalized_) {
// Delete builders that have not been built with Finalize()
for (auto* statement : statements_) {
if (auto* builder = statement->As<StatementBuilder>()) {
delete builder;
}
}
}
}
void FunctionEmitter::StatementBlock::Finalize(ast::Module* mod) {
assert(!finalized_ /* Finalize() must only be called once */);
for (size_t i = 0; i < statements_.size(); i++) {
if (auto* builder = statements_[i]->As<StatementBuilder>()) {
statements_[i] = builder->Build(mod);
delete builder;
}
}
if (completion_action_ != nullptr) {
completion_action_(statements_);
}
finalized_ = true;
}
void FunctionEmitter::StatementBlock::Add(ast::Statement* statement) {
assert(!finalized_ /* Add() must not be called after Finalize() */);
statements_.emplace_back(statement);
}
void FunctionEmitter::PushNewStatementBlock(const Construct* construct, void FunctionEmitter::PushNewStatementBlock(const Construct* construct,
uint32_t end_id, uint32_t end_id,
ast::BlockStatement* block,
ast::CaseStatementList* cases, ast::CaseStatementList* cases,
CompletionAction action) { CompletionAction action) {
if (block == nullptr) {
block = create<ast::BlockStatement>(Source{});
}
statements_stack_.emplace_back( statements_stack_.emplace_back(
StatementBlock{construct, end_id, action, block, cases}); StatementBlock{construct, end_id, action, cases});
} }
void FunctionEmitter::PushGuard(const std::string& guard_name, void FunctionEmitter::PushGuard(const std::string& guard_name,
@ -685,10 +790,12 @@ void FunctionEmitter::PushGuard(const std::string& guard_name,
auto* cond = create<ast::IdentifierExpression>( auto* cond = create<ast::IdentifierExpression>(
Source{}, ast_module_.RegisterSymbol(guard_name), guard_name); Source{}, ast_module_.RegisterSymbol(guard_name), guard_name);
auto* body = create<ast::BlockStatement>(Source{}); auto* builder = AddStatementBuilder<IfStatementBuilder>(cond);
AddStatement(
create<ast::IfStatement>(Source{}, cond, body, ast::ElseStatementList{})); PushNewStatementBlock(
PushNewStatementBlock(top.construct_, end_id, body, nullptr, nullptr); top.Construct(), end_id, nullptr, [=](const ast::StatementList& stmts) {
builder->body = create<ast::BlockStatement>(Source{}, stmts);
});
} }
void FunctionEmitter::PushTrueGuard(uint32_t end_id) { void FunctionEmitter::PushTrueGuard(uint32_t end_id) {
@ -696,31 +803,36 @@ void FunctionEmitter::PushTrueGuard(uint32_t end_id) {
const auto& top = statements_stack_.back(); const auto& top = statements_stack_.back();
auto* cond = MakeTrue(Source{}); auto* cond = MakeTrue(Source{});
auto* body = create<ast::BlockStatement>(Source{}); auto* builder = AddStatementBuilder<IfStatementBuilder>(cond);
AddStatement(
create<ast::IfStatement>(Source{}, cond, body, ast::ElseStatementList{})); PushNewStatementBlock(
PushNewStatementBlock(top.construct_, end_id, body, nullptr, nullptr); top.Construct(), end_id, nullptr, [=](const ast::StatementList& stmts) {
builder->body = create<ast::BlockStatement>(Source{}, stmts);
});
} }
const ast::BlockStatement* FunctionEmitter::ast_body() { const ast::StatementList FunctionEmitter::ast_body() {
assert(!statements_stack_.empty()); assert(!statements_stack_.empty());
return statements_stack_[0].statements_; auto& entry = statements_stack_[0];
entry.Finalize(&ast_module_);
return entry.Statements();
} }
ast::Statement* FunctionEmitter::AddStatement(ast::Statement* statement) { ast::Statement* FunctionEmitter::AddStatement(ast::Statement* statement) {
assert(!statements_stack_.empty()); assert(!statements_stack_.empty());
auto* result = statement; auto* result = statement;
if (result != nullptr) { if (result != nullptr) {
statements_stack_.back().statements_->append(statement); auto& block = statements_stack_.back();
block.Add(statement);
} }
return result; return result;
} }
ast::Statement* FunctionEmitter::LastStatement() { ast::Statement* FunctionEmitter::LastStatement() {
assert(!statements_stack_.empty()); assert(!statements_stack_.empty());
auto* statement_list = statements_stack_.back().statements_; auto& statement_list = statements_stack_.back().Statements();
assert(!statement_list->empty()); assert(!statement_list.empty());
return statement_list->last(); return statement_list.back();
} }
bool FunctionEmitter::Emit() { bool FunctionEmitter::Emit() {
@ -748,7 +860,10 @@ bool FunctionEmitter::Emit() {
<< statements_stack_.size(); << statements_stack_.size();
} }
auto* body = statements_stack_[0].statements_; statements_stack_[0].Finalize(&ast_module_);
auto& statements = statements_stack_[0].Statements();
auto* body = create<ast::BlockStatement>(Source{}, statements);
ast_module_.AddFunction( ast_module_.AddFunction(
create<ast::Function>(decl.source, ast_module_.RegisterSymbol(decl.name), create<ast::Function>(decl.source, ast_module_.RegisterSymbol(decl.name),
decl.name, std::move(decl.params), decl.return_type, decl.name, std::move(decl.params), decl.return_type,
@ -756,7 +871,7 @@ bool FunctionEmitter::Emit() {
// Maintain the invariant by repopulating the one and only element. // Maintain the invariant by repopulating the one and only element.
statements_stack_.clear(); statements_stack_.clear();
PushNewStatementBlock(constructs_[0].get(), 0, nullptr, nullptr, nullptr); PushNewStatementBlock(constructs_[0].get(), 0, nullptr, nullptr);
return success(); return success();
} }
@ -1935,7 +2050,7 @@ bool FunctionEmitter::EmitFunctionBodyStatements() {
// TODO(dneto): refactor how the first construct is created vs. // TODO(dneto): refactor how the first construct is created vs.
// this statements stack entry is populated. // this statements stack entry is populated.
assert(statements_stack_.size() == 1); assert(statements_stack_.size() == 1);
statements_stack_[0].construct_ = function_construct; statements_stack_[0].SetConstruct(function_construct);
for (auto block_id : block_order()) { for (auto block_id : block_order()) {
if (!EmitBasicBlock(*GetBlockInfo(block_id))) { if (!EmitBasicBlock(*GetBlockInfo(block_id))) {
@ -1948,11 +2063,8 @@ bool FunctionEmitter::EmitFunctionBodyStatements() {
bool FunctionEmitter::EmitBasicBlock(const BlockInfo& block_info) { bool FunctionEmitter::EmitBasicBlock(const BlockInfo& block_info) {
// Close off previous constructs. // Close off previous constructs.
while (!statements_stack_.empty() && while (!statements_stack_.empty() &&
(statements_stack_.back().end_id_ == block_info.id)) { (statements_stack_.back().EndId() == block_info.id)) {
StatementBlock& sb = statements_stack_.back(); statements_stack_.back().Finalize(&ast_module_);
if (sb.completion_action_ != nullptr) {
sb.completion_action_();
}
statements_stack_.pop_back(); statements_stack_.pop_back();
} }
if (statements_stack_.empty()) { if (statements_stack_.empty()) {
@ -1965,7 +2077,7 @@ bool FunctionEmitter::EmitBasicBlock(const BlockInfo& block_info) {
std::vector<const Construct*> entering_constructs; // inner most comes first std::vector<const Construct*> entering_constructs; // inner most comes first
{ {
auto* here = block_info.construct; auto* here = block_info.construct;
auto* const top_construct = statements_stack_.back().construct_; auto* const top_construct = statements_stack_.back().Construct();
while (here != top_construct) { while (here != top_construct) {
// Only enter a construct at its header block. // Only enter a construct at its header block.
if (here->begin_id == block_info.id) { if (here->begin_id == block_info.id) {
@ -2152,42 +2264,9 @@ bool FunctionEmitter::EmitIfStart(const BlockInfo& block_info) {
const auto condition_id = const auto condition_id =
block_info.basic_block->terminator()->GetSingleWordInOperand(0); block_info.basic_block->terminator()->GetSingleWordInOperand(0);
auto* cond = MakeExpression(condition_id).expr; auto* cond = MakeExpression(condition_id).expr;
auto* body = create<ast::BlockStatement>(Source{});
// Generate the code for the condition. // Generate the code for the condition.
// Use the IfBuilder to create the if-statement. The IfBuilder is constructed auto* builder = AddStatementBuilder<IfStatementBuilder>(cond);
// as a std::shared_ptr and is captured by the then and else clause
// CompletionAction lambdas, and so will only be destructed when the last
// block is completed. The IfBuilder destructor constructs the IfStatement,
// inserting it at the current insertion point in the current
// ast::BlockStatement.
struct IfBuilder {
IfBuilder(ast::Module* mod,
StatementBlock& statement_block,
tint::ast::Expression* cond,
ast::BlockStatement* body)
: mod_(mod),
dst_block_(statement_block.statements_),
dst_block_insertion_point_(statement_block.statements_->size()),
cond_(cond),
body_(body) {}
~IfBuilder() {
dst_block_->insert(
dst_block_insertion_point_,
mod_->create<ast::IfStatement>(Source{}, cond_, body_, else_stmts_));
}
ast::Module* mod_;
ast::BlockStatement* dst_block_;
size_t dst_block_insertion_point_;
tint::ast::Expression* cond_;
ast::BlockStatement* body_;
ast::ElseStatementList else_stmts_;
};
auto if_builder = std::make_shared<IfBuilder>(
&ast_module_, statements_stack_.back(), cond, body);
// Compute the block IDs that should end the then-clause and the else-clause. // Compute the block IDs that should end the then-clause and the else-clause.
@ -2225,17 +2304,16 @@ bool FunctionEmitter::EmitIfStart(const BlockInfo& block_info) {
// Push statement blocks for the then-clause and the else-clause. // Push statement blocks for the then-clause and the else-clause.
// But make sure we do it in the right order. // But make sure we do it in the right order.
auto push_else = [this, if_builder, else_end, construct]() { auto push_else = [this, builder, else_end, construct]() {
// Push the else clause onto the stack first. // Push the else clause onto the stack first.
auto* else_body = create<ast::BlockStatement>(Source{});
PushNewStatementBlock( PushNewStatementBlock(
construct, else_end, else_body, nullptr, construct, else_end, nullptr, [=](const ast::StatementList& stmts) {
[this, if_builder, else_body]() {
// Only set the else-clause if there are statements to fill it. // Only set the else-clause if there are statements to fill it.
if (!else_body->empty()) { if (!stmts.empty()) {
// The "else" consists of the statement list from the top of // The "else" consists of the statement list from the top of
// statements stack, without an elseif condition. // statements stack, without an elseif condition.
if_builder->else_stmts_.emplace_back( auto* else_body = create<ast::BlockStatement>(Source{}, stmts);
builder->else_stmts.emplace_back(
create<ast::ElseStatement>(Source{}, nullptr, else_body)); create<ast::ElseStatement>(Source{}, nullptr, else_body));
} }
}); });
@ -2275,7 +2353,10 @@ bool FunctionEmitter::EmitIfStart(const BlockInfo& block_info) {
} }
// Push the then clause onto the stack. // Push the then clause onto the stack.
PushNewStatementBlock(construct, then_end, body, nullptr, [if_builder] {}); PushNewStatementBlock(
construct, then_end, nullptr, [=](const ast::StatementList& stmts) {
builder->body = create<ast::BlockStatement>(Source{}, stmts);
});
} }
return success(); return success();
@ -2293,14 +2374,11 @@ bool FunctionEmitter::EmitSwitchStart(const BlockInfo& block_info) {
auto selector = MakeExpression(selector_id); auto selector = MakeExpression(selector_id);
// First, push the statement block for the entire switch. // First, push the statement block for the entire switch.
ast::CaseStatementList case_list; auto* swch = AddStatementBuilder<SwitchStatementBuilder>(selector.expr);
auto* swch = create<ast::SwitchStatement>(Source{}, selector.expr, case_list);
AddStatement(swch)->As<ast::SwitchStatement>();
// Grab a pointer to the case list. It will get buried in the statement block // Grab a pointer to the case list. It will get buried in the statement block
// stack. // stack.
auto* cases = &(swch->body()); PushNewStatementBlock(construct, construct->end_id, &swch->cases, nullptr);
PushNewStatementBlock(construct, construct->end_id, nullptr, cases, nullptr);
// We will push statement-blocks onto the stack to gather the statements in // We will push statement-blocks onto the stack to gather the statements in
// the default clause and cases clauses. Determine the list of blocks // the default clause and cases clauses. Determine the list of blocks
@ -2367,21 +2445,27 @@ bool FunctionEmitter::EmitSwitchStart(const BlockInfo& block_info) {
const auto end_id = (i + 1 < clause_heads.size()) ? clause_heads[i + 1]->id const auto end_id = (i + 1 < clause_heads.size()) ? clause_heads[i + 1]->id
: construct->end_id; : construct->end_id;
// Create the case clause. Temporarily put it in the wrong order // Reserve the case clause slot in swch->cases, push the new statement block
// on the case statement list. // for the case, and fill the case clause once the block is generated.
auto* body = create<ast::BlockStatement>(Source{}); auto case_idx = swch->cases.size();
cases->emplace_back(create<ast::CaseStatement>(Source{}, selectors, body)); swch->cases.emplace_back(nullptr);
PushNewStatementBlock(
PushNewStatementBlock(construct, end_id, body, nullptr, nullptr); construct, end_id, nullptr, [=](const ast::StatementList& stmts) {
auto* body = create<ast::BlockStatement>(Source{}, stmts);
swch->cases[case_idx] =
create<ast::CaseStatement>(Source{}, selectors, body);
});
if ((default_info == clause_heads[i]) && has_selectors && if ((default_info == clause_heads[i]) && has_selectors &&
construct->ContainsPos(default_info->pos)) { construct->ContainsPos(default_info->pos)) {
// Generate a default clause with a just fallthrough. // Generate a default clause with a just fallthrough.
auto* stmts = create<ast::BlockStatement>(Source{}); auto* stmts = create<ast::BlockStatement>(
stmts->append(create<ast::FallthroughStatement>(Source{})); Source{}, ast::StatementList{
create<ast::FallthroughStatement>(Source{}),
});
auto* case_stmt = auto* case_stmt =
create<ast::CaseStatement>(Source{}, ast::CaseSelectorList{}, stmts); create<ast::CaseStatement>(Source{}, ast::CaseSelectorList{}, stmts);
cases->emplace_back(case_stmt); swch->cases.emplace_back(case_stmt);
} }
if (i == 0) { if (i == 0) {
@ -2389,18 +2473,16 @@ bool FunctionEmitter::EmitSwitchStart(const BlockInfo& block_info) {
} }
} }
// We've listed cases in reverse order in the switch statement. Reorder them
// to match the presentation order in WGSL.
std::reverse(cases->begin(), cases->end());
return success(); return success();
} }
bool FunctionEmitter::EmitLoopStart(const Construct* construct) { bool FunctionEmitter::EmitLoopStart(const Construct* construct) {
auto* body = create<ast::BlockStatement>(Source{}); auto* builder = AddStatementBuilder<LoopStatementBuilder>();
AddStatement(create<ast::LoopStatement>( PushNewStatementBlock(construct, construct->end_id, nullptr,
Source{}, body, create<ast::BlockStatement>(Source{}))); [=](const ast::StatementList& stmts) {
PushNewStatementBlock(construct, construct->end_id, body, nullptr, nullptr); builder->body =
create<ast::BlockStatement>(Source{}, stmts);
});
return success(); return success();
} }
@ -2408,13 +2490,16 @@ bool FunctionEmitter::EmitContinuingStart(const Construct* construct) {
// A continue construct has the same depth as its associated loop // A continue construct has the same depth as its associated loop
// construct. Start a continue construct. // construct. Start a continue construct.
auto* loop_candidate = LastStatement(); auto* loop_candidate = LastStatement();
auto* loop = loop_candidate->As<ast::LoopStatement>(); auto* loop = loop_candidate->As<LoopStatementBuilder>();
if (loop == nullptr) { if (loop == nullptr) {
return Fail() << "internal error: starting continue construct, " return Fail() << "internal error: starting continue construct, "
"expected loop on top of stack"; "expected loop on top of stack";
} }
PushNewStatementBlock(construct, construct->end_id, loop->continuing(), PushNewStatementBlock(construct, construct->end_id, nullptr,
nullptr, nullptr); [=](const ast::StatementList& stmts) {
loop->continuing =
create<ast::BlockStatement>(Source{}, stmts);
});
return success(); return success();
} }
@ -2502,7 +2587,7 @@ bool FunctionEmitter::EmitNormalTerminator(const BlockInfo& block_info) {
AddStatement(MakeSimpleIf(cond, true_branch, false_branch)); AddStatement(MakeSimpleIf(cond, true_branch, false_branch));
if (!flow_guard.empty()) { if (!flow_guard.empty()) {
PushGuard(flow_guard, statements_stack_.back().end_id_); PushGuard(flow_guard, statements_stack_.back().EndId());
} }
return true; return true;
} }
@ -2600,17 +2685,18 @@ ast::Statement* FunctionEmitter::MakeSimpleIf(ast::Expression* condition,
} }
ast::ElseStatementList else_stmts; ast::ElseStatementList else_stmts;
if (else_stmt != nullptr) { if (else_stmt != nullptr) {
auto* stmts = create<ast::BlockStatement>(Source{}); ast::StatementList stmts{else_stmt};
stmts->append(else_stmt); else_stmts.emplace_back(create<ast::ElseStatement>(
else_stmts.emplace_back( Source{}, nullptr, create<ast::BlockStatement>(Source{}, stmts)));
create<ast::ElseStatement>(Source{}, nullptr, stmts));
} }
auto* if_block = create<ast::BlockStatement>(Source{}); ast::StatementList if_stmts;
if (then_stmt != nullptr) {
if_stmts.emplace_back(then_stmt);
}
auto* if_block = create<ast::BlockStatement>(Source{}, if_stmts);
auto* if_stmt = auto* if_stmt =
create<ast::IfStatement>(Source{}, condition, if_block, else_stmts); create<ast::IfStatement>(Source{}, condition, if_block, else_stmts);
if (then_stmt != nullptr) {
if_block->append(then_stmt);
}
return if_stmt; return if_stmt;
} }
@ -4285,3 +4371,8 @@ FunctionEmitter::FunctionDeclaration::~FunctionDeclaration() = default;
} // namespace spirv } // namespace spirv
} // namespace reader } // namespace reader
} // namespace tint } // namespace tint
TINT_INSTANTIATE_CLASS_ID(tint::reader::spirv::StatementBuilder);
TINT_INSTANTIATE_CLASS_ID(tint::reader::spirv::SwitchStatementBuilder);
TINT_INSTANTIATE_CLASS_ID(tint::reader::spirv::IfStatementBuilder);
TINT_INSTANTIATE_CLASS_ID(tint::reader::spirv::LoopStatementBuilder);

View File

@ -289,6 +289,29 @@ inline std::ostream& operator<<(std::ostream& o, const DefInfo& di) {
return o; return o;
} }
/// A placeholder Statement that exists for the duration of building a
/// StatementBlock. Once the StatementBlock is built, Build() will be called to
/// construct the final AST node, which will be used in the place of this
/// StatementBuilder.
/// StatementBuilders are used to simplify construction of AST nodes that will
/// become immutable. The builders may hold mutable state while the
/// StatementBlock is being constructed, which becomes an immutable node on
/// StatementBlock::Finalize().
class StatementBuilder : public Castable<StatementBuilder, ast::Statement> {
public:
/// Constructor
StatementBuilder() : Base(Source{}) {}
/// @param mod the ast Module to build into
/// @returns the build AST node
virtual ast::Statement* Build(ast::Module* mod) const = 0;
private:
bool IsValid() const override;
Node* Clone(ast::CloneContext*) const override;
void to_str(std::ostream& out, size_t indent) const override;
};
/// A FunctionEmitter emits a SPIR-V function onto a Tint AST module. /// A FunctionEmitter emits a SPIR-V function onto a Tint AST module.
class FunctionEmitter { class FunctionEmitter {
public: public:
@ -317,10 +340,10 @@ class FunctionEmitter {
/// @returns true if emission has failed. /// @returns true if emission has failed.
bool failed() const { return !success(); } bool failed() const { return !success(); }
/// Returns the body of the function. It is the bottom of the statement /// Finalizes any StatementBuilders returns the body of the function.
/// stack. /// Must only be called once, and to be used only for testing.
/// @returns the body of the function. /// @returns the body of the function.
const ast::BlockStatement* ast_body(); const ast::StatementList ast_body();
/// Records failure. /// Records failure.
/// @returns a FailStream on which to emit diagnostics. /// @returns a FailStream on which to emit diagnostics.
@ -811,6 +834,14 @@ class FunctionEmitter {
/// @returns a pointer to the statement. /// @returns a pointer to the statement.
ast::Statement* AddStatement(ast::Statement* statement); ast::Statement* AddStatement(ast::Statement* statement);
template <typename T, typename... ARGS>
T* AddStatementBuilder(ARGS&&... args) {
// The builder is temporary and is not owned by the module.
auto builder = new T(std::forward<ARGS>(args)...);
AddStatement(builder);
return builder;
}
/// Returns the source record for the given instruction. /// Returns the source record for the given instruction.
/// @param inst the SPIR-V instruction /// @param inst the SPIR-V instruction
/// @return the Source record, or a default one /// @return the Source record, or a default one
@ -819,43 +850,79 @@ class FunctionEmitter {
/// @returns the last statetment in the top of the statement stack. /// @returns the last statetment in the top of the statement stack.
ast::Statement* LastStatement(); ast::Statement* LastStatement();
using CompletionAction = std::function<void()>; using CompletionAction = std::function<void(const ast::StatementList&)>;
// A StatementBlock represents a braced-list of statements while it is being // A StatementBlock represents a braced-list of statements while it is being
// constructed. // constructed.
struct StatementBlock { class StatementBlock {
public:
StatementBlock(const Construct* construct, StatementBlock(const Construct* construct,
uint32_t end_id, uint32_t end_id,
CompletionAction completion_action, CompletionAction completion_action,
ast::BlockStatement* statements,
ast::CaseStatementList* cases); ast::CaseStatementList* cases);
StatementBlock(StatementBlock&&); StatementBlock(StatementBlock&&);
~StatementBlock(); ~StatementBlock();
// The construct to which this construct constributes. StatementBlock(const StatementBlock&) = delete;
const Construct* construct_; StatementBlock& operator=(const StatementBlock&) = delete;
// The ID of the block at which the completion action should be triggerd
// and this statement block discarded. This is often the |end_id| of
// |construct| itself.
uint32_t end_id_;
// The completion action finishes processing this statement block.
CompletionAction completion_action_;
// Only one of |statements| or |cases| is active. /// Replaces any StatementBuilders with the built result, and calls the
/// completion callback (if set). Must only be called once, after all
/// statements have been added with Add().
/// @param mod the module
void Finalize(ast::Module* mod);
// The list of statements being built, if this construct is not a switch. /// Add() adds `statement` to the block.
ast::BlockStatement* statements_ = nullptr; /// Add() must not be called after calling Finalize().
// The list of switch cases being built, if this construct is a switch. void Add(ast::Statement* statement);
/// @param construct the construct which this construct constributes to
void SetConstruct(const Construct* construct) { construct_ = construct; }
/// @return the construct to which this construct constributes
const Construct* Construct() const { return construct_; }
/// @return the ID of the block at which the completion action should be
/// triggerd and this statement block discarded. This is often the `end_id`
/// of `construct` itself.
uint32_t EndId() const { return end_id_; }
/// @return the completion action finishes processing this statement block
CompletionAction CompletionAction() const { return completion_action_; }
/// @return the list of statements being built, if this construct is not a
/// switch.
const ast::StatementList& Statements() const { return statements_; }
/// @return the list of switch cases being built, if this construct is a
/// switch
ast::CaseStatementList* Cases() const { return cases_; }
private:
/// The construct to which this construct constributes.
const spirv::Construct* construct_;
/// The ID of the block at which the completion action should be triggerd
/// and this statement block discarded. This is often the `end_id` of
/// `construct` itself.
uint32_t const end_id_;
/// The completion action finishes processing this statement block.
FunctionEmitter::CompletionAction const completion_action_;
// Only one of `statements` or `cases` is active.
/// The list of statements being built, if this construct is not a switch.
ast::StatementList statements_;
/// The list of switch cases being built, if this construct is a switch.
ast::CaseStatementList* cases_ = nullptr; ast::CaseStatementList* cases_ = nullptr;
/// True if Finalize() has been called.
bool finalized_ = false;
}; };
/// Pushes an empty statement block onto the statements stack. /// Pushes an empty statement block onto the statements stack.
/// @param block the block to push into
/// @param cases the case list to push into /// @param cases the case list to push into
/// @param action the completion action for this block /// @param action the completion action for this block
void PushNewStatementBlock(const Construct* construct, void PushNewStatementBlock(const Construct* construct,
uint32_t end_id, uint32_t end_id,
ast::BlockStatement* block,
ast::CaseStatementList* cases, ast::CaseStatementList* cases,
CompletionAction action); CompletionAction action);
@ -887,6 +954,8 @@ class FunctionEmitter {
return ast_module_.create<T>(std::forward<ARGS>(args)...); return ast_module_.create<T>(std::forward<ARGS>(args)...);
} }
using StatementsStack = std::vector<StatementBlock>;
ParserImpl& parser_impl_; ParserImpl& parser_impl_;
ast::Module& ast_module_; ast::Module& ast_module_;
spvtools::opt::IRContext& ir_context_; spvtools::opt::IRContext& ir_context_;
@ -901,9 +970,9 @@ class FunctionEmitter {
// A stack of statement lists. Each list is contained in a construct in // A stack of statement lists. Each list is contained in a construct in
// the next deeper element of stack. The 0th entry represents the statements // the next deeper element of stack. The 0th entry represents the statements
// for the entire function. This stack is never empty. // for the entire function. This stack is never empty.
// The |construct| member for the 0th element is only valid during the // The `construct` member for the 0th element is only valid during the
// lifetime of the EmitFunctionBodyStatements method. // lifetime of the EmitFunctionBodyStatements method.
std::vector<StatementBlock> statements_stack_; StatementsStack statements_stack_;
// The set of IDs that have already had an identifier name generated for it. // The set of IDs that have already had an identifier name generated for it.
std::unordered_set<uint32_t> identifier_values_; std::unordered_set<uint32_t> identifier_values_;

View File

@ -479,7 +479,8 @@ TEST_F(SpvParserTest_CompositeExtract, Struct_DifferOnlyInMemberName) {
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
EXPECT_TRUE(fe.EmitBody()) << p->error(); EXPECT_TRUE(fe.EmitBody()) << p->error();
EXPECT_THAT(ToString(p->get_module(), fe.ast_body()), HasSubstr(R"( auto got = fe.ast_body();
EXPECT_THAT(ToString(p->get_module(), got), HasSubstr(R"(
VariableConst{ VariableConst{
x_2 x_2
none none
@ -491,8 +492,8 @@ TEST_F(SpvParserTest_CompositeExtract, Struct_DifferOnlyInMemberName) {
} }
} }
})")) })"))
<< ToString(p->get_module(), fe.ast_body()); << ToString(p->get_module(), got);
EXPECT_THAT(ToString(p->get_module(), fe.ast_body()), HasSubstr(R"( EXPECT_THAT(ToString(p->get_module(), got), HasSubstr(R"(
VariableConst{ VariableConst{
x_4 x_4
none none
@ -504,7 +505,7 @@ TEST_F(SpvParserTest_CompositeExtract, Struct_DifferOnlyInMemberName) {
} }
} }
})")) })"))
<< ToString(p->get_module(), fe.ast_body()); << ToString(p->get_module(), got);
} }
TEST_F(SpvParserTest_CompositeExtract, Struct_IndexTooBigError) { TEST_F(SpvParserTest_CompositeExtract, Struct_IndexTooBigError) {

View File

@ -57,13 +57,14 @@ class SpvParserTestBase : public T {
// Use this form when you don't need to template any further. // Use this form when you don't need to template any further.
using SpvParserTest = SpvParserTestBase<::testing::Test>; using SpvParserTest = SpvParserTestBase<::testing::Test>;
/// Returns the string dump of a function body. /// Returns the string dump of a statement list.
/// @param body the statement in the body /// @param mod the module
/// @returnss the string dump of a function body. /// @param stmts the statement list
/// @returns the string dump of a statement list.
inline std::string ToString(const ast::Module& mod, inline std::string ToString(const ast::Module& mod,
const ast::BlockStatement* body) { const ast::StatementList& stmts) {
std::ostringstream outs; std::ostringstream outs;
for (const auto* stmt : *body) { for (const auto* stmt : stmts) {
stmt->to_str(outs, 0); stmt->to_str(outs, 0);
} }
return Demangler().Demangle(mod, outs.str()); return Demangler().Demangle(mod, outs.str());