sem: Split BlockStatement up into subclasses

Allows us to put block-type-specific data on the specific subtype instead of littering a common base class

Change-Id: If4a327a8ee52d5911308f38b518ec07c3ceebcb7
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/51367
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2021-05-19 21:22:07 +00:00 committed by Tint LUCI CQ
parent 512d60c207
commit 9a3ba02c36
3 changed files with 114 additions and 62 deletions

View File

@ -1251,8 +1251,8 @@ bool Resolver::Function(ast::Function* func) {
<< "Resolver::Function() called with a current statement"; << "Resolver::Function() called with a current statement";
return false; return false;
} }
sem::BlockStatement* sem_block = builder_->create<sem::BlockStatement>( auto* sem_block =
func->body(), nullptr, sem::BlockStatement::Type::kGeneric); builder_->create<sem::BlockStatement>(func->body(), nullptr);
builder_->Sem().Add(func->body(), sem_block); builder_->Sem().Add(func->body(), sem_block);
TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block); TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block);
if (!BlockScope(func->body(), if (!BlockScope(func->body(),
@ -1379,8 +1379,7 @@ bool Resolver::Statement(ast::Statement* stmt) {
sem::Statement* sem_statement; sem::Statement* sem_statement;
if (stmt->As<ast::BlockStatement>()) { if (stmt->As<ast::BlockStatement>()) {
sem_statement = builder_->create<sem::BlockStatement>( sem_statement = builder_->create<sem::BlockStatement>(
stmt->As<ast::BlockStatement>(), current_statement_, stmt->As<ast::BlockStatement>(), current_statement_);
sem::BlockStatement::Type::kGeneric);
} else { } else {
sem_statement = builder_->create<sem::Statement>(stmt, current_statement_); sem_statement = builder_->create<sem::Statement>(stmt, current_statement_);
} }
@ -1403,9 +1402,8 @@ bool Resolver::Statement(ast::Statement* stmt) {
return BlockScope(b, [&] { return Statements(b->list()); }); return BlockScope(b, [&] { return Statements(b->list()); });
} }
if (stmt->Is<ast::BreakStatement>()) { if (stmt->Is<ast::BreakStatement>()) {
if (!current_block_->FindFirstParent(sem::BlockStatement::Type::kLoop) && if (!current_block_->FindFirstParent<sem::LoopBlockStatement>() &&
!current_block_->FindFirstParent( !current_block_->FindFirstParent<sem::SwitchCaseBlockStatement>()) {
sem::BlockStatement::Type::kSwitchCase)) {
diagnostics_.add_error("break statement must be in a loop or switch case", diagnostics_.add_error("break statement must be in a loop or switch case",
stmt->source()); stmt->source());
return false; return false;
@ -1422,9 +1420,9 @@ bool Resolver::Statement(ast::Statement* stmt) {
if (stmt->Is<ast::ContinueStatement>()) { if (stmt->Is<ast::ContinueStatement>()) {
// Set if we've hit the first continue statement in our parent loop // Set if we've hit the first continue statement in our parent loop
if (auto* loop_block = if (auto* loop_block =
current_block_->FindFirstParent(sem::BlockStatement::Type::kLoop)) { current_block_->FindFirstParent<sem::LoopBlockStatement>()) {
if (loop_block->FirstContinue() == size_t(~0)) { if (loop_block->FirstContinue() == size_t(~0)) {
const_cast<sem::BlockStatement*>(loop_block) const_cast<sem::LoopBlockStatement*>(loop_block)
->SetFirstContinue(loop_block->Decls().size()); ->SetFirstContinue(loop_block->Decls().size());
} }
} else { } else {
@ -1468,8 +1466,8 @@ bool Resolver::CaseStatement(ast::CaseStatement* stmt) {
for (auto* sel : stmt->selectors()) { for (auto* sel : stmt->selectors()) {
Mark(sel); Mark(sel);
} }
sem::BlockStatement* sem_block = builder_->create<sem::BlockStatement>( auto* sem_block = builder_->create<sem::SwitchCaseBlockStatement>(
stmt->body(), current_statement_, sem::BlockStatement::Type::kSwitchCase); stmt->body(), current_statement_);
builder_->Sem().Add(stmt->body(), sem_block); builder_->Sem().Add(stmt->body(), sem_block);
TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block); TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block);
return BlockScope(stmt->body(), return BlockScope(stmt->body(),
@ -1492,8 +1490,8 @@ bool Resolver::IfStatement(ast::IfStatement* stmt) {
Mark(stmt->body()); Mark(stmt->body());
{ {
sem::BlockStatement* sem_block = builder_->create<sem::BlockStatement>( auto* sem_block =
stmt->body(), current_statement_, sem::BlockStatement::Type::kGeneric); builder_->create<sem::BlockStatement>(stmt->body(), current_statement_);
builder_->Sem().Add(stmt->body(), sem_block); builder_->Sem().Add(stmt->body(), sem_block);
TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block); TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block);
if (!BlockScope(stmt->body(), if (!BlockScope(stmt->body(),
@ -1525,9 +1523,8 @@ bool Resolver::IfStatement(ast::IfStatement* stmt) {
} }
Mark(else_stmt->body()); Mark(else_stmt->body());
{ {
sem::BlockStatement* sem_block = builder_->create<sem::BlockStatement>( auto* sem_block = builder_->create<sem::BlockStatement>(
else_stmt->body(), current_statement_, else_stmt->body(), current_statement_);
sem::BlockStatement::Type::kGeneric);
builder_->Sem().Add(else_stmt->body(), sem_block); builder_->Sem().Add(else_stmt->body(), sem_block);
TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block); TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block);
if (!BlockScope(else_stmt->body(), if (!BlockScope(else_stmt->body(),
@ -1546,8 +1543,8 @@ bool Resolver::LoopStatement(ast::LoopStatement* stmt) {
// validation. Also, we need to set their types differently. // validation. Also, we need to set their types differently.
Mark(stmt->body()); Mark(stmt->body());
auto* sem_block_body = builder_->create<sem::BlockStatement>( auto* sem_block_body = builder_->create<sem::LoopBlockStatement>(
stmt->body(), current_statement_, sem::BlockStatement::Type::kLoop); stmt->body(), current_statement_);
builder_->Sem().Add(stmt->body(), sem_block_body); builder_->Sem().Add(stmt->body(), sem_block_body);
TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block_body); TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block_body);
return BlockScope(stmt->body(), [&] { return BlockScope(stmt->body(), [&] {
@ -1558,9 +1555,9 @@ bool Resolver::LoopStatement(ast::LoopStatement* stmt) {
Mark(stmt->continuing()); Mark(stmt->continuing());
} }
if (stmt->has_continuing()) { if (stmt->has_continuing()) {
auto* sem_block_continuing = builder_->create<sem::BlockStatement>( auto* sem_block_continuing =
stmt->continuing(), current_statement_, builder_->create<sem::LoopContinuingBlockStatement>(
sem::BlockStatement::Type::kLoopContinuing); stmt->continuing(), current_statement_);
builder_->Sem().Add(stmt->continuing(), sem_block_continuing); builder_->Sem().Add(stmt->continuing(), sem_block_continuing);
TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block_continuing); TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block_continuing);
if (!BlockScope(stmt->continuing(), if (!BlockScope(stmt->continuing(),
@ -1909,10 +1906,11 @@ bool Resolver::Identifier(ast::IdentifierExpression* expr) {
// If identifier is part of a loop continuing block, make sure it // If identifier is part of a loop continuing block, make sure it
// doesn't refer to a variable that is bypassed by a continue statement // doesn't refer to a variable that is bypassed by a continue statement
// in the loop's body block. // in the loop's body block.
if (auto* continuing_block = current_block_->FindFirstParent( if (auto* continuing_block =
sem::BlockStatement::Type::kLoopContinuing)) { current_block_
->FindFirstParent<sem::LoopContinuingBlockStatement>()) {
auto* loop_block = auto* loop_block =
continuing_block->FindFirstParent(sem::BlockStatement::Type::kLoop); continuing_block->FindFirstParent<sem::LoopBlockStatement>();
if (loop_block->FirstContinue() != size_t(~0)) { if (loop_block->FirstContinue() != size_t(~0)) {
auto& decls = loop_block->Decls(); auto& decls = loop_block->Decls();
// If our identifier is in loop_block->decls, make sure its index is // If our identifier is in loop_block->decls, make sure its index is

View File

@ -17,35 +17,47 @@
#include "src/ast/block_statement.h" #include "src/ast/block_statement.h"
TINT_INSTANTIATE_TYPEINFO(tint::sem::BlockStatement); TINT_INSTANTIATE_TYPEINFO(tint::sem::BlockStatement);
TINT_INSTANTIATE_TYPEINFO(tint::sem::LoopBlockStatement);
TINT_INSTANTIATE_TYPEINFO(tint::sem::LoopContinuingBlockStatement);
TINT_INSTANTIATE_TYPEINFO(tint::sem::SwitchCaseBlockStatement);
namespace tint { namespace tint {
namespace sem { namespace sem {
BlockStatement::BlockStatement(const ast::BlockStatement* declaration, BlockStatement::BlockStatement(const ast::BlockStatement* declaration,
const Statement* parent, const Statement* parent)
Type type) : Base(declaration, parent) {}
: Base(declaration, parent), type_(type) {}
BlockStatement::~BlockStatement() = default; BlockStatement::~BlockStatement() = default;
const BlockStatement* BlockStatement::FindFirstParent(
BlockStatement::Type ty) const {
return FindFirstParent(
[ty](auto* block_info) { return block_info->type_ == ty; });
}
const ast::BlockStatement* BlockStatement::Declaration() const { const ast::BlockStatement* BlockStatement::Declaration() const {
return Base::Declaration()->As<ast::BlockStatement>(); return Base::Declaration()->As<ast::BlockStatement>();
} }
void BlockStatement::SetFirstContinue(size_t first_continue) {
TINT_ASSERT(type_ == Type::kLoop);
first_continue_ = first_continue;
}
void BlockStatement::AddDecl(ast::Variable* var) { void BlockStatement::AddDecl(ast::Variable* var) {
decls_.push_back(var); decls_.push_back(var);
} }
LoopBlockStatement::LoopBlockStatement(const ast::BlockStatement* declaration,
const Statement* parent)
: Base(declaration, parent) {}
LoopBlockStatement::~LoopBlockStatement() = default;
void LoopBlockStatement::SetFirstContinue(size_t first_continue) {
first_continue_ = first_continue;
}
LoopContinuingBlockStatement::LoopContinuingBlockStatement(
const ast::BlockStatement* declaration,
const Statement* parent)
: Base(declaration, parent) {}
LoopContinuingBlockStatement::~LoopContinuingBlockStatement() = default;
SwitchCaseBlockStatement::SwitchCaseBlockStatement(
const ast::BlockStatement* declaration,
const Statement* parent)
: Base(declaration, parent) {}
SwitchCaseBlockStatement::~SwitchCaseBlockStatement() = default;
} // namespace sem } // namespace sem
} // namespace tint } // namespace tint

View File

@ -17,7 +17,6 @@
#include <vector> #include <vector>
#include "src/debug.h"
#include "src/sem/statement.h" #include "src/sem/statement.h"
namespace tint { namespace tint {
@ -34,15 +33,11 @@ namespace sem {
/// declared in the block. /// declared in the block.
class BlockStatement : public Castable<BlockStatement, Statement> { class BlockStatement : public Castable<BlockStatement, Statement> {
public: public:
enum class Type { kGeneric, kLoop, kLoopContinuing, kSwitchCase };
/// Constructor /// Constructor
/// @param declaration the AST node for this block statement /// @param declaration the AST node for this block statement
/// @param parent the owning statement /// @param parent the owning statement
/// @param type the type of block this is
BlockStatement(const ast::BlockStatement* declaration, BlockStatement(const ast::BlockStatement* declaration,
const Statement* parent, const Statement* parent);
Type type);
/// Destructor /// Destructor
~BlockStatement() override; ~BlockStatement() override;
@ -63,21 +58,47 @@ class BlockStatement : public Castable<BlockStatement, Statement> {
return curr; return curr;
} }
/// @returns the closest enclosing block that matches the given type, which /// @returns the statement itself if it matches the template type `T`,
/// may be the block itself, or nullptr if no match is found /// otherwise the nearest enclosing block that matches `T`, or nullptr if
/// @param ty the type of block to be searched for /// there is none.
const BlockStatement* FindFirstParent(BlockStatement::Type ty) const; template <typename T>
const T* FindFirstParent() const {
const BlockStatement* curr = this;
while (curr) {
if (auto* block = curr->As<T>()) {
return block;
}
curr = curr->Block();
}
return nullptr;
}
/// @returns the declarations associated with this block /// @returns the declarations associated with this block
const std::vector<const ast::Variable*>& Decls() const { return decls_; } const std::vector<const ast::Variable*>& Decls() const { return decls_; }
/// Requires that this is a loop block. /// Associates a declaration with this block.
/// @param var a variable declaration to be added to the block
void AddDecl(ast::Variable* var);
private:
std::vector<const ast::Variable*> decls_;
};
/// Holds semantic information about a loop block
class LoopBlockStatement : public Castable<LoopBlockStatement, BlockStatement> {
public:
/// Constructor
/// @param declaration the AST node for this block statement
/// @param parent the owning statement
LoopBlockStatement(const ast::BlockStatement* declaration,
const Statement* parent);
/// Destructor
~LoopBlockStatement() override;
/// @returns the index of the first variable declared after the first continue /// @returns the index of the first variable declared after the first continue
/// statement /// statement
size_t FirstContinue() const { size_t FirstContinue() const { return first_continue_; }
TINT_ASSERT(type_ == Type::kLoop);
return first_continue_;
}
/// Requires that this is a loop block. /// Requires that this is a loop block.
/// Allows the resolver to set the index of the first variable declared after /// Allows the resolver to set the index of the first variable declared after
@ -85,20 +106,41 @@ class BlockStatement : public Castable<BlockStatement, Statement> {
/// @param first_continue index of the relevant variable /// @param first_continue index of the relevant variable
void SetFirstContinue(size_t first_continue); void SetFirstContinue(size_t first_continue);
/// Allows the resolver to associate a declaration with this block.
/// @param var a variable declaration to be added to the block
void AddDecl(ast::Variable* var);
private: private:
Type const type_;
std::vector<const ast::Variable*> decls_;
// first_continue is set to the index of the first variable in decls // first_continue is set to the index of the first variable in decls
// declared after the first continue statement in a loop block, if any. // declared after the first continue statement in a loop block, if any.
constexpr static size_t kNoContinue = size_t(~0); constexpr static size_t kNoContinue = size_t(~0);
size_t first_continue_ = kNoContinue; size_t first_continue_ = kNoContinue;
}; };
/// Holds semantic information about a loop continuing block
class LoopContinuingBlockStatement
: public Castable<LoopContinuingBlockStatement, BlockStatement> {
public:
/// Constructor
/// @param declaration the AST node for this block statement
/// @param parent the owning statement
LoopContinuingBlockStatement(const ast::BlockStatement* declaration,
const Statement* parent);
/// Destructor
~LoopContinuingBlockStatement() override;
};
/// Holds semantic information about a switch case block
class SwitchCaseBlockStatement
: public Castable<SwitchCaseBlockStatement, BlockStatement> {
public:
/// Constructor
/// @param declaration the AST node for this block statement
/// @param parent the owning statement
SwitchCaseBlockStatement(const ast::BlockStatement* declaration,
const Statement* parent);
/// Destructor
~SwitchCaseBlockStatement() override;
};
} // namespace sem } // namespace sem
} // namespace tint } // namespace tint