resolver: Refactor Statement handling

Break up Resolver::Statement() into multiple resolver functions.
Move simple statement validation out to resolver_validation.cc

Change-Id: Ifa29433af0a9afa39a66ac3e4f7ca376351adfbf
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/71102
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Ben Clayton 2021-11-26 16:26:42 +00:00 committed by Tint LUCI CQ
parent c5835eb4a0
commit 8c30d752a0
7 changed files with 358 additions and 214 deletions

View File

@ -654,9 +654,9 @@ sem::Function* Resolver::Function(const ast::Function* decl) {
<< "Resolver::Function() called with a current compound statement";
return nullptr;
}
auto* sem_block = builder_->create<sem::FunctionBlockStatement>(func);
builder_->Sem().Add(decl->body, sem_block);
if (!Scope(sem_block, [&] { return Statements(decl->body->statements); })) {
if (!StatementScope(decl->body,
builder_->create<sem::FunctionBlockStatement>(func),
[&] { return Statements(decl->body->statements); })) {
return nullptr;
}
}
@ -796,7 +796,8 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
bool Resolver::Statements(const ast::StatementList& stmts) {
for (auto* stmt : stmts) {
Mark(stmt);
if (!Statement(stmt)) {
auto* sem = Statement(stmt);
if (!sem) {
return false;
}
}
@ -807,18 +808,18 @@ bool Resolver::Statements(const ast::StatementList& stmts) {
return true;
}
bool Resolver::Statement(const ast::Statement* stmt) {
sem::Statement* Resolver::Statement(const ast::Statement* stmt) {
if (stmt->Is<ast::CaseStatement>()) {
AddError("case statement can only be used inside a switch statement",
stmt->source);
return false;
return nullptr;
}
if (stmt->Is<ast::ElseStatement>()) {
TINT_ICE(Resolver, diagnostics_)
<< "Resolver::Statement() encountered an Else statement. Else "
"statements are embedded in If statements, so should never be "
"encountered as top-level statements";
return false;
return nullptr;
}
// Compound statements. These create their own sem::CompoundStatement
@ -840,69 +841,26 @@ bool Resolver::Statement(const ast::Statement* stmt) {
}
// Non-Compound statements
sem::Statement* sem_statement = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt, sem_statement);
TINT_SCOPED_ASSIGNMENT(current_statement_, sem_statement);
if (auto* a = stmt->As<ast::AssignmentStatement>()) {
return Assignment(a);
return AssignmentStatement(a);
}
if (stmt->Is<ast::BreakStatement>()) {
if (!sem_statement->FindFirstParent<sem::LoopBlockStatement>() &&
!sem_statement->FindFirstParent<sem::SwitchCaseBlockStatement>()) {
AddError("break statement must be in a loop or switch case",
stmt->source);
return false;
}
return true;
if (auto* b = stmt->As<ast::BreakStatement>()) {
return BreakStatement(b);
}
if (auto* c = stmt->As<ast::CallStatement>()) {
if (!Expression(c->expr)) {
return false;
}
return true;
return CallStatement(c);
}
if (auto* c = stmt->As<ast::ContinueStatement>()) {
// Set if we've hit the first continue statement in our parent loop
if (auto* block =
current_block_->FindFirstParent<
sem::LoopBlockStatement, sem::LoopContinuingBlockStatement>()) {
if (auto* loop_block = block->As<sem::LoopBlockStatement>()) {
if (!loop_block->FirstContinue()) {
const_cast<sem::LoopBlockStatement*>(loop_block)
->SetFirstContinue(c, loop_block->Decls().size());
}
} else {
AddError("continuing blocks must not contain a continue statement",
stmt->source);
return false;
}
} else {
AddError("continue statement must be in a loop", stmt->source);
return false;
}
return true;
return ContinueStatement(c);
}
if (stmt->Is<ast::DiscardStatement>()) {
if (auto* continuing =
sem_statement
->FindFirstParent<sem::LoopContinuingBlockStatement>()) {
AddError("continuing blocks must not contain a discard statement",
stmt->source);
if (continuing != sem_statement->Parent()) {
AddNote("see continuing block here", continuing->Declaration()->source);
}
return false;
}
current_function_->SetHasDiscard();
return true;
if (auto* d = stmt->As<ast::DiscardStatement>()) {
return DiscardStatement(d);
}
if (stmt->Is<ast::FallthroughStatement>()) {
return true;
if (auto* f = stmt->As<ast::FallthroughStatement>()) {
return FallthroughStatement(f);
}
if (auto* r = stmt->As<ast::ReturnStatement>()) {
return Return(r);
return ReturnStatement(r);
}
if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
return VariableDeclStatement(v);
@ -910,43 +868,38 @@ bool Resolver::Statement(const ast::Statement* stmt) {
AddError("unknown statement type: " + std::string(stmt->TypeInfo().name),
stmt->source);
return false;
return nullptr;
}
bool Resolver::CaseStatement(const ast::CaseStatement* stmt) {
sem::SwitchCaseBlockStatement* Resolver::CaseStatement(
const ast::CaseStatement* stmt) {
auto* sem = builder_->create<sem::SwitchCaseBlockStatement>(
stmt->body, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt, sem);
builder_->Sem().Add(stmt->body, sem);
Mark(stmt->body);
for (auto* sel : stmt->selectors) {
Mark(sel);
}
return Scope(sem, [&] { return Statements(stmt->body->statements); });
return StatementScope(stmt, sem, [&] {
builder_->Sem().Add(stmt->body, sem);
Mark(stmt->body);
for (auto* sel : stmt->selectors) {
Mark(sel);
}
return Statements(stmt->body->statements);
});
}
bool Resolver::IfStatement(const ast::IfStatement* stmt) {
sem::IfStatement* Resolver::IfStatement(const ast::IfStatement* stmt) {
auto* sem = builder_->create<sem::IfStatement>(
stmt, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt, sem);
return Scope(sem, [&] {
if (!Expression(stmt->condition)) {
return false;
}
auto* cond_type = TypeOf(stmt->condition)->UnwrapRef();
if (!cond_type->Is<sem::Bool>()) {
AddError(
"if statement condition must be bool, got " + TypeNameOf(cond_type),
stmt->condition->source);
return StatementScope(stmt, sem, [&] {
auto* cond = Expression(stmt->condition);
if (!cond) {
return false;
}
sem->SetCondition(cond);
Mark(stmt->body);
auto* body = builder_->create<sem::BlockStatement>(
stmt->body, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt->body, body);
if (!Scope(body, [&] { return Statements(stmt->body->statements); })) {
if (!StatementScope(stmt->body, body,
[&] { return Statements(stmt->body->statements); })) {
return false;
}
@ -956,59 +909,56 @@ bool Resolver::IfStatement(const ast::IfStatement* stmt) {
return false;
}
}
return true;
return ValidateIfStatement(sem);
});
}
bool Resolver::ElseStatement(const ast::ElseStatement* stmt) {
sem::ElseStatement* Resolver::ElseStatement(const ast::ElseStatement* stmt) {
auto* sem = builder_->create<sem::ElseStatement>(
stmt, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt, sem);
return Scope(sem, [&] {
if (auto* cond = stmt->condition) {
if (!Expression(cond)) {
return false;
}
auto* else_cond_type = TypeOf(cond)->UnwrapRef();
if (!else_cond_type->Is<sem::Bool>()) {
AddError("else statement condition must be bool, got " +
TypeNameOf(else_cond_type),
cond->source);
return StatementScope(stmt, sem, [&] {
if (auto* cond_expr = stmt->condition) {
auto* cond = Expression(cond_expr);
if (!cond) {
return false;
}
sem->SetCondition(cond);
}
Mark(stmt->body);
auto* body = builder_->create<sem::BlockStatement>(
stmt->body, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt->body, body);
return Scope(body, [&] { return Statements(stmt->body->statements); });
if (!StatementScope(stmt->body, body,
[&] { return Statements(stmt->body->statements); })) {
return false;
}
return ValidateElseStatement(sem);
});
}
bool Resolver::BlockStatement(const ast::BlockStatement* stmt) {
sem::BlockStatement* Resolver::BlockStatement(const ast::BlockStatement* stmt) {
auto* sem = builder_->create<sem::BlockStatement>(
stmt->As<ast::BlockStatement>(), current_compound_statement_,
current_function_);
builder_->Sem().Add(stmt, sem);
return Scope(sem, [&] { return Statements(stmt->statements); });
return StatementScope(stmt, sem,
[&] { return Statements(stmt->statements); });
}
bool Resolver::LoopStatement(const ast::LoopStatement* stmt) {
sem::LoopStatement* Resolver::LoopStatement(const ast::LoopStatement* stmt) {
auto* sem = builder_->create<sem::LoopStatement>(
stmt, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt, sem);
return Scope(sem, [&] {
return StatementScope(stmt, sem, [&] {
Mark(stmt->body);
auto* body = builder_->create<sem::LoopBlockStatement>(
stmt->body, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt->body, body);
return Scope(body, [&] {
return StatementScope(stmt->body, body, [&] {
if (!Statements(stmt->body->statements)) {
return false;
}
if (stmt->continuing) {
Mark(stmt->continuing);
if (!stmt->continuing->Empty()) {
@ -1016,24 +966,22 @@ bool Resolver::LoopStatement(const ast::LoopStatement* stmt) {
builder_->create<sem::LoopContinuingBlockStatement>(
stmt->continuing, current_compound_statement_,
current_function_);
builder_->Sem().Add(stmt->continuing, continuing);
if (!Scope(continuing, [&] {
return Statements(stmt->continuing->statements);
})) {
return false;
}
return StatementScope(stmt->continuing, continuing, [&] {
return Statements(stmt->continuing->statements);
}) != nullptr;
}
}
return true;
});
});
}
bool Resolver::ForLoopStatement(const ast::ForLoopStatement* stmt) {
sem::ForLoopStatement* Resolver::ForLoopStatement(
const ast::ForLoopStatement* stmt) {
auto* sem = builder_->create<sem::ForLoopStatement>(
stmt, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt, sem);
return Scope(sem, [&] {
return StatementScope(stmt, sem, [&] {
if (auto* initializer = stmt->initializer) {
Mark(initializer);
if (!Statement(initializer)) {
@ -1041,17 +989,12 @@ bool Resolver::ForLoopStatement(const ast::ForLoopStatement* stmt) {
}
}
if (auto* condition = stmt->condition) {
if (!Expression(condition)) {
return false;
}
auto* cond_ty = TypeOf(condition)->UnwrapRef();
if (!cond_ty->Is<sem::Bool>()) {
AddError("for-loop condition must be bool, got " + TypeNameOf(cond_ty),
condition->source);
if (auto* cond_expr = stmt->condition) {
auto* cond = Expression(cond_expr);
if (!cond) {
return false;
}
sem->SetCondition(cond);
}
if (auto* continuing = stmt->continuing) {
@ -1065,8 +1008,12 @@ bool Resolver::ForLoopStatement(const ast::ForLoopStatement* stmt) {
auto* body = builder_->create<sem::LoopBlockStatement>(
stmt->body, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt->body, body);
return Scope(body, [&] { return Statements(stmt->body->statements); });
if (!StatementScope(stmt->body, body,
[&] { return Statements(stmt->body->statements); })) {
return false;
}
return ValidateForLoopStatement(sem);
});
}
@ -1930,33 +1877,6 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
return builder_->create<sem::Expression>(unary, ty, current_statement_, val);
}
bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
Mark(stmt->variable);
auto* var = Variable(stmt->variable, VariableKind::kLocal);
if (!var) {
return false;
}
for (auto* deco : stmt->variable->decorations) {
Mark(deco);
if (!deco->Is<ast::InternalDecoration>()) {
AddError("decorations are not valid on local variables", deco->source);
return false;
}
}
if (current_block_) { // Not all statements are inside a block
current_block_->AddDecl(stmt->variable);
}
if (!ValidateVariable(var)) {
return false;
}
return true;
}
sem::Type* Resolver::TypeDecl(const ast::TypeDecl* named_type) {
sem::Type* result = nullptr;
if (auto* alias = named_type->As<ast::Alias>()) {
@ -2318,45 +2238,127 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) {
return out;
}
bool Resolver::Return(const ast::ReturnStatement* ret) {
if (auto* value = ret->value) {
if (!Expression(value)) {
return false;
sem::Statement* Resolver::ReturnStatement(const ast::ReturnStatement* stmt) {
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
if (auto* value = stmt->value) {
if (!Expression(value)) {
return false;
}
}
}
// Validate after processing the return value expression so that its type is
// available for validation.
return ValidateReturn(ret);
// Validate after processing the return value expression so that its type is
// available for validation.
return ValidateReturn(stmt);
});
}
bool Resolver::SwitchStatement(const ast::SwitchStatement* stmt) {
sem::SwitchStatement* Resolver::SwitchStatement(
const ast::SwitchStatement* stmt) {
auto* sem = builder_->create<sem::SwitchStatement>(
stmt, current_compound_statement_, current_function_);
builder_->Sem().Add(stmt, sem);
return Scope(sem, [&] {
return StatementScope(stmt, sem, [&] {
if (!Expression(stmt->condition)) {
return false;
}
for (auto* case_stmt : stmt->body) {
Mark(case_stmt);
if (!CaseStatement(case_stmt)) {
return false;
}
}
if (!ValidateSwitch(stmt)) {
return false;
}
return true;
return ValidateSwitch(stmt);
});
}
bool Resolver::Assignment(const ast::AssignmentStatement* a) {
if (!Expression(a->lhs) || !Expression(a->rhs)) {
return false;
}
sem::Statement* Resolver::VariableDeclStatement(
const ast::VariableDeclStatement* stmt) {
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
Mark(stmt->variable);
return ValidateAssignment(a);
auto* var = Variable(stmt->variable, VariableKind::kLocal);
if (!var) {
return false;
}
for (auto* deco : stmt->variable->decorations) {
Mark(deco);
if (!deco->Is<ast::InternalDecoration>()) {
AddError("decorations are not valid on local variables", deco->source);
return false;
}
}
if (current_block_) { // Not all statements are inside a block
current_block_->AddDecl(stmt->variable);
}
return ValidateVariable(var);
});
}
sem::Statement* Resolver::AssignmentStatement(
const ast::AssignmentStatement* stmt) {
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
if (!Expression(stmt->lhs) || !Expression(stmt->rhs)) {
return false;
}
return ValidateAssignment(stmt);
});
}
sem::Statement* Resolver::BreakStatement(const ast::BreakStatement* stmt) {
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] { return ValidateBreakStatement(sem); });
}
sem::Statement* Resolver::CallStatement(const ast::CallStatement* stmt) {
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] { return Expression(stmt->expr); });
}
sem::Statement* Resolver::ContinueStatement(
const ast::ContinueStatement* stmt) {
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
// Set if we've hit the first continue statement in our parent loop
if (auto* block = sem->FindFirstParent<sem::LoopBlockStatement>()) {
if (!block->FirstContinue()) {
const_cast<sem::LoopBlockStatement*>(block)->SetFirstContinue(
stmt, block->Decls().size());
}
}
return ValidateContinueStatement(sem);
});
}
sem::Statement* Resolver::DiscardStatement(const ast::DiscardStatement* stmt) {
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
current_function_->SetHasDiscard();
return ValidateDiscardStatement(sem);
});
}
sem::Statement* Resolver::FallthroughStatement(
const ast::FallthroughStatement* stmt) {
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] { return true; });
}
bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
@ -2399,22 +2401,28 @@ bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
return true;
}
template <typename F>
bool Resolver::Scope(sem::CompoundStatement* stmt, F&& callback) {
auto* prev_current_statement = current_statement_;
auto* prev_current_compound_statement = current_compound_statement_;
auto* prev_current_block = current_block_;
current_statement_ = stmt;
current_compound_statement_ = stmt;
current_block_ = stmt->As<sem::BlockStatement>();
template <typename SEM, typename F>
SEM* Resolver::StatementScope(const ast::Statement* ast,
SEM* sem,
F&& callback) {
builder_->Sem().Add(ast, sem);
TINT_DEFER({
current_block_ = prev_current_block;
current_compound_statement_ = prev_current_compound_statement;
current_statement_ = prev_current_statement;
});
auto* as_compound =
As<sem::CompoundStatement, CastFlags::kDontErrorOnImpossibleCast>(sem);
auto* as_block =
As<sem::BlockStatement, CastFlags::kDontErrorOnImpossibleCast>(sem);
return callback();
TINT_SCOPED_ASSIGNMENT(current_statement_, sem);
TINT_SCOPED_ASSIGNMENT(
current_compound_statement_,
as_compound ? as_compound : current_compound_statement_);
TINT_SCOPED_ASSIGNMENT(current_block_, as_block ? as_block : current_block_);
if (!callback()) {
return nullptr;
}
return sem;
}
std::string Resolver::VectorPretty(uint32_t size,

View File

@ -59,8 +59,15 @@ class Variable;
namespace sem {
class Array;
class Atomic;
class BlockStatement;
class ElseStatement;
class ForLoopStatement;
class IfStatement;
class Intrinsic;
class LoopStatement;
class Statement;
class SwitchCaseBlockStatement;
class SwitchStatement;
class TypeConstructor;
} // namespace sem
@ -198,20 +205,26 @@ class Resolver {
// Statement resolving methods
// Each return true on success, false on failure.
bool Assignment(const ast::AssignmentStatement* a);
bool BlockStatement(const ast::BlockStatement*);
bool CaseStatement(const ast::CaseStatement*);
bool ElseStatement(const ast::ElseStatement*);
bool ForLoopStatement(const ast::ForLoopStatement*);
bool Parameter(const ast::Variable* param);
bool GlobalVariable(const ast::Variable* var);
bool IfStatement(const ast::IfStatement*);
bool LoopStatement(const ast::LoopStatement*);
bool Return(const ast::ReturnStatement* ret);
bool Statement(const ast::Statement*);
sem::Statement* AssignmentStatement(const ast::AssignmentStatement*);
sem::BlockStatement* BlockStatement(const ast::BlockStatement*);
sem::Statement* BreakStatement(const ast::BreakStatement*);
sem::Statement* CallStatement(const ast::CallStatement*);
sem::SwitchCaseBlockStatement* CaseStatement(const ast::CaseStatement*);
sem::Statement* ContinueStatement(const ast::ContinueStatement*);
sem::Statement* DiscardStatement(const ast::DiscardStatement*);
sem::ElseStatement* ElseStatement(const ast::ElseStatement*);
sem::Statement* FallthroughStatement(const ast::FallthroughStatement*);
sem::ForLoopStatement* ForLoopStatement(const ast::ForLoopStatement*);
sem::Statement* Parameter(const ast::Variable*);
sem::IfStatement* IfStatement(const ast::IfStatement*);
sem::LoopStatement* LoopStatement(const ast::LoopStatement*);
sem::Statement* ReturnStatement(const ast::ReturnStatement*);
sem::Statement* Statement(const ast::Statement*);
sem::SwitchStatement* SwitchStatement(const ast::SwitchStatement* s);
sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*);
bool Statements(const ast::StatementList&);
bool SwitchStatement(const ast::SwitchStatement* s);
bool VariableDeclStatement(const ast::VariableDeclStatement*);
bool GlobalVariable(const ast::Variable*);
// AST and Type validation methods
// Each return true on success, false on failure.
@ -224,13 +237,19 @@ class Resolver {
bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s);
bool ValidateAtomicVariable(const sem::Variable* var);
bool ValidateAssignment(const ast::AssignmentStatement* a);
bool ValidateBreakStatement(const sem::Statement* stmt);
bool ValidateContinueStatement(const sem::Statement* stmt);
bool ValidateDiscardStatement(const sem::Statement* stmt);
bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
const sem::Type* storage_type,
const bool is_input);
bool ValidateElseStatement(const sem::ElseStatement* stmt);
bool ValidateEntryPoint(const sem::Function* func);
bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt);
bool ValidateFunction(const sem::Function* func);
bool ValidateFunctionCall(const sem::Call* call);
bool ValidateGlobalVariable(const sem::Variable* var);
bool ValidateIfStatement(const sem::IfStatement* stmt);
bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco,
const sem::Type* storage_type);
bool ValidateIntrinsicCall(const sem::Call* call);
@ -369,14 +388,19 @@ class Resolver {
/// @param lit the literal
sem::Type* TypeOf(const ast::LiteralExpression* lit);
/// Assigns `stmt` to #current_statement_, #current_compound_statement_, and
/// possibly #current_block_, pushes the variable scope, then calls
/// `callback`. Before returning #current_statement_,
/// #current_compound_statement_, and #current_block_ are restored to their
/// original values, and the variable scope is popped.
/// @returns the value returned by callback
template <typename F>
bool Scope(sem::CompoundStatement* stmt, F&& callback);
/// StatementScope() does the following:
/// * Creates the AST -> SEM mapping.
/// * Assigns `sem` to #current_statement_
/// * Assigns `sem` to #current_compound_statement_ if `sem` derives from
/// sem::CompoundStatement.
/// * Assigns `sem` to #current_block_ if `sem` derives from
/// sem::BlockStatement.
/// * Then calls `callback`.
/// * Before returning #current_statement_, #current_compound_statement_, and
/// #current_block_ are restored to their original values.
/// @returns `sem` if `callback` returns true, otherwise `nullptr`.
template <typename SEM, typename F>
SEM* StatementScope(const ast::Statement* ast, SEM* sem, F&& callback);
/// Returns a human-readable string representation of the vector type name
/// with the given parameters.

View File

@ -1344,6 +1344,82 @@ bool Resolver::ValidateStatements(const ast::StatementList& stmts) {
return true;
}
bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) {
if (!stmt->FindFirstParent<sem::LoopBlockStatement>() &&
!stmt->FindFirstParent<sem::SwitchCaseBlockStatement>()) {
AddError("break statement must be in a loop or switch case",
stmt->Declaration()->source);
return false;
}
return true;
}
bool Resolver::ValidateContinueStatement(const sem::Statement* stmt) {
if (auto* block =
stmt->FindFirstParent<sem::LoopBlockStatement,
sem::LoopContinuingBlockStatement>()) {
if (block->Is<sem::LoopContinuingBlockStatement>()) {
AddError("continuing blocks must not contain a continue statement",
stmt->Declaration()->source);
return false;
}
} else {
AddError("continue statement must be in a loop",
stmt->Declaration()->source);
return false;
}
return true;
}
bool Resolver::ValidateDiscardStatement(const sem::Statement* stmt) {
if (auto* continuing =
stmt->FindFirstParent<sem::LoopContinuingBlockStatement>()) {
AddError("continuing blocks must not contain a discard statement",
stmt->Declaration()->source);
if (continuing != stmt->Parent()) {
AddNote("see continuing block here", continuing->Declaration()->source);
}
return false;
}
return true;
}
bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) {
if (auto* cond = stmt->Condition()) {
auto* cond_ty = cond->Type()->UnwrapRef();
if (!cond_ty->Is<sem::Bool>()) {
AddError(
"else statement condition must be bool, got " + TypeNameOf(cond_ty),
stmt->Condition()->Declaration()->source);
return false;
}
}
return true;
}
bool Resolver::ValidateForLoopStatement(const sem::ForLoopStatement* stmt) {
if (auto* cond = stmt->Condition()) {
auto* cond_ty = cond->Type()->UnwrapRef();
if (!cond_ty->Is<sem::Bool>()) {
AddError("for-loop condition must be bool, got " + TypeNameOf(cond_ty),
stmt->Condition()->Declaration()->source);
return false;
}
}
return true;
}
bool Resolver::ValidateIfStatement(const sem::IfStatement* stmt) {
auto* cond_ty = stmt->Condition()->Type()->UnwrapRef();
if (!cond_ty->Is<sem::Bool>()) {
AddError("if statement condition must be bool, got " + TypeNameOf(cond_ty),
stmt->Condition()->Declaration()->source);
return false;
}
return true;
}
bool Resolver::ValidateIntrinsicCall(const sem::Call* call) {
if (call->Type()->Is<sem::Void>()) {
bool is_call_statement = false;
@ -2103,8 +2179,8 @@ bool Resolver::ValidateReturn(const ast::ReturnStatement* ret) {
}
bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) {
auto* cond_type = TypeOf(s->condition)->UnwrapRef();
if (!cond_type->is_integer_scalar()) {
auto* cond_ty = TypeOf(s->condition)->UnwrapRef();
if (!cond_ty->is_integer_scalar()) {
AddError(
"switch statement selector expression must be of a "
"scalar integer type",
@ -2127,7 +2203,7 @@ bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) {
}
for (auto* selector : case_stmt->selectors) {
if (cond_type != TypeOf(selector)) {
if (cond_ty != TypeOf(selector)) {
AddError(
"the case selector values must have the same "
"type as the selector expression.",

View File

@ -21,6 +21,9 @@ namespace tint {
namespace ast {
class ForLoopStatement;
} // namespace ast
namespace sem {
class Expression;
} // namespace sem
} // namespace tint
namespace tint {
@ -39,6 +42,16 @@ class ForLoopStatement : public Castable<ForLoopStatement, CompoundStatement> {
/// Destructor
~ForLoopStatement() override;
/// @returns the for-loop condition expression
const Expression* Condition() const { return condition_; }
/// Sets the for-loop condition expression
/// @param condition the for-loop condition expression
void SetCondition(const Expression* condition) { condition_ = condition; }
private:
const Expression* condition_ = nullptr;
};
} // namespace sem

View File

@ -23,6 +23,9 @@ namespace ast {
class IfStatement;
class ElseStatement;
} // namespace ast
namespace sem {
class Expression;
} // namespace sem
} // namespace tint
namespace tint {
@ -41,6 +44,16 @@ class IfStatement : public Castable<IfStatement, CompoundStatement> {
/// Destructor
~IfStatement() override;
/// @returns the if-statement condition expression
const Expression* Condition() const { return condition_; }
/// Sets the if-statement condition expression
/// @param condition the if condition expression
void SetCondition(const Expression* condition) { condition_ = condition; }
private:
const Expression* condition_ = nullptr;
};
/// Holds semantic information about an else statement
@ -56,6 +69,16 @@ class ElseStatement : public Castable<ElseStatement, CompoundStatement> {
/// Destructor
~ElseStatement() override;
/// @returns the else-statement condition expression
const Expression* Condition() const { return condition_; }
/// Sets the else-statement condition expression
/// @param condition the else condition expression
void SetCondition(const Expression* condition) { condition_ = condition; }
private:
const Expression* condition_ = nullptr;
};
} // namespace sem

View File

@ -82,7 +82,7 @@ class Swizzle : public Castable<Swizzle, MemberAccessorExpression> {
/// Constructor
/// @param declaration the AST node
/// @param type the resolved type of the expression
/// @param statement the statement that
/// @param statement the statement that owns this expression
/// @param indices the swizzle indices
Swizzle(const ast::MemberAccessorExpression* declaration,
const sem::Type* type,

View File

@ -234,7 +234,7 @@ class VariableUser : public Castable<VariableUser, Expression> {
const sem::Variable* Variable() const { return variable_; }
private:
sem::Variable const* const variable_;
const sem::Variable* const variable_;
};
} // namespace sem