Resolver: Validate usage of break

Also remove unused fields of Resolver (block_to_info_, block_infos_). We can put them back when they're actually needed.

Fixed: tint:190
Change-Id: I1a02a24eca7fba32b8e1120abb88040138a39c6a
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/44051
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Ben Clayton 2021-03-09 15:06:37 +00:00 committed by Commit Bot service account
parent 5e9eb2037d
commit 9430cb4775
3 changed files with 83 additions and 44 deletions

View File

@ -66,9 +66,8 @@ Resolver::Resolver(ProgramBuilder* builder)
Resolver::~Resolver() = default;
Resolver::BlockInfo::BlockInfo(Resolver::BlockInfo::Type ty,
Resolver::BlockInfo* p,
const ast::BlockStatement* b)
: type(ty), parent(p), block(b) {}
Resolver::BlockInfo* p)
: type(ty), parent(p) {}
Resolver::BlockInfo::~BlockInfo() = default;
@ -150,12 +149,8 @@ bool Resolver::Function(ast::Function* func) {
}
bool Resolver::BlockStatement(const ast::BlockStatement* stmt) {
auto* block =
block_infos_.Create(BlockInfo::Type::Generic, current_block_, stmt);
block_to_info_[stmt] = block;
ScopedAssignment<BlockInfo*> scope_sa(current_block_, block);
return Statements(stmt->list());
return BlockScope(BlockInfo::Type::kGeneric,
[&] { return Statements(stmt->list()); });
}
bool Resolver::Statements(const ast::StatementList& stmts) {
@ -219,18 +214,24 @@ bool Resolver::Statement(ast::Statement* stmt) {
return BlockStatement(b);
}
if (stmt->Is<ast::BreakStatement>()) {
if (!current_block_->FindFirstParent(BlockInfo::Type::kLoop) &&
!current_block_->FindFirstParent(BlockInfo::Type::kSwitchCase)) {
diagnostics_.add_error("break statement must be in a loop or switch case",
stmt->source());
return false;
}
return true;
}
if (auto* c = stmt->As<ast::CallStatement>()) {
return Expression(c->expr());
}
if (auto* c = stmt->As<ast::CaseStatement>()) {
return BlockStatement(c->body());
return CaseStatement(c);
}
if (stmt->Is<ast::ContinueStatement>()) {
// Set if we've hit the first continue statement in our parent loop
if (auto* loop_block =
current_block_->FindFirstParent(BlockInfo::Type::Loop)) {
current_block_->FindFirstParent(BlockInfo::Type::kLoop)) {
if (loop_block->first_continue == size_t(~0)) {
loop_block->first_continue = loop_block->decls.size();
}
@ -268,26 +269,20 @@ bool Resolver::Statement(ast::Statement* stmt) {
// these would make their BlockInfo siblings as in the AST, but we want the
// body BlockInfo to parent the continuing BlockInfo for semantics and
// validation. Also, we need to set their types differently.
auto* block =
block_infos_.Create(BlockInfo::Type::Loop, current_block_, l->body());
block_to_info_[l->body()] = block;
ScopedAssignment<BlockInfo*> scope_sa(current_block_, block);
return BlockScope(BlockInfo::Type::kLoop, [&] {
if (!Statements(l->body()->list())) {
return false;
}
if (l->has_continuing()) {
auto* cont_block = block_infos_.Create(BlockInfo::Type::LoopContinuing,
current_block_, l->continuing());
block_to_info_[l->continuing()] = cont_block;
ScopedAssignment<BlockInfo*> scope_sa2(current_block_, cont_block);
if (!Statements(l->continuing()->list())) {
if (!BlockScope(BlockInfo::Type::kLoopContinuing,
[&] { return Statements(l->continuing()->list()); })) {
return false;
}
}
return true;
});
}
if (auto* r = stmt->As<ast::ReturnStatement>()) {
return Expression(r->value());
@ -297,7 +292,7 @@ bool Resolver::Statement(ast::Statement* stmt) {
return false;
}
for (auto* case_stmt : s->body()) {
if (!Statement(case_stmt)) {
if (!CaseStatement(case_stmt)) {
return false;
}
}
@ -316,6 +311,11 @@ bool Resolver::Statement(ast::Statement* stmt) {
return false;
}
bool Resolver::CaseStatement(ast::CaseStatement* stmt) {
return BlockScope(BlockInfo::Type::kSwitchCase,
[&] { return Statements(stmt->body()->list()); });
}
bool Resolver::Expressions(const ast::ExpressionList& list) {
for (auto* expr : list) {
if (!Expression(expr)) {
@ -395,8 +395,7 @@ bool Resolver::ArrayAccessor(ast::ArrayAccessorExpression* expr) {
} else if (auto* arr = parent_type->As<type::Array>()) {
if (!arr->type()->is_scalar()) {
// If we extract a non-scalar from an array then we also get a pointer. We
// will generate a Function storage class variable to store this
// into.
// will generate a Function storage class variable to store this into.
ret = builder_->create<type::Pointer>(ret, ast::StorageClass::kFunction);
}
}
@ -573,9 +572,9 @@ bool Resolver::Identifier(ast::IdentifierExpression* expr) {
// refer to a variable that is bypassed by a continue statement in the
// loop's body block.
if (auto* continuing_block =
current_block_->FindFirstParent(BlockInfo::Type::LoopContinuing)) {
current_block_->FindFirstParent(BlockInfo::Type::kLoopContinuing)) {
auto* loop_block =
continuing_block->FindFirstParent(BlockInfo::Type::Loop);
continuing_block->FindFirstParent(BlockInfo::Type::kLoop);
if (loop_block->first_continue != size_t(~0)) {
auto& decls = loop_block->decls;
// If our identifier is in loop_block->decls, make sure its index is
@ -946,6 +945,13 @@ void Resolver::CreateSemanticNodes() const {
}
}
template <typename F>
bool Resolver::BlockScope(BlockInfo::Type type, F&& callback) {
BlockInfo block_info(type, current_block_);
ScopedAssignment<BlockInfo*> sa(current_block_, &block_info);
return callback();
}
Resolver::VariableInfo::VariableInfo(ast::Variable* decl)
: declaration(decl), storage_class(decl->declared_storage_class()) {}

View File

@ -33,6 +33,7 @@ class ArrayAccessorExpression;
class BinaryExpression;
class BitcastExpression;
class CallExpression;
class CaseStatement;
class ConstructorExpression;
class Function;
class IdentifierExpression;
@ -105,9 +106,9 @@ class Resolver {
/// parent block and variables declared in the block.
/// Used to validate variable scoping rules.
struct BlockInfo {
enum class Type { Generic, Loop, LoopContinuing };
enum class Type { kGeneric, kLoop, kLoopContinuing, kSwitchCase };
BlockInfo(Type type, BlockInfo* parent, const ast::BlockStatement* block);
BlockInfo(Type type, BlockInfo* parent);
~BlockInfo();
template <typename Pred>
@ -124,9 +125,8 @@ class Resolver {
[ty](auto* block_info) { return block_info->type == ty; });
}
const Type type;
BlockInfo* parent;
const ast::BlockStatement* block;
Type const type;
BlockInfo* const parent;
std::vector<const ast::Variable*> decls;
// first_continue is set to the index of the first variable in decls
@ -134,9 +134,6 @@ class Resolver {
constexpr static size_t kNoContinue = size_t(~0);
size_t first_continue = kNoContinue;
};
std::unordered_map<const ast::BlockStatement*, BlockInfo*> block_to_info_;
BlockAllocator<BlockInfo> block_infos_;
BlockInfo* current_block_ = nullptr;
/// Resolves the program, without creating final the semantic nodes.
/// @returns true on success, false on error
@ -200,6 +197,7 @@ class Resolver {
bool Binary(ast::BinaryExpression* expr);
bool Bitcast(ast::BitcastExpression* expr);
bool Call(ast::CallExpression* expr);
bool CaseStatement(ast::CaseStatement* stmt);
bool Constructor(ast::ConstructorExpression* expr);
bool Identifier(ast::IdentifierExpression* expr);
bool IntrinsicCall(ast::CallExpression* call,
@ -221,9 +219,16 @@ class Resolver {
/// @param type the resolved type
void SetType(ast::Expression* expr, type::Type* type);
/// Constructs a new BlockInfo with the given type and with #current_block_ as
/// its parent, assigns this to #current_block_, and then calls `callback`.
/// The original #current_block_ is restored on exit.
template <typename F>
bool BlockScope(BlockInfo::Type type, F&& callback);
ProgramBuilder* const builder_;
std::unique_ptr<IntrinsicTable> const intrinsic_table_;
diag::List diagnostics_;
BlockInfo* current_block_ = nullptr;
ScopeStack<VariableInfo*> variable_stack_;
std::unordered_map<Symbol, FunctionInfo*> symbol_to_function_;
std::unordered_map<ast::Function*, FunctionInfo*> function_to_info_;

View File

@ -17,6 +17,7 @@
#include "gmock/gmock.h"
#include "src/ast/assignment_statement.h"
#include "src/ast/bitcast_expression.h"
#include "src/ast/break_statement.h"
#include "src/ast/call_statement.h"
#include "src/ast/continue_statement.h"
#include "src/ast/if_statement.h"
@ -476,10 +477,37 @@ TEST_F(ResolverTest,
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverTest, Stmt_ContinueInLoop) {
WrapInFunction(Loop(Block(create<ast::ContinueStatement>(Source{{12, 34}}))));
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverTest, Stmt_ContinueNotInLoop) {
WrapInFunction(create<ast::ContinueStatement>());
WrapInFunction(create<ast::ContinueStatement>(Source{{12, 34}}));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "error: continue statement must be in a loop");
EXPECT_EQ(r()->error(), "12:34 error: continue statement must be in a loop");
}
TEST_F(ResolverTest, Stmt_BreakInLoop) {
WrapInFunction(Loop(Block(create<ast::BreakStatement>(Source{{12, 34}}))));
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverTest, Stmt_BreakInSwitch) {
WrapInFunction(Loop(Block(create<ast::SwitchStatement>(
Expr(1), ast::CaseStatementList{
create<ast::CaseStatement>(
ast::CaseSelectorList{Literal(1)},
Block(create<ast::BreakStatement>(Source{{12, 34}}))),
}))));
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverTest, Stmt_BreakNotInLoopOrSwitch) {
WrapInFunction(create<ast::BreakStatement>(Source{{12, 34}}));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: break statement must be in a loop or switch case");
}
TEST_F(ResolverTest, Stmt_Return) {