TD: validate continue statements bypassing body variables

We now keep track of scopes as a tree of BlockInfos that track variables
declared in each scope. For loop scopes, we store the index of the first
variable (if any) that follows the first continue statement. Using this
data structure, when parsing expressions, we validate that used
variables in continuing blocks are not bypassed by a continue statement
in the parent loop block.

Also:
* Validate that continue statements are in a loop in TD. This error is
already caught by the spir-v writer, but better to catch it here.
* Add more utility functions to ProgramBuilder to make it easier to
write tests

Fixed: tint:17
Change-Id: I967bf2cfb63062bac8dcca113d074ba0fe2152e2
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/44120
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Antonio Maiorano 2021-03-09 10:26:57 +00:00 committed by Commit Bot service account
parent b78251fdcd
commit fd31bbd3f1
5 changed files with 409 additions and 10 deletions

View File

@ -35,6 +35,9 @@ class BlockStatement : public Castable<BlockStatement, Statement> {
BlockStatement(BlockStatement&&);
~BlockStatement() override;
/// @returns the StatementList
const StatementList& list() const { return statements_; }
/// @returns true if the block is empty
bool empty() const { return statements_.empty(); }
/// @returns the number of statements directly in the block

View File

@ -19,12 +19,15 @@
#include <utility>
#include "src/ast/array_accessor_expression.h"
#include "src/ast/assignment_statement.h"
#include "src/ast/binary_expression.h"
#include "src/ast/bool_literal.h"
#include "src/ast/call_expression.h"
#include "src/ast/expression.h"
#include "src/ast/float_literal.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/if_statement.h"
#include "src/ast/loop_statement.h"
#include "src/ast/member_accessor_expression.h"
#include "src/ast/module.h"
#include "src/ast/scalar_constructor_expression.h"
@ -36,6 +39,7 @@
#include "src/ast/type_constructor_expression.h"
#include "src/ast/uint_literal.h"
#include "src/ast/variable.h"
#include "src/ast/variable_decl_statement.h"
#include "src/diagnostic/diagnostic.h"
#include "src/program.h"
#include "src/semantic/info.h"
@ -1051,6 +1055,64 @@ class ProgramBuilder {
});
}
/// Creates a ast::BlockStatement with input statements
/// @param statements statements of block
/// @returns the block statement pointer
template <typename... Statements>
ast::BlockStatement* Block(Statements&&... statements) {
return create<ast::BlockStatement>(
ast::StatementList{std::forward<Statements>(statements)...});
}
/// Creates a ast::ElseStatement with input condition and body
/// @param condition the else condition expression
/// @param body the else body
/// @returns the else statement pointer
ast::ElseStatement* Else(ast::Expression* condition,
ast::BlockStatement* body) {
return create<ast::ElseStatement>(condition, body);
}
/// Creates a ast::IfStatement with input condition, body, and optional
/// variadic else statements
/// @param condition the if statement condition expression
/// @param body the if statement body
/// @param elseStatements optional variadic else statements
/// @returns the if statement pointer
template <typename... ElseStatements>
ast::IfStatement* If(ast::Expression* condition,
ast::BlockStatement* body,
ElseStatements&&... elseStatements) {
return create<ast::IfStatement>(
condition, body,
ast::ElseStatementList{
std::forward<ElseStatements>(elseStatements)...});
}
/// Creates a ast::AssignmentStatement with input lhs and rhs expressions
/// @param lhs the left hand side expression
/// @param rhs the right hand side expression
/// @returns the assignment statement pointer
ast::AssignmentStatement* Assign(ast::Expression* lhs, ast::Expression* rhs) {
return create<ast::AssignmentStatement>(lhs, rhs);
}
/// Creates a ast::LoopStatement with input body and optional continuing
/// @param body the loop body
/// @param continuing the optional continuing block
/// @returns the loop statement pointer
ast::LoopStatement* Loop(ast::BlockStatement* body,
ast::BlockStatement* continuing = nullptr) {
return create<ast::LoopStatement>(body, continuing);
}
/// Creates a ast::VariableDeclStatement for the input variable
/// @param var the variable to wrap in a decl statement
/// @returns the variable decl statement pointer
ast::VariableDeclStatement* Decl(ast::Variable* var) {
return create<ast::VariableDeclStatement>(var);
}
/// Sets the current builder source to `src`
/// @param src the Source used for future create() calls
void SetSource(const Source& src) {

View File

@ -96,6 +96,13 @@ TypeDeterminer::TypeDeterminer(ProgramBuilder* builder)
TypeDeterminer::~TypeDeterminer() = default;
TypeDeterminer::BlockInfo::BlockInfo(TypeDeterminer::BlockInfo::Type type,
TypeDeterminer::BlockInfo* parent,
const ast::BlockStatement* block)
: type(type), parent(parent), block(block) {}
TypeDeterminer::BlockInfo::~BlockInfo() = default;
void TypeDeterminer::set_referenced_from_function_if_needed(VariableInfo* var,
bool local) {
if (current_function_ == nullptr) {
@ -159,7 +166,7 @@ bool TypeDeterminer::DetermineFunction(ast::Function* func) {
variable_stack_.set(param->symbol(), CreateVariableInfo(param));
}
if (!DetermineStatements(func->body())) {
if (!DetermineBlockStatement(func->body())) {
return false;
}
variable_stack_.pop_scope();
@ -173,8 +180,17 @@ bool TypeDeterminer::DetermineFunction(ast::Function* func) {
return true;
}
bool TypeDeterminer::DetermineStatements(const ast::BlockStatement* stmts) {
for (auto* stmt : *stmts) {
bool TypeDeterminer::DetermineBlockStatement(const ast::BlockStatement* stmt) {
auto* block =
block_infos_.Create(BlockInfo::Type::Generic, current_block_, stmt);
block_to_info_[stmt] = block;
ScopedAssignment<BlockInfo*> scope_sa(current_block_, block);
return DetermineStatements(stmt->list());
}
bool TypeDeterminer::DetermineStatements(const ast::StatementList& stmts) {
for (auto* stmt : stmts) {
if (auto* decl = stmt->As<ast::VariableDeclStatement>()) {
if (!ValidateVariableDeclStatement(decl)) {
return false;
@ -231,7 +247,7 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
return DetermineResultType(a->lhs()) && DetermineResultType(a->rhs());
}
if (auto* b = stmt->As<ast::BlockStatement>()) {
return DetermineStatements(b);
return DetermineBlockStatement(b);
}
if (stmt->Is<ast::BreakStatement>()) {
return true;
@ -240,9 +256,21 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
return DetermineResultType(c->expr());
}
if (auto* c = stmt->As<ast::CaseStatement>()) {
return DetermineStatements(c->body());
return DetermineBlockStatement(c->body());
}
if (stmt->Is<ast::ContinueStatement>()) {
// Set if we've hit the first continue statement in our parent loop
if (auto* loop_block =
current_block_->FindFirstParent(BlockInfo::Type::Loop)) {
if (loop_block->first_continue == size_t(~0)) {
loop_block->first_continue = loop_block->decls.size();
}
} else {
diagnostics_.add_error("continue statement must be in a loop",
stmt->source());
return false;
}
return true;
}
if (stmt->Is<ast::DiscardStatement>()) {
@ -250,14 +278,14 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
}
if (auto* e = stmt->As<ast::ElseStatement>()) {
return DetermineResultType(e->condition()) &&
DetermineStatements(e->body());
DetermineBlockStatement(e->body());
}
if (stmt->Is<ast::FallthroughStatement>()) {
return true;
}
if (auto* i = stmt->As<ast::IfStatement>()) {
if (!DetermineResultType(i->condition()) ||
!DetermineStatements(i->body())) {
!DetermineBlockStatement(i->body())) {
return false;
}
@ -269,8 +297,30 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
return true;
}
if (auto* l = stmt->As<ast::LoopStatement>()) {
return DetermineStatements(l->body()) &&
DetermineStatements(l->continuing());
// We don't call DetermineBlockStatement on the body and continuing block as
// these would make their BlockInfo siblings as in the AST, but we want the
// body BlockInfo to parent the continuing BlockInfo for semantics and
// validation. Also, we need to set their types differently.
auto* block =
block_infos_.Create(BlockInfo::Type::Loop, current_block_, l->body());
block_to_info_[l->body()] = block;
ScopedAssignment<BlockInfo*> scope_sa(current_block_, block);
if (!DetermineStatements(l->body()->list())) {
return false;
}
if (l->has_continuing()) {
auto* block = block_infos_.Create(BlockInfo::Type::LoopContinuing,
current_block_, l->continuing());
block_to_info_[l->continuing()] = block;
ScopedAssignment<BlockInfo*> scope_sa(current_block_, block);
if (!DetermineStatements(l->continuing()->list())) {
return false;
}
}
return true;
}
if (auto* r = stmt->As<ast::ReturnStatement>()) {
return DetermineResultType(r->value());
@ -289,6 +339,7 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
variable_stack_.set(v->variable()->symbol(),
variable_to_info_.at(v->variable()));
current_block_->decls.push_back(v->variable());
return DetermineResultType(v->variable()->constructor());
}
@ -552,6 +603,36 @@ bool TypeDeterminer::DetermineIdentifier(ast::IdentifierExpression* expr) {
var->users.push_back(expr);
set_referenced_from_function_if_needed(var, true);
// If identifier is part of a loop continuing block, make sure it doesn't
// refer to a variable that is bypassed by a continue statement in the
// loop's body block.
if (auto* continuing_block =
current_block_->FindFirstParent(BlockInfo::Type::LoopContinuing)) {
auto* loop_block =
continuing_block->FindFirstParent(BlockInfo::Type::Loop);
if (loop_block->first_continue != size_t(~0)) {
auto& decls = loop_block->decls;
// If our identifier is in loop_block->decls, make sure its index is
// less than first_continue
auto iter = std::find_if(
decls.begin(), decls.end(),
[&symbol](auto* var) { return var->symbol() == symbol; });
if (iter != decls.end()) {
auto var_decl_index =
static_cast<size_t>(std::distance(decls.begin(), iter));
if (var_decl_index >= loop_block->first_continue) {
diagnostics_.add_error(
"continue statement bypasses declaration of '" +
builder_->Symbols().NameFor(symbol) +
"' in continuing block",
expr->source());
return false;
}
}
}
}
return true;
}

View File

@ -106,6 +106,43 @@ class TypeDeterminer {
semantic::Statement* statement;
};
/// Structure holding semantic information about a block (i.e. scope), such as
/// parent block and variables declared in the block.
/// Used to validate variable scoping rules.
struct BlockInfo {
enum class Type { Generic, Loop, LoopContinuing };
BlockInfo(Type type, BlockInfo* parent, const ast::BlockStatement* block);
~BlockInfo();
template <typename Pred>
BlockInfo* FindFirstParent(Pred&& pred) {
BlockInfo* curr = this;
while (curr && !pred(curr)) {
curr = curr->parent;
}
return curr;
}
BlockInfo* FindFirstParent(BlockInfo::Type type) {
return FindFirstParent(
[type](auto* block_info) { return block_info->type == type; });
}
const Type type;
BlockInfo* parent;
const ast::BlockStatement* block;
std::vector<const ast::Variable*> decls;
// first_continue is set to the index of the first variable in decls
// declared after the first continue statement in a loop block, if any.
constexpr static size_t kNoContinue = size_t(~0);
size_t first_continue = kNoContinue;
};
std::unordered_map<const ast::BlockStatement*, BlockInfo*> block_to_info_;
BlockAllocator<BlockInfo> block_infos_;
BlockInfo* current_block_ = nullptr;
/// Determines type information for the program, without creating final the
/// semantic nodes.
/// @returns true if the determination was successful
@ -120,10 +157,14 @@ class TypeDeterminer {
/// @param func the function to check
/// @returns true if the determination was successful
bool DetermineFunction(ast::Function* func);
/// Determines the type information for a block statement
/// @param stmt the block statement
/// @returns true if determination was successful
bool DetermineBlockStatement(const ast::BlockStatement* stmt);
/// Determines type information for a set of statements
/// @param stmts the statements to check
/// @returns true if the determination was successful
bool DetermineStatements(const ast::BlockStatement* stmts);
bool DetermineStatements(const ast::StatementList& stmts);
/// Determines type information for a statement
/// @param stmt the statement to check
/// @returns true if the determination was successful

View File

@ -312,6 +312,218 @@ TEST_F(TypeDeterminerTest, Stmt_Loop) {
EXPECT_TRUE(TypeOf(continuing_rhs)->Is<type::F32>());
}
TEST_F(TypeDeterminerTest,
Stmt_Loop_ContinueInLoopBodyBeforeDecl_UsageInContinuing) {
// loop {
// continue; // Bypasses z decl
// var z : i32;
//
// continuing {
// z = 2;
// }
// }
auto error_loc = Source{Source::Location{12, 34}};
auto* body = Block(create<ast::ContinueStatement>(),
Decl(Var("z", ty.i32(), ast::StorageClass::kNone)));
auto* continuing = Block(Assign(Expr(error_loc, "z"), Expr(2)));
auto* loop_stmt = Loop(body, continuing);
WrapInFunction(loop_stmt);
EXPECT_FALSE(td()->Determine()) << td()->error();
EXPECT_EQ(td()->error(),
"12:34 error: continue statement bypasses declaration of 'z' in "
"continuing block");
}
TEST_F(TypeDeterminerTest,
Stmt_Loop_ContinueInLoopBodyBeforeDeclAndAfterDecl_UsageInContinuing) {
// loop {
// continue; // Bypasses z decl
// var z : i32;
// continue; // Ok
//
// continuing {
// z = 2;
// }
// }
auto error_loc = Source{Source::Location{12, 34}};
auto* body = Block(create<ast::ContinueStatement>(),
Decl(Var("z", ty.i32(), ast::StorageClass::kNone)),
create<ast::ContinueStatement>());
auto* continuing = Block(Assign(Expr(error_loc, "z"), Expr(2)));
auto* loop_stmt = Loop(body, continuing);
WrapInFunction(loop_stmt);
EXPECT_FALSE(td()->Determine()) << td()->error();
EXPECT_EQ(td()->error(),
"12:34 error: continue statement bypasses declaration of 'z' in "
"continuing block");
}
TEST_F(TypeDeterminerTest,
Stmt_Loop_ContinueInLoopBodySubscopeBeforeDecl_UsageInContinuing) {
// loop {
// if (true) {
// continue; // Still bypasses z decl (if we reach here)
// }
// var z : i32;
// continuing {
// z = 2;
// }
// }
auto error_loc = Source{Source::Location{12, 34}};
auto* body = Block(If(Expr(true), Block(create<ast::ContinueStatement>())),
Decl(Var("z", ty.i32(), ast::StorageClass::kNone)));
auto* continuing = Block(Assign(Expr(error_loc, "z"), Expr(2)));
auto* loop_stmt = Loop(body, continuing);
WrapInFunction(loop_stmt);
EXPECT_FALSE(td()->Determine()) << td()->error();
EXPECT_EQ(td()->error(),
"12:34 error: continue statement bypasses declaration of 'z' in "
"continuing block");
}
TEST_F(
TypeDeterminerTest,
Stmt_Loop_ContinueInLoopBodySubscopeBeforeDecl_UsageInContinuingSubscope) {
// loop {
// if (true) {
// continue; // Still bypasses z decl (if we reach here)
// }
// var z : i32;
// continuing {
// if (true) {
// z = 2; // Must fail even if z is in a sub-scope
// }
// }
// }
auto error_loc = Source{Source::Location{12, 34}};
auto* body = Block(If(Expr(true), Block(create<ast::ContinueStatement>())),
Decl(Var("z", ty.i32(), ast::StorageClass::kNone)));
auto* continuing =
Block(If(Expr(true), Block(Assign(Expr(error_loc, "z"), Expr(2)))));
auto* loop_stmt = Loop(body, continuing);
WrapInFunction(loop_stmt);
EXPECT_FALSE(td()->Determine()) << td()->error();
EXPECT_EQ(td()->error(),
"12:34 error: continue statement bypasses declaration of 'z' in "
"continuing block");
}
TEST_F(TypeDeterminerTest,
Stmt_Loop_ContinueInLoopBodySubscopeBeforeDecl_UsageInContinuingLoop) {
// loop {
// if (true) {
// continue; // Still bypasses z decl (if we reach here)
// }
// var z : i32;
// continuing {
// loop {
// z = 2; // Must fail even if z is in a sub-scope
// }
// }
// }
auto error_loc = Source{Source::Location{12, 34}};
auto* body = Block(If(Expr(true), Block(create<ast::ContinueStatement>())),
Decl(Var("z", ty.i32(), ast::StorageClass::kNone)));
auto* continuing = Block(Loop(Block(Assign(Expr(error_loc, "z"), Expr(2)))));
auto* loop_stmt = Loop(body, continuing);
WrapInFunction(loop_stmt);
EXPECT_FALSE(td()->Determine()) << td()->error();
EXPECT_EQ(td()->error(),
"12:34 error: continue statement bypasses declaration of 'z' in "
"continuing block");
}
TEST_F(TypeDeterminerTest,
Stmt_Loop_ContinueInNestedLoopBodyBeforeDecl_UsageInContinuing) {
// loop {
// loop {
// continue; // OK: not part of the outer loop
// }
// var z : i32;
//
// continuing {
// z = 2;
// }
// }
auto* inner_loop = Loop(Block(create<ast::ContinueStatement>()));
auto* body =
Block(inner_loop, Decl(Var("z", ty.i32(), ast::StorageClass::kNone)));
auto* continuing = Block(Assign(Expr("z"), Expr(2)));
auto* loop_stmt = Loop(body, continuing);
WrapInFunction(loop_stmt);
EXPECT_TRUE(td()->Determine()) << td()->error();
}
TEST_F(TypeDeterminerTest,
Stmt_Loop_ContinueInNestedLoopBodyBeforeDecl_UsageInContinuingSubscope) {
// loop {
// loop {
// continue; // OK: not part of the outer loop
// }
// var z : i32;
//
// continuing {
// if (true) {
// z = 2;
// }
// }
// }
auto* inner_loop = Loop(Block(create<ast::ContinueStatement>()));
auto* body =
Block(inner_loop, Decl(Var("z", ty.i32(), ast::StorageClass::kNone)));
auto* continuing = Block(If(Expr(true), Block(Assign(Expr("z"), Expr(2)))));
auto* loop_stmt = Loop(body, continuing);
WrapInFunction(loop_stmt);
EXPECT_TRUE(td()->Determine()) << td()->error();
}
TEST_F(TypeDeterminerTest,
Stmt_Loop_ContinueInNestedLoopBodyBeforeDecl_UsageInContinuingLoop) {
// loop {
// loop {
// continue; // OK: not part of the outer loop
// }
// var z : i32;
//
// continuing {
// loop {
// z = 2;
// }
// }
// }
auto* inner_loop = Loop(Block(create<ast::ContinueStatement>()));
auto* body =
Block(inner_loop, Decl(Var("z", ty.i32(), ast::StorageClass::kNone)));
auto* continuing = Block(Loop(Block(Assign(Expr("z"), Expr(2)))));
auto* loop_stmt = Loop(body, continuing);
WrapInFunction(loop_stmt);
EXPECT_TRUE(td()->Determine()) << td()->error();
}
TEST_F(TypeDeterminerTest, Stmt_ContinueNotInLoop) {
WrapInFunction(create<ast::ContinueStatement>());
EXPECT_FALSE(td()->Determine());
EXPECT_EQ(td()->error(), "error: continue statement must be in a loop");
}
TEST_F(TypeDeterminerTest, Stmt_Return) {
auto* cond = Expr(2);