resolver: Refactor Statement handling

Break up Resolver::Statement() into multiple resolver functions.
Move simple statement validation out to resolver_validation.cc

Change-Id: Ifa29433af0a9afa39a66ac3e4f7ca376351adfbf
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/71102
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Ben Clayton 2021-11-26 16:26:42 +00:00 committed by Tint LUCI CQ
parent c5835eb4a0
commit 8c30d752a0
7 changed files with 358 additions and 214 deletions

View File

@ -654,9 +654,9 @@ sem::Function* Resolver::Function(const ast::Function* decl) {
<< "Resolver::Function() called with a current compound statement"; << "Resolver::Function() called with a current compound statement";
return nullptr; return nullptr;
} }
auto* sem_block = builder_->create<sem::FunctionBlockStatement>(func); if (!StatementScope(decl->body,
builder_->Sem().Add(decl->body, sem_block); builder_->create<sem::FunctionBlockStatement>(func),
if (!Scope(sem_block, [&] { return Statements(decl->body->statements); })) { [&] { return Statements(decl->body->statements); })) {
return nullptr; return nullptr;
} }
} }
@ -796,7 +796,8 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
bool Resolver::Statements(const ast::StatementList& stmts) { bool Resolver::Statements(const ast::StatementList& stmts) {
for (auto* stmt : stmts) { for (auto* stmt : stmts) {
Mark(stmt); Mark(stmt);
if (!Statement(stmt)) { auto* sem = Statement(stmt);
if (!sem) {
return false; return false;
} }
} }
@ -807,18 +808,18 @@ bool Resolver::Statements(const ast::StatementList& stmts) {
return true; return true;
} }
bool Resolver::Statement(const ast::Statement* stmt) { sem::Statement* Resolver::Statement(const ast::Statement* stmt) {
if (stmt->Is<ast::CaseStatement>()) { if (stmt->Is<ast::CaseStatement>()) {
AddError("case statement can only be used inside a switch statement", AddError("case statement can only be used inside a switch statement",
stmt->source); stmt->source);
return false; return nullptr;
} }
if (stmt->Is<ast::ElseStatement>()) { if (stmt->Is<ast::ElseStatement>()) {
TINT_ICE(Resolver, diagnostics_) TINT_ICE(Resolver, diagnostics_)
<< "Resolver::Statement() encountered an Else statement. Else " << "Resolver::Statement() encountered an Else statement. Else "
"statements are embedded in If statements, so should never be " "statements are embedded in If statements, so should never be "
"encountered as top-level statements"; "encountered as top-level statements";
return false; return nullptr;
} }
// Compound statements. These create their own sem::CompoundStatement // Compound statements. These create their own sem::CompoundStatement
@ -840,69 +841,26 @@ bool Resolver::Statement(const ast::Statement* stmt) {
} }
// Non-Compound statements // Non-Compound statements
sem::Statement* sem_statement = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt, sem_statement);
TINT_SCOPED_ASSIGNMENT(current_statement_, sem_statement);
if (auto* a = stmt->As<ast::AssignmentStatement>()) { if (auto* a = stmt->As<ast::AssignmentStatement>()) {
return Assignment(a); return AssignmentStatement(a);
} }
if (stmt->Is<ast::BreakStatement>()) { if (auto* b = stmt->As<ast::BreakStatement>()) {
if (!sem_statement->FindFirstParent<sem::LoopBlockStatement>() && return BreakStatement(b);
!sem_statement->FindFirstParent<sem::SwitchCaseBlockStatement>()) {
AddError("break statement must be in a loop or switch case",
stmt->source);
return false;
}
return true;
} }
if (auto* c = stmt->As<ast::CallStatement>()) { if (auto* c = stmt->As<ast::CallStatement>()) {
if (!Expression(c->expr)) { return CallStatement(c);
return false;
}
return true;
} }
if (auto* c = stmt->As<ast::ContinueStatement>()) { if (auto* c = stmt->As<ast::ContinueStatement>()) {
// Set if we've hit the first continue statement in our parent loop return ContinueStatement(c);
if (auto* block =
current_block_->FindFirstParent<
sem::LoopBlockStatement, sem::LoopContinuingBlockStatement>()) {
if (auto* loop_block = block->As<sem::LoopBlockStatement>()) {
if (!loop_block->FirstContinue()) {
const_cast<sem::LoopBlockStatement*>(loop_block)
->SetFirstContinue(c, loop_block->Decls().size());
} }
} else { if (auto* d = stmt->As<ast::DiscardStatement>()) {
AddError("continuing blocks must not contain a continue statement", return DiscardStatement(d);
stmt->source);
return false;
} }
} else { if (auto* f = stmt->As<ast::FallthroughStatement>()) {
AddError("continue statement must be in a loop", stmt->source); return FallthroughStatement(f);
return false;
}
return true;
}
if (stmt->Is<ast::DiscardStatement>()) {
if (auto* continuing =
sem_statement
->FindFirstParent<sem::LoopContinuingBlockStatement>()) {
AddError("continuing blocks must not contain a discard statement",
stmt->source);
if (continuing != sem_statement->Parent()) {
AddNote("see continuing block here", continuing->Declaration()->source);
}
return false;
}
current_function_->SetHasDiscard();
return true;
}
if (stmt->Is<ast::FallthroughStatement>()) {
return true;
} }
if (auto* r = stmt->As<ast::ReturnStatement>()) { if (auto* r = stmt->As<ast::ReturnStatement>()) {
return Return(r); return ReturnStatement(r);
} }
if (auto* v = stmt->As<ast::VariableDeclStatement>()) { if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
return VariableDeclStatement(v); return VariableDeclStatement(v);
@ -910,43 +868,38 @@ bool Resolver::Statement(const ast::Statement* stmt) {
AddError("unknown statement type: " + std::string(stmt->TypeInfo().name), AddError("unknown statement type: " + std::string(stmt->TypeInfo().name),
stmt->source); stmt->source);
return false; return nullptr;
} }
bool Resolver::CaseStatement(const ast::CaseStatement* stmt) { sem::SwitchCaseBlockStatement* Resolver::CaseStatement(
const ast::CaseStatement* stmt) {
auto* sem = builder_->create<sem::SwitchCaseBlockStatement>( auto* sem = builder_->create<sem::SwitchCaseBlockStatement>(
stmt->body, current_compound_statement_, current_function_); stmt->body, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt, sem); return StatementScope(stmt, sem, [&] {
builder_->Sem().Add(stmt->body, sem); builder_->Sem().Add(stmt->body, sem);
Mark(stmt->body); Mark(stmt->body);
for (auto* sel : stmt->selectors) { for (auto* sel : stmt->selectors) {
Mark(sel); Mark(sel);
} }
return Scope(sem, [&] { return Statements(stmt->body->statements); }); return Statements(stmt->body->statements);
});
} }
bool Resolver::IfStatement(const ast::IfStatement* stmt) { sem::IfStatement* Resolver::IfStatement(const ast::IfStatement* stmt) {
auto* sem = builder_->create<sem::IfStatement>( auto* sem = builder_->create<sem::IfStatement>(
stmt, current_compound_statement_, current_function_); stmt, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt, sem); return StatementScope(stmt, sem, [&] {
return Scope(sem, [&] { auto* cond = Expression(stmt->condition);
if (!Expression(stmt->condition)) { if (!cond) {
return false;
}
auto* cond_type = TypeOf(stmt->condition)->UnwrapRef();
if (!cond_type->Is<sem::Bool>()) {
AddError(
"if statement condition must be bool, got " + TypeNameOf(cond_type),
stmt->condition->source);
return false; return false;
} }
sem->SetCondition(cond);
Mark(stmt->body); Mark(stmt->body);
auto* body = builder_->create<sem::BlockStatement>( auto* body = builder_->create<sem::BlockStatement>(
stmt->body, current_compound_statement_, current_function_); stmt->body, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt->body, body); if (!StatementScope(stmt->body, body,
if (!Scope(body, [&] { return Statements(stmt->body->statements); })) { [&] { return Statements(stmt->body->statements); })) {
return false; return false;
} }
@ -956,59 +909,56 @@ bool Resolver::IfStatement(const ast::IfStatement* stmt) {
return false; return false;
} }
} }
return true;
return ValidateIfStatement(sem);
}); });
} }
bool Resolver::ElseStatement(const ast::ElseStatement* stmt) { sem::ElseStatement* Resolver::ElseStatement(const ast::ElseStatement* stmt) {
auto* sem = builder_->create<sem::ElseStatement>( auto* sem = builder_->create<sem::ElseStatement>(
stmt, current_compound_statement_, current_function_); stmt, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt, sem); return StatementScope(stmt, sem, [&] {
return Scope(sem, [&] { if (auto* cond_expr = stmt->condition) {
if (auto* cond = stmt->condition) { auto* cond = Expression(cond_expr);
if (!Expression(cond)) { if (!cond) {
return false;
}
auto* else_cond_type = TypeOf(cond)->UnwrapRef();
if (!else_cond_type->Is<sem::Bool>()) {
AddError("else statement condition must be bool, got " +
TypeNameOf(else_cond_type),
cond->source);
return false; return false;
} }
sem->SetCondition(cond);
} }
Mark(stmt->body); Mark(stmt->body);
auto* body = builder_->create<sem::BlockStatement>( auto* body = builder_->create<sem::BlockStatement>(
stmt->body, current_compound_statement_, current_function_); stmt->body, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt->body, body); if (!StatementScope(stmt->body, body,
return Scope(body, [&] { return Statements(stmt->body->statements); }); [&] { return Statements(stmt->body->statements); })) {
return false;
}
return ValidateElseStatement(sem);
}); });
} }
bool Resolver::BlockStatement(const ast::BlockStatement* stmt) { sem::BlockStatement* Resolver::BlockStatement(const ast::BlockStatement* stmt) {
auto* sem = builder_->create<sem::BlockStatement>( auto* sem = builder_->create<sem::BlockStatement>(
stmt->As<ast::BlockStatement>(), current_compound_statement_, stmt->As<ast::BlockStatement>(), current_compound_statement_,
current_function_); current_function_);
builder_->Sem().Add(stmt, sem); return StatementScope(stmt, sem,
return Scope(sem, [&] { return Statements(stmt->statements); }); [&] { return Statements(stmt->statements); });
} }
bool Resolver::LoopStatement(const ast::LoopStatement* stmt) { sem::LoopStatement* Resolver::LoopStatement(const ast::LoopStatement* stmt) {
auto* sem = builder_->create<sem::LoopStatement>( auto* sem = builder_->create<sem::LoopStatement>(
stmt, current_compound_statement_, current_function_); stmt, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt, sem); return StatementScope(stmt, sem, [&] {
return Scope(sem, [&] {
Mark(stmt->body); Mark(stmt->body);
auto* body = builder_->create<sem::LoopBlockStatement>( auto* body = builder_->create<sem::LoopBlockStatement>(
stmt->body, current_compound_statement_, current_function_); stmt->body, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt->body, body); return StatementScope(stmt->body, body, [&] {
return Scope(body, [&] {
if (!Statements(stmt->body->statements)) { if (!Statements(stmt->body->statements)) {
return false; return false;
} }
if (stmt->continuing) { if (stmt->continuing) {
Mark(stmt->continuing); Mark(stmt->continuing);
if (!stmt->continuing->Empty()) { if (!stmt->continuing->Empty()) {
@ -1016,24 +966,22 @@ bool Resolver::LoopStatement(const ast::LoopStatement* stmt) {
builder_->create<sem::LoopContinuingBlockStatement>( builder_->create<sem::LoopContinuingBlockStatement>(
stmt->continuing, current_compound_statement_, stmt->continuing, current_compound_statement_,
current_function_); current_function_);
builder_->Sem().Add(stmt->continuing, continuing); return StatementScope(stmt->continuing, continuing, [&] {
if (!Scope(continuing, [&] {
return Statements(stmt->continuing->statements); return Statements(stmt->continuing->statements);
})) { }) != nullptr;
return false;
}
} }
} }
return true; return true;
}); });
}); });
} }
bool Resolver::ForLoopStatement(const ast::ForLoopStatement* stmt) { sem::ForLoopStatement* Resolver::ForLoopStatement(
const ast::ForLoopStatement* stmt) {
auto* sem = builder_->create<sem::ForLoopStatement>( auto* sem = builder_->create<sem::ForLoopStatement>(
stmt, current_compound_statement_, current_function_); stmt, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt, sem); return StatementScope(stmt, sem, [&] {
return Scope(sem, [&] {
if (auto* initializer = stmt->initializer) { if (auto* initializer = stmt->initializer) {
Mark(initializer); Mark(initializer);
if (!Statement(initializer)) { if (!Statement(initializer)) {
@ -1041,17 +989,12 @@ bool Resolver::ForLoopStatement(const ast::ForLoopStatement* stmt) {
} }
} }
if (auto* condition = stmt->condition) { if (auto* cond_expr = stmt->condition) {
if (!Expression(condition)) { auto* cond = Expression(cond_expr);
return false; if (!cond) {
}
auto* cond_ty = TypeOf(condition)->UnwrapRef();
if (!cond_ty->Is<sem::Bool>()) {
AddError("for-loop condition must be bool, got " + TypeNameOf(cond_ty),
condition->source);
return false; return false;
} }
sem->SetCondition(cond);
} }
if (auto* continuing = stmt->continuing) { if (auto* continuing = stmt->continuing) {
@ -1065,8 +1008,12 @@ bool Resolver::ForLoopStatement(const ast::ForLoopStatement* stmt) {
auto* body = builder_->create<sem::LoopBlockStatement>( auto* body = builder_->create<sem::LoopBlockStatement>(
stmt->body, current_compound_statement_, current_function_); stmt->body, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt->body, body); if (!StatementScope(stmt->body, body,
return Scope(body, [&] { return Statements(stmt->body->statements); }); [&] { return Statements(stmt->body->statements); })) {
return false;
}
return ValidateForLoopStatement(sem);
}); });
} }
@ -1930,33 +1877,6 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
return builder_->create<sem::Expression>(unary, ty, current_statement_, val); return builder_->create<sem::Expression>(unary, ty, current_statement_, val);
} }
bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
Mark(stmt->variable);
auto* var = Variable(stmt->variable, VariableKind::kLocal);
if (!var) {
return false;
}
for (auto* deco : stmt->variable->decorations) {
Mark(deco);
if (!deco->Is<ast::InternalDecoration>()) {
AddError("decorations are not valid on local variables", deco->source);
return false;
}
}
if (current_block_) { // Not all statements are inside a block
current_block_->AddDecl(stmt->variable);
}
if (!ValidateVariable(var)) {
return false;
}
return true;
}
sem::Type* Resolver::TypeDecl(const ast::TypeDecl* named_type) { sem::Type* Resolver::TypeDecl(const ast::TypeDecl* named_type) {
sem::Type* result = nullptr; sem::Type* result = nullptr;
if (auto* alias = named_type->As<ast::Alias>()) { if (auto* alias = named_type->As<ast::Alias>()) {
@ -2318,8 +2238,11 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) {
return out; return out;
} }
bool Resolver::Return(const ast::ReturnStatement* ret) { sem::Statement* Resolver::ReturnStatement(const ast::ReturnStatement* stmt) {
if (auto* value = ret->value) { auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
if (auto* value = stmt->value) {
if (!Expression(value)) { if (!Expression(value)) {
return false; return false;
} }
@ -2327,36 +2250,115 @@ bool Resolver::Return(const ast::ReturnStatement* ret) {
// Validate after processing the return value expression so that its type is // Validate after processing the return value expression so that its type is
// available for validation. // available for validation.
return ValidateReturn(ret); return ValidateReturn(stmt);
});
} }
bool Resolver::SwitchStatement(const ast::SwitchStatement* stmt) { sem::SwitchStatement* Resolver::SwitchStatement(
const ast::SwitchStatement* stmt) {
auto* sem = builder_->create<sem::SwitchStatement>( auto* sem = builder_->create<sem::SwitchStatement>(
stmt, current_compound_statement_, current_function_); stmt, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt, sem); return StatementScope(stmt, sem, [&] {
return Scope(sem, [&] {
if (!Expression(stmt->condition)) { if (!Expression(stmt->condition)) {
return false; return false;
} }
for (auto* case_stmt : stmt->body) { for (auto* case_stmt : stmt->body) {
Mark(case_stmt); Mark(case_stmt);
if (!CaseStatement(case_stmt)) { if (!CaseStatement(case_stmt)) {
return false; return false;
} }
} }
if (!ValidateSwitch(stmt)) {
return false; return ValidateSwitch(stmt);
}
return true;
}); });
} }
bool Resolver::Assignment(const ast::AssignmentStatement* a) { sem::Statement* Resolver::VariableDeclStatement(
if (!Expression(a->lhs) || !Expression(a->rhs)) { const ast::VariableDeclStatement* stmt) {
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
Mark(stmt->variable);
auto* var = Variable(stmt->variable, VariableKind::kLocal);
if (!var) {
return false; return false;
} }
return ValidateAssignment(a); for (auto* deco : stmt->variable->decorations) {
Mark(deco);
if (!deco->Is<ast::InternalDecoration>()) {
AddError("decorations are not valid on local variables", deco->source);
return false;
}
}
if (current_block_) { // Not all statements are inside a block
current_block_->AddDecl(stmt->variable);
}
return ValidateVariable(var);
});
}
sem::Statement* Resolver::AssignmentStatement(
const ast::AssignmentStatement* stmt) {
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
if (!Expression(stmt->lhs) || !Expression(stmt->rhs)) {
return false;
}
return ValidateAssignment(stmt);
});
}
sem::Statement* Resolver::BreakStatement(const ast::BreakStatement* stmt) {
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] { return ValidateBreakStatement(sem); });
}
sem::Statement* Resolver::CallStatement(const ast::CallStatement* stmt) {
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] { return Expression(stmt->expr); });
}
sem::Statement* Resolver::ContinueStatement(
const ast::ContinueStatement* stmt) {
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
// Set if we've hit the first continue statement in our parent loop
if (auto* block = sem->FindFirstParent<sem::LoopBlockStatement>()) {
if (!block->FirstContinue()) {
const_cast<sem::LoopBlockStatement*>(block)->SetFirstContinue(
stmt, block->Decls().size());
}
}
return ValidateContinueStatement(sem);
});
}
sem::Statement* Resolver::DiscardStatement(const ast::DiscardStatement* stmt) {
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
current_function_->SetHasDiscard();
return ValidateDiscardStatement(sem);
});
}
sem::Statement* Resolver::FallthroughStatement(
const ast::FallthroughStatement* stmt) {
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] { return true; });
} }
bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc, bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
@ -2399,22 +2401,28 @@ bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
return true; return true;
} }
template <typename F> template <typename SEM, typename F>
bool Resolver::Scope(sem::CompoundStatement* stmt, F&& callback) { SEM* Resolver::StatementScope(const ast::Statement* ast,
auto* prev_current_statement = current_statement_; SEM* sem,
auto* prev_current_compound_statement = current_compound_statement_; F&& callback) {
auto* prev_current_block = current_block_; builder_->Sem().Add(ast, sem);
current_statement_ = stmt;
current_compound_statement_ = stmt;
current_block_ = stmt->As<sem::BlockStatement>();
TINT_DEFER({ auto* as_compound =
current_block_ = prev_current_block; As<sem::CompoundStatement, CastFlags::kDontErrorOnImpossibleCast>(sem);
current_compound_statement_ = prev_current_compound_statement; auto* as_block =
current_statement_ = prev_current_statement; As<sem::BlockStatement, CastFlags::kDontErrorOnImpossibleCast>(sem);
});
return callback(); TINT_SCOPED_ASSIGNMENT(current_statement_, sem);
TINT_SCOPED_ASSIGNMENT(
current_compound_statement_,
as_compound ? as_compound : current_compound_statement_);
TINT_SCOPED_ASSIGNMENT(current_block_, as_block ? as_block : current_block_);
if (!callback()) {
return nullptr;
}
return sem;
} }
std::string Resolver::VectorPretty(uint32_t size, std::string Resolver::VectorPretty(uint32_t size,

View File

@ -59,8 +59,15 @@ class Variable;
namespace sem { namespace sem {
class Array; class Array;
class Atomic; class Atomic;
class BlockStatement;
class ElseStatement;
class ForLoopStatement;
class IfStatement;
class Intrinsic; class Intrinsic;
class LoopStatement;
class Statement; class Statement;
class SwitchCaseBlockStatement;
class SwitchStatement;
class TypeConstructor; class TypeConstructor;
} // namespace sem } // namespace sem
@ -198,20 +205,26 @@ class Resolver {
// Statement resolving methods // Statement resolving methods
// Each return true on success, false on failure. // Each return true on success, false on failure.
bool Assignment(const ast::AssignmentStatement* a); sem::Statement* AssignmentStatement(const ast::AssignmentStatement*);
bool BlockStatement(const ast::BlockStatement*); sem::BlockStatement* BlockStatement(const ast::BlockStatement*);
bool CaseStatement(const ast::CaseStatement*); sem::Statement* BreakStatement(const ast::BreakStatement*);
bool ElseStatement(const ast::ElseStatement*); sem::Statement* CallStatement(const ast::CallStatement*);
bool ForLoopStatement(const ast::ForLoopStatement*); sem::SwitchCaseBlockStatement* CaseStatement(const ast::CaseStatement*);
bool Parameter(const ast::Variable* param); sem::Statement* ContinueStatement(const ast::ContinueStatement*);
bool GlobalVariable(const ast::Variable* var); sem::Statement* DiscardStatement(const ast::DiscardStatement*);
bool IfStatement(const ast::IfStatement*); sem::ElseStatement* ElseStatement(const ast::ElseStatement*);
bool LoopStatement(const ast::LoopStatement*); sem::Statement* FallthroughStatement(const ast::FallthroughStatement*);
bool Return(const ast::ReturnStatement* ret); sem::ForLoopStatement* ForLoopStatement(const ast::ForLoopStatement*);
bool Statement(const ast::Statement*); sem::Statement* Parameter(const ast::Variable*);
sem::IfStatement* IfStatement(const ast::IfStatement*);
sem::LoopStatement* LoopStatement(const ast::LoopStatement*);
sem::Statement* ReturnStatement(const ast::ReturnStatement*);
sem::Statement* Statement(const ast::Statement*);
sem::SwitchStatement* SwitchStatement(const ast::SwitchStatement* s);
sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*);
bool Statements(const ast::StatementList&); bool Statements(const ast::StatementList&);
bool SwitchStatement(const ast::SwitchStatement* s);
bool VariableDeclStatement(const ast::VariableDeclStatement*); bool GlobalVariable(const ast::Variable*);
// AST and Type validation methods // AST and Type validation methods
// Each return true on success, false on failure. // Each return true on success, false on failure.
@ -224,13 +237,19 @@ class Resolver {
bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s); bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s);
bool ValidateAtomicVariable(const sem::Variable* var); bool ValidateAtomicVariable(const sem::Variable* var);
bool ValidateAssignment(const ast::AssignmentStatement* a); bool ValidateAssignment(const ast::AssignmentStatement* a);
bool ValidateBreakStatement(const sem::Statement* stmt);
bool ValidateContinueStatement(const sem::Statement* stmt);
bool ValidateDiscardStatement(const sem::Statement* stmt);
bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco, bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
const sem::Type* storage_type, const sem::Type* storage_type,
const bool is_input); const bool is_input);
bool ValidateElseStatement(const sem::ElseStatement* stmt);
bool ValidateEntryPoint(const sem::Function* func); bool ValidateEntryPoint(const sem::Function* func);
bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt);
bool ValidateFunction(const sem::Function* func); bool ValidateFunction(const sem::Function* func);
bool ValidateFunctionCall(const sem::Call* call); bool ValidateFunctionCall(const sem::Call* call);
bool ValidateGlobalVariable(const sem::Variable* var); bool ValidateGlobalVariable(const sem::Variable* var);
bool ValidateIfStatement(const sem::IfStatement* stmt);
bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco, bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco,
const sem::Type* storage_type); const sem::Type* storage_type);
bool ValidateIntrinsicCall(const sem::Call* call); bool ValidateIntrinsicCall(const sem::Call* call);
@ -369,14 +388,19 @@ class Resolver {
/// @param lit the literal /// @param lit the literal
sem::Type* TypeOf(const ast::LiteralExpression* lit); sem::Type* TypeOf(const ast::LiteralExpression* lit);
/// Assigns `stmt` to #current_statement_, #current_compound_statement_, and /// StatementScope() does the following:
/// possibly #current_block_, pushes the variable scope, then calls /// * Creates the AST -> SEM mapping.
/// `callback`. Before returning #current_statement_, /// * Assigns `sem` to #current_statement_
/// #current_compound_statement_, and #current_block_ are restored to their /// * Assigns `sem` to #current_compound_statement_ if `sem` derives from
/// original values, and the variable scope is popped. /// sem::CompoundStatement.
/// @returns the value returned by callback /// * Assigns `sem` to #current_block_ if `sem` derives from
template <typename F> /// sem::BlockStatement.
bool Scope(sem::CompoundStatement* stmt, F&& callback); /// * Then calls `callback`.
/// * Before returning #current_statement_, #current_compound_statement_, and
/// #current_block_ are restored to their original values.
/// @returns `sem` if `callback` returns true, otherwise `nullptr`.
template <typename SEM, typename F>
SEM* StatementScope(const ast::Statement* ast, SEM* sem, F&& callback);
/// Returns a human-readable string representation of the vector type name /// Returns a human-readable string representation of the vector type name
/// with the given parameters. /// with the given parameters.

View File

@ -1344,6 +1344,82 @@ bool Resolver::ValidateStatements(const ast::StatementList& stmts) {
return true; return true;
} }
bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) {
if (!stmt->FindFirstParent<sem::LoopBlockStatement>() &&
!stmt->FindFirstParent<sem::SwitchCaseBlockStatement>()) {
AddError("break statement must be in a loop or switch case",
stmt->Declaration()->source);
return false;
}
return true;
}
bool Resolver::ValidateContinueStatement(const sem::Statement* stmt) {
if (auto* block =
stmt->FindFirstParent<sem::LoopBlockStatement,
sem::LoopContinuingBlockStatement>()) {
if (block->Is<sem::LoopContinuingBlockStatement>()) {
AddError("continuing blocks must not contain a continue statement",
stmt->Declaration()->source);
return false;
}
} else {
AddError("continue statement must be in a loop",
stmt->Declaration()->source);
return false;
}
return true;
}
bool Resolver::ValidateDiscardStatement(const sem::Statement* stmt) {
if (auto* continuing =
stmt->FindFirstParent<sem::LoopContinuingBlockStatement>()) {
AddError("continuing blocks must not contain a discard statement",
stmt->Declaration()->source);
if (continuing != stmt->Parent()) {
AddNote("see continuing block here", continuing->Declaration()->source);
}
return false;
}
return true;
}
bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) {
if (auto* cond = stmt->Condition()) {
auto* cond_ty = cond->Type()->UnwrapRef();
if (!cond_ty->Is<sem::Bool>()) {
AddError(
"else statement condition must be bool, got " + TypeNameOf(cond_ty),
stmt->Condition()->Declaration()->source);
return false;
}
}
return true;
}
bool Resolver::ValidateForLoopStatement(const sem::ForLoopStatement* stmt) {
if (auto* cond = stmt->Condition()) {
auto* cond_ty = cond->Type()->UnwrapRef();
if (!cond_ty->Is<sem::Bool>()) {
AddError("for-loop condition must be bool, got " + TypeNameOf(cond_ty),
stmt->Condition()->Declaration()->source);
return false;
}
}
return true;
}
bool Resolver::ValidateIfStatement(const sem::IfStatement* stmt) {
auto* cond_ty = stmt->Condition()->Type()->UnwrapRef();
if (!cond_ty->Is<sem::Bool>()) {
AddError("if statement condition must be bool, got " + TypeNameOf(cond_ty),
stmt->Condition()->Declaration()->source);
return false;
}
return true;
}
bool Resolver::ValidateIntrinsicCall(const sem::Call* call) { bool Resolver::ValidateIntrinsicCall(const sem::Call* call) {
if (call->Type()->Is<sem::Void>()) { if (call->Type()->Is<sem::Void>()) {
bool is_call_statement = false; bool is_call_statement = false;
@ -2103,8 +2179,8 @@ bool Resolver::ValidateReturn(const ast::ReturnStatement* ret) {
} }
bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) { bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) {
auto* cond_type = TypeOf(s->condition)->UnwrapRef(); auto* cond_ty = TypeOf(s->condition)->UnwrapRef();
if (!cond_type->is_integer_scalar()) { if (!cond_ty->is_integer_scalar()) {
AddError( AddError(
"switch statement selector expression must be of a " "switch statement selector expression must be of a "
"scalar integer type", "scalar integer type",
@ -2127,7 +2203,7 @@ bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) {
} }
for (auto* selector : case_stmt->selectors) { for (auto* selector : case_stmt->selectors) {
if (cond_type != TypeOf(selector)) { if (cond_ty != TypeOf(selector)) {
AddError( AddError(
"the case selector values must have the same " "the case selector values must have the same "
"type as the selector expression.", "type as the selector expression.",

View File

@ -21,6 +21,9 @@ namespace tint {
namespace ast { namespace ast {
class ForLoopStatement; class ForLoopStatement;
} // namespace ast } // namespace ast
namespace sem {
class Expression;
} // namespace sem
} // namespace tint } // namespace tint
namespace tint { namespace tint {
@ -39,6 +42,16 @@ class ForLoopStatement : public Castable<ForLoopStatement, CompoundStatement> {
/// Destructor /// Destructor
~ForLoopStatement() override; ~ForLoopStatement() override;
/// @returns the for-loop condition expression
const Expression* Condition() const { return condition_; }
/// Sets the for-loop condition expression
/// @param condition the for-loop condition expression
void SetCondition(const Expression* condition) { condition_ = condition; }
private:
const Expression* condition_ = nullptr;
}; };
} // namespace sem } // namespace sem

View File

@ -23,6 +23,9 @@ namespace ast {
class IfStatement; class IfStatement;
class ElseStatement; class ElseStatement;
} // namespace ast } // namespace ast
namespace sem {
class Expression;
} // namespace sem
} // namespace tint } // namespace tint
namespace tint { namespace tint {
@ -41,6 +44,16 @@ class IfStatement : public Castable<IfStatement, CompoundStatement> {
/// Destructor /// Destructor
~IfStatement() override; ~IfStatement() override;
/// @returns the if-statement condition expression
const Expression* Condition() const { return condition_; }
/// Sets the if-statement condition expression
/// @param condition the if condition expression
void SetCondition(const Expression* condition) { condition_ = condition; }
private:
const Expression* condition_ = nullptr;
}; };
/// Holds semantic information about an else statement /// Holds semantic information about an else statement
@ -56,6 +69,16 @@ class ElseStatement : public Castable<ElseStatement, CompoundStatement> {
/// Destructor /// Destructor
~ElseStatement() override; ~ElseStatement() override;
/// @returns the else-statement condition expression
const Expression* Condition() const { return condition_; }
/// Sets the else-statement condition expression
/// @param condition the else condition expression
void SetCondition(const Expression* condition) { condition_ = condition; }
private:
const Expression* condition_ = nullptr;
}; };
} // namespace sem } // namespace sem

View File

@ -82,7 +82,7 @@ class Swizzle : public Castable<Swizzle, MemberAccessorExpression> {
/// Constructor /// Constructor
/// @param declaration the AST node /// @param declaration the AST node
/// @param type the resolved type of the expression /// @param type the resolved type of the expression
/// @param statement the statement that /// @param statement the statement that owns this expression
/// @param indices the swizzle indices /// @param indices the swizzle indices
Swizzle(const ast::MemberAccessorExpression* declaration, Swizzle(const ast::MemberAccessorExpression* declaration,
const sem::Type* type, const sem::Type* type,

View File

@ -234,7 +234,7 @@ class VariableUser : public Castable<VariableUser, Expression> {
const sem::Variable* Variable() const { return variable_; } const sem::Variable* Variable() const { return variable_; }
private: private:
sem::Variable const* const variable_; const sem::Variable* const variable_;
}; };
} // namespace sem } // namespace sem