diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 173220c72f..5d4575d45b 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -654,9 +654,9 @@ sem::Function* Resolver::Function(const ast::Function* decl) { << "Resolver::Function() called with a current compound statement"; return nullptr; } - auto* sem_block = builder_->create(func); - builder_->Sem().Add(decl->body, sem_block); - if (!Scope(sem_block, [&] { return Statements(decl->body->statements); })) { + if (!StatementScope(decl->body, + builder_->create(func), + [&] { return Statements(decl->body->statements); })) { return nullptr; } } @@ -796,7 +796,8 @@ bool Resolver::WorkgroupSize(const ast::Function* func) { bool Resolver::Statements(const ast::StatementList& stmts) { for (auto* stmt : stmts) { Mark(stmt); - if (!Statement(stmt)) { + auto* sem = Statement(stmt); + if (!sem) { return false; } } @@ -807,18 +808,18 @@ bool Resolver::Statements(const ast::StatementList& stmts) { return true; } -bool Resolver::Statement(const ast::Statement* stmt) { +sem::Statement* Resolver::Statement(const ast::Statement* stmt) { if (stmt->Is()) { AddError("case statement can only be used inside a switch statement", stmt->source); - return false; + return nullptr; } if (stmt->Is()) { TINT_ICE(Resolver, diagnostics_) << "Resolver::Statement() encountered an Else statement. Else " "statements are embedded in If statements, so should never be " "encountered as top-level statements"; - return false; + return nullptr; } // Compound statements. These create their own sem::CompoundStatement @@ -840,69 +841,26 @@ bool Resolver::Statement(const ast::Statement* stmt) { } // Non-Compound statements - sem::Statement* sem_statement = builder_->create( - stmt, current_compound_statement_, current_function_); - builder_->Sem().Add(stmt, sem_statement); - TINT_SCOPED_ASSIGNMENT(current_statement_, sem_statement); if (auto* a = stmt->As()) { - return Assignment(a); + return AssignmentStatement(a); } - if (stmt->Is()) { - if (!sem_statement->FindFirstParent() && - !sem_statement->FindFirstParent()) { - AddError("break statement must be in a loop or switch case", - stmt->source); - return false; - } - return true; + if (auto* b = stmt->As()) { + return BreakStatement(b); } if (auto* c = stmt->As()) { - if (!Expression(c->expr)) { - return false; - } - return true; + return CallStatement(c); } if (auto* c = stmt->As()) { - // Set if we've hit the first continue statement in our parent loop - if (auto* block = - current_block_->FindFirstParent< - sem::LoopBlockStatement, sem::LoopContinuingBlockStatement>()) { - if (auto* loop_block = block->As()) { - if (!loop_block->FirstContinue()) { - const_cast(loop_block) - ->SetFirstContinue(c, loop_block->Decls().size()); - } - } else { - AddError("continuing blocks must not contain a continue statement", - stmt->source); - return false; - } - } else { - AddError("continue statement must be in a loop", stmt->source); - return false; - } - - return true; + return ContinueStatement(c); } - if (stmt->Is()) { - if (auto* continuing = - sem_statement - ->FindFirstParent()) { - AddError("continuing blocks must not contain a discard statement", - stmt->source); - if (continuing != sem_statement->Parent()) { - AddNote("see continuing block here", continuing->Declaration()->source); - } - return false; - } - current_function_->SetHasDiscard(); - return true; + if (auto* d = stmt->As()) { + return DiscardStatement(d); } - if (stmt->Is()) { - return true; + if (auto* f = stmt->As()) { + return FallthroughStatement(f); } if (auto* r = stmt->As()) { - return Return(r); + return ReturnStatement(r); } if (auto* v = stmt->As()) { return VariableDeclStatement(v); @@ -910,43 +868,38 @@ bool Resolver::Statement(const ast::Statement* stmt) { AddError("unknown statement type: " + std::string(stmt->TypeInfo().name), stmt->source); - return false; + return nullptr; } -bool Resolver::CaseStatement(const ast::CaseStatement* stmt) { +sem::SwitchCaseBlockStatement* Resolver::CaseStatement( + const ast::CaseStatement* stmt) { auto* sem = builder_->create( stmt->body, current_compound_statement_, current_function_); - builder_->Sem().Add(stmt, sem); - builder_->Sem().Add(stmt->body, sem); - Mark(stmt->body); - for (auto* sel : stmt->selectors) { - Mark(sel); - } - return Scope(sem, [&] { return Statements(stmt->body->statements); }); + return StatementScope(stmt, sem, [&] { + builder_->Sem().Add(stmt->body, sem); + Mark(stmt->body); + for (auto* sel : stmt->selectors) { + Mark(sel); + } + return Statements(stmt->body->statements); + }); } -bool Resolver::IfStatement(const ast::IfStatement* stmt) { +sem::IfStatement* Resolver::IfStatement(const ast::IfStatement* stmt) { auto* sem = builder_->create( stmt, current_compound_statement_, current_function_); - builder_->Sem().Add(stmt, sem); - return Scope(sem, [&] { - if (!Expression(stmt->condition)) { - return false; - } - - auto* cond_type = TypeOf(stmt->condition)->UnwrapRef(); - if (!cond_type->Is()) { - AddError( - "if statement condition must be bool, got " + TypeNameOf(cond_type), - stmt->condition->source); + return StatementScope(stmt, sem, [&] { + auto* cond = Expression(stmt->condition); + if (!cond) { return false; } + sem->SetCondition(cond); Mark(stmt->body); auto* body = builder_->create( stmt->body, current_compound_statement_, current_function_); - builder_->Sem().Add(stmt->body, body); - if (!Scope(body, [&] { return Statements(stmt->body->statements); })) { + if (!StatementScope(stmt->body, body, + [&] { return Statements(stmt->body->statements); })) { return false; } @@ -956,59 +909,56 @@ bool Resolver::IfStatement(const ast::IfStatement* stmt) { return false; } } - return true; + + return ValidateIfStatement(sem); }); } -bool Resolver::ElseStatement(const ast::ElseStatement* stmt) { +sem::ElseStatement* Resolver::ElseStatement(const ast::ElseStatement* stmt) { auto* sem = builder_->create( stmt, current_compound_statement_, current_function_); - builder_->Sem().Add(stmt, sem); - return Scope(sem, [&] { - if (auto* cond = stmt->condition) { - if (!Expression(cond)) { - return false; - } - - auto* else_cond_type = TypeOf(cond)->UnwrapRef(); - if (!else_cond_type->Is()) { - AddError("else statement condition must be bool, got " + - TypeNameOf(else_cond_type), - cond->source); + return StatementScope(stmt, sem, [&] { + if (auto* cond_expr = stmt->condition) { + auto* cond = Expression(cond_expr); + if (!cond) { return false; } + sem->SetCondition(cond); } Mark(stmt->body); auto* body = builder_->create( stmt->body, current_compound_statement_, current_function_); - builder_->Sem().Add(stmt->body, body); - return Scope(body, [&] { return Statements(stmt->body->statements); }); + if (!StatementScope(stmt->body, body, + [&] { return Statements(stmt->body->statements); })) { + return false; + } + + return ValidateElseStatement(sem); }); } -bool Resolver::BlockStatement(const ast::BlockStatement* stmt) { +sem::BlockStatement* Resolver::BlockStatement(const ast::BlockStatement* stmt) { auto* sem = builder_->create( stmt->As(), current_compound_statement_, current_function_); - builder_->Sem().Add(stmt, sem); - return Scope(sem, [&] { return Statements(stmt->statements); }); + return StatementScope(stmt, sem, + [&] { return Statements(stmt->statements); }); } -bool Resolver::LoopStatement(const ast::LoopStatement* stmt) { +sem::LoopStatement* Resolver::LoopStatement(const ast::LoopStatement* stmt) { auto* sem = builder_->create( stmt, current_compound_statement_, current_function_); - builder_->Sem().Add(stmt, sem); - return Scope(sem, [&] { + return StatementScope(stmt, sem, [&] { Mark(stmt->body); auto* body = builder_->create( stmt->body, current_compound_statement_, current_function_); - builder_->Sem().Add(stmt->body, body); - return Scope(body, [&] { + return StatementScope(stmt->body, body, [&] { if (!Statements(stmt->body->statements)) { return false; } + if (stmt->continuing) { Mark(stmt->continuing); if (!stmt->continuing->Empty()) { @@ -1016,24 +966,22 @@ bool Resolver::LoopStatement(const ast::LoopStatement* stmt) { builder_->create( stmt->continuing, current_compound_statement_, current_function_); - builder_->Sem().Add(stmt->continuing, continuing); - if (!Scope(continuing, [&] { - return Statements(stmt->continuing->statements); - })) { - return false; - } + return StatementScope(stmt->continuing, continuing, [&] { + return Statements(stmt->continuing->statements); + }) != nullptr; } } + return true; }); }); } -bool Resolver::ForLoopStatement(const ast::ForLoopStatement* stmt) { +sem::ForLoopStatement* Resolver::ForLoopStatement( + const ast::ForLoopStatement* stmt) { auto* sem = builder_->create( stmt, current_compound_statement_, current_function_); - builder_->Sem().Add(stmt, sem); - return Scope(sem, [&] { + return StatementScope(stmt, sem, [&] { if (auto* initializer = stmt->initializer) { Mark(initializer); if (!Statement(initializer)) { @@ -1041,17 +989,12 @@ bool Resolver::ForLoopStatement(const ast::ForLoopStatement* stmt) { } } - if (auto* condition = stmt->condition) { - if (!Expression(condition)) { - return false; - } - - auto* cond_ty = TypeOf(condition)->UnwrapRef(); - if (!cond_ty->Is()) { - AddError("for-loop condition must be bool, got " + TypeNameOf(cond_ty), - condition->source); + if (auto* cond_expr = stmt->condition) { + auto* cond = Expression(cond_expr); + if (!cond) { return false; } + sem->SetCondition(cond); } if (auto* continuing = stmt->continuing) { @@ -1065,8 +1008,12 @@ bool Resolver::ForLoopStatement(const ast::ForLoopStatement* stmt) { auto* body = builder_->create( stmt->body, current_compound_statement_, current_function_); - builder_->Sem().Add(stmt->body, body); - return Scope(body, [&] { return Statements(stmt->body->statements); }); + if (!StatementScope(stmt->body, body, + [&] { return Statements(stmt->body->statements); })) { + return false; + } + + return ValidateForLoopStatement(sem); }); } @@ -1930,33 +1877,6 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) { return builder_->create(unary, ty, current_statement_, val); } -bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) { - Mark(stmt->variable); - - auto* var = Variable(stmt->variable, VariableKind::kLocal); - if (!var) { - return false; - } - - for (auto* deco : stmt->variable->decorations) { - Mark(deco); - if (!deco->Is()) { - AddError("decorations are not valid on local variables", deco->source); - return false; - } - } - - if (current_block_) { // Not all statements are inside a block - current_block_->AddDecl(stmt->variable); - } - - if (!ValidateVariable(var)) { - return false; - } - - return true; -} - sem::Type* Resolver::TypeDecl(const ast::TypeDecl* named_type) { sem::Type* result = nullptr; if (auto* alias = named_type->As()) { @@ -2318,45 +2238,127 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) { return out; } -bool Resolver::Return(const ast::ReturnStatement* ret) { - if (auto* value = ret->value) { - if (!Expression(value)) { - return false; +sem::Statement* Resolver::ReturnStatement(const ast::ReturnStatement* stmt) { + auto* sem = builder_->create( + stmt, current_compound_statement_, current_function_); + return StatementScope(stmt, sem, [&] { + if (auto* value = stmt->value) { + if (!Expression(value)) { + return false; + } } - } - // Validate after processing the return value expression so that its type is - // available for validation. - return ValidateReturn(ret); + // Validate after processing the return value expression so that its type is + // available for validation. + return ValidateReturn(stmt); + }); } -bool Resolver::SwitchStatement(const ast::SwitchStatement* stmt) { +sem::SwitchStatement* Resolver::SwitchStatement( + const ast::SwitchStatement* stmt) { auto* sem = builder_->create( stmt, current_compound_statement_, current_function_); - builder_->Sem().Add(stmt, sem); - return Scope(sem, [&] { + return StatementScope(stmt, sem, [&] { if (!Expression(stmt->condition)) { return false; } + for (auto* case_stmt : stmt->body) { Mark(case_stmt); if (!CaseStatement(case_stmt)) { return false; } } - if (!ValidateSwitch(stmt)) { - return false; - } - return true; + + return ValidateSwitch(stmt); }); } -bool Resolver::Assignment(const ast::AssignmentStatement* a) { - if (!Expression(a->lhs) || !Expression(a->rhs)) { - return false; - } +sem::Statement* Resolver::VariableDeclStatement( + const ast::VariableDeclStatement* stmt) { + auto* sem = builder_->create( + stmt, current_compound_statement_, current_function_); + return StatementScope(stmt, sem, [&] { + Mark(stmt->variable); - return ValidateAssignment(a); + auto* var = Variable(stmt->variable, VariableKind::kLocal); + if (!var) { + return false; + } + + for (auto* deco : stmt->variable->decorations) { + Mark(deco); + if (!deco->Is()) { + AddError("decorations are not valid on local variables", deco->source); + return false; + } + } + + if (current_block_) { // Not all statements are inside a block + current_block_->AddDecl(stmt->variable); + } + + return ValidateVariable(var); + }); +} + +sem::Statement* Resolver::AssignmentStatement( + const ast::AssignmentStatement* stmt) { + auto* sem = builder_->create( + stmt, current_compound_statement_, current_function_); + return StatementScope(stmt, sem, [&] { + if (!Expression(stmt->lhs) || !Expression(stmt->rhs)) { + return false; + } + + return ValidateAssignment(stmt); + }); +} + +sem::Statement* Resolver::BreakStatement(const ast::BreakStatement* stmt) { + auto* sem = builder_->create( + stmt, current_compound_statement_, current_function_); + return StatementScope(stmt, sem, [&] { return ValidateBreakStatement(sem); }); +} + +sem::Statement* Resolver::CallStatement(const ast::CallStatement* stmt) { + auto* sem = builder_->create( + stmt, current_compound_statement_, current_function_); + return StatementScope(stmt, sem, [&] { return Expression(stmt->expr); }); +} + +sem::Statement* Resolver::ContinueStatement( + const ast::ContinueStatement* stmt) { + auto* sem = builder_->create( + stmt, current_compound_statement_, current_function_); + return StatementScope(stmt, sem, [&] { + // Set if we've hit the first continue statement in our parent loop + if (auto* block = sem->FindFirstParent()) { + if (!block->FirstContinue()) { + const_cast(block)->SetFirstContinue( + stmt, block->Decls().size()); + } + } + + return ValidateContinueStatement(sem); + }); +} + +sem::Statement* Resolver::DiscardStatement(const ast::DiscardStatement* stmt) { + auto* sem = builder_->create( + stmt, current_compound_statement_, current_function_); + return StatementScope(stmt, sem, [&] { + current_function_->SetHasDiscard(); + + return ValidateDiscardStatement(sem); + }); +} + +sem::Statement* Resolver::FallthroughStatement( + const ast::FallthroughStatement* stmt) { + auto* sem = builder_->create( + stmt, current_compound_statement_, current_function_); + return StatementScope(stmt, sem, [&] { return true; }); } bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc, @@ -2399,22 +2401,28 @@ bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc, return true; } -template -bool Resolver::Scope(sem::CompoundStatement* stmt, F&& callback) { - auto* prev_current_statement = current_statement_; - auto* prev_current_compound_statement = current_compound_statement_; - auto* prev_current_block = current_block_; - current_statement_ = stmt; - current_compound_statement_ = stmt; - current_block_ = stmt->As(); +template +SEM* Resolver::StatementScope(const ast::Statement* ast, + SEM* sem, + F&& callback) { + builder_->Sem().Add(ast, sem); - TINT_DEFER({ - current_block_ = prev_current_block; - current_compound_statement_ = prev_current_compound_statement; - current_statement_ = prev_current_statement; - }); + auto* as_compound = + As(sem); + auto* as_block = + As(sem); - return callback(); + TINT_SCOPED_ASSIGNMENT(current_statement_, sem); + TINT_SCOPED_ASSIGNMENT( + current_compound_statement_, + as_compound ? as_compound : current_compound_statement_); + TINT_SCOPED_ASSIGNMENT(current_block_, as_block ? as_block : current_block_); + + if (!callback()) { + return nullptr; + } + + return sem; } std::string Resolver::VectorPretty(uint32_t size, diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index 08a9c94d74..f2a2fc8e1a 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -59,8 +59,15 @@ class Variable; namespace sem { class Array; class Atomic; +class BlockStatement; +class ElseStatement; +class ForLoopStatement; +class IfStatement; class Intrinsic; +class LoopStatement; class Statement; +class SwitchCaseBlockStatement; +class SwitchStatement; class TypeConstructor; } // namespace sem @@ -198,20 +205,26 @@ class Resolver { // Statement resolving methods // Each return true on success, false on failure. - bool Assignment(const ast::AssignmentStatement* a); - bool BlockStatement(const ast::BlockStatement*); - bool CaseStatement(const ast::CaseStatement*); - bool ElseStatement(const ast::ElseStatement*); - bool ForLoopStatement(const ast::ForLoopStatement*); - bool Parameter(const ast::Variable* param); - bool GlobalVariable(const ast::Variable* var); - bool IfStatement(const ast::IfStatement*); - bool LoopStatement(const ast::LoopStatement*); - bool Return(const ast::ReturnStatement* ret); - bool Statement(const ast::Statement*); + sem::Statement* AssignmentStatement(const ast::AssignmentStatement*); + sem::BlockStatement* BlockStatement(const ast::BlockStatement*); + sem::Statement* BreakStatement(const ast::BreakStatement*); + sem::Statement* CallStatement(const ast::CallStatement*); + sem::SwitchCaseBlockStatement* CaseStatement(const ast::CaseStatement*); + sem::Statement* ContinueStatement(const ast::ContinueStatement*); + sem::Statement* DiscardStatement(const ast::DiscardStatement*); + sem::ElseStatement* ElseStatement(const ast::ElseStatement*); + sem::Statement* FallthroughStatement(const ast::FallthroughStatement*); + sem::ForLoopStatement* ForLoopStatement(const ast::ForLoopStatement*); + sem::Statement* Parameter(const ast::Variable*); + sem::IfStatement* IfStatement(const ast::IfStatement*); + sem::LoopStatement* LoopStatement(const ast::LoopStatement*); + sem::Statement* ReturnStatement(const ast::ReturnStatement*); + sem::Statement* Statement(const ast::Statement*); + sem::SwitchStatement* SwitchStatement(const ast::SwitchStatement* s); + sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*); bool Statements(const ast::StatementList&); - bool SwitchStatement(const ast::SwitchStatement* s); - bool VariableDeclStatement(const ast::VariableDeclStatement*); + + bool GlobalVariable(const ast::Variable*); // AST and Type validation methods // Each return true on success, false on failure. @@ -224,13 +237,19 @@ class Resolver { bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s); bool ValidateAtomicVariable(const sem::Variable* var); bool ValidateAssignment(const ast::AssignmentStatement* a); + bool ValidateBreakStatement(const sem::Statement* stmt); + bool ValidateContinueStatement(const sem::Statement* stmt); + bool ValidateDiscardStatement(const sem::Statement* stmt); bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco, const sem::Type* storage_type, const bool is_input); + bool ValidateElseStatement(const sem::ElseStatement* stmt); bool ValidateEntryPoint(const sem::Function* func); + bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt); bool ValidateFunction(const sem::Function* func); bool ValidateFunctionCall(const sem::Call* call); bool ValidateGlobalVariable(const sem::Variable* var); + bool ValidateIfStatement(const sem::IfStatement* stmt); bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco, const sem::Type* storage_type); bool ValidateIntrinsicCall(const sem::Call* call); @@ -369,14 +388,19 @@ class Resolver { /// @param lit the literal sem::Type* TypeOf(const ast::LiteralExpression* lit); - /// Assigns `stmt` to #current_statement_, #current_compound_statement_, and - /// possibly #current_block_, pushes the variable scope, then calls - /// `callback`. Before returning #current_statement_, - /// #current_compound_statement_, and #current_block_ are restored to their - /// original values, and the variable scope is popped. - /// @returns the value returned by callback - template - bool Scope(sem::CompoundStatement* stmt, F&& callback); + /// StatementScope() does the following: + /// * Creates the AST -> SEM mapping. + /// * Assigns `sem` to #current_statement_ + /// * Assigns `sem` to #current_compound_statement_ if `sem` derives from + /// sem::CompoundStatement. + /// * Assigns `sem` to #current_block_ if `sem` derives from + /// sem::BlockStatement. + /// * Then calls `callback`. + /// * Before returning #current_statement_, #current_compound_statement_, and + /// #current_block_ are restored to their original values. + /// @returns `sem` if `callback` returns true, otherwise `nullptr`. + template + SEM* StatementScope(const ast::Statement* ast, SEM* sem, F&& callback); /// Returns a human-readable string representation of the vector type name /// with the given parameters. diff --git a/src/resolver/resolver_validation.cc b/src/resolver/resolver_validation.cc index 5698138229..b7cb04ee18 100644 --- a/src/resolver/resolver_validation.cc +++ b/src/resolver/resolver_validation.cc @@ -1344,6 +1344,82 @@ bool Resolver::ValidateStatements(const ast::StatementList& stmts) { return true; } +bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) { + if (!stmt->FindFirstParent() && + !stmt->FindFirstParent()) { + AddError("break statement must be in a loop or switch case", + stmt->Declaration()->source); + return false; + } + return true; +} + +bool Resolver::ValidateContinueStatement(const sem::Statement* stmt) { + if (auto* block = + stmt->FindFirstParent()) { + if (block->Is()) { + AddError("continuing blocks must not contain a continue statement", + stmt->Declaration()->source); + return false; + } + } else { + AddError("continue statement must be in a loop", + stmt->Declaration()->source); + return false; + } + + return true; +} + +bool Resolver::ValidateDiscardStatement(const sem::Statement* stmt) { + if (auto* continuing = + stmt->FindFirstParent()) { + AddError("continuing blocks must not contain a discard statement", + stmt->Declaration()->source); + if (continuing != stmt->Parent()) { + AddNote("see continuing block here", continuing->Declaration()->source); + } + return false; + } + return true; +} + +bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) { + if (auto* cond = stmt->Condition()) { + auto* cond_ty = cond->Type()->UnwrapRef(); + if (!cond_ty->Is()) { + AddError( + "else statement condition must be bool, got " + TypeNameOf(cond_ty), + stmt->Condition()->Declaration()->source); + return false; + } + } + return true; +} + +bool Resolver::ValidateForLoopStatement(const sem::ForLoopStatement* stmt) { + if (auto* cond = stmt->Condition()) { + auto* cond_ty = cond->Type()->UnwrapRef(); + if (!cond_ty->Is()) { + AddError("for-loop condition must be bool, got " + TypeNameOf(cond_ty), + stmt->Condition()->Declaration()->source); + return false; + } + } + return true; +} + +bool Resolver::ValidateIfStatement(const sem::IfStatement* stmt) { + auto* cond_ty = stmt->Condition()->Type()->UnwrapRef(); + if (!cond_ty->Is()) { + AddError("if statement condition must be bool, got " + TypeNameOf(cond_ty), + stmt->Condition()->Declaration()->source); + return false; + } + return true; +} + bool Resolver::ValidateIntrinsicCall(const sem::Call* call) { if (call->Type()->Is()) { bool is_call_statement = false; @@ -2103,8 +2179,8 @@ bool Resolver::ValidateReturn(const ast::ReturnStatement* ret) { } bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) { - auto* cond_type = TypeOf(s->condition)->UnwrapRef(); - if (!cond_type->is_integer_scalar()) { + auto* cond_ty = TypeOf(s->condition)->UnwrapRef(); + if (!cond_ty->is_integer_scalar()) { AddError( "switch statement selector expression must be of a " "scalar integer type", @@ -2127,7 +2203,7 @@ bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) { } for (auto* selector : case_stmt->selectors) { - if (cond_type != TypeOf(selector)) { + if (cond_ty != TypeOf(selector)) { AddError( "the case selector values must have the same " "type as the selector expression.", diff --git a/src/sem/for_loop_statement.h b/src/sem/for_loop_statement.h index ff89241440..2f287bb0bb 100644 --- a/src/sem/for_loop_statement.h +++ b/src/sem/for_loop_statement.h @@ -21,6 +21,9 @@ namespace tint { namespace ast { class ForLoopStatement; } // namespace ast +namespace sem { +class Expression; +} // namespace sem } // namespace tint namespace tint { @@ -39,6 +42,16 @@ class ForLoopStatement : public Castable { /// Destructor ~ForLoopStatement() override; + + /// @returns the for-loop condition expression + const Expression* Condition() const { return condition_; } + + /// Sets the for-loop condition expression + /// @param condition the for-loop condition expression + void SetCondition(const Expression* condition) { condition_ = condition; } + + private: + const Expression* condition_ = nullptr; }; } // namespace sem diff --git a/src/sem/if_statement.h b/src/sem/if_statement.h index 6c25fcab01..a8c9c2eb9a 100644 --- a/src/sem/if_statement.h +++ b/src/sem/if_statement.h @@ -23,6 +23,9 @@ namespace ast { class IfStatement; class ElseStatement; } // namespace ast +namespace sem { +class Expression; +} // namespace sem } // namespace tint namespace tint { @@ -41,6 +44,16 @@ class IfStatement : public Castable { /// Destructor ~IfStatement() override; + + /// @returns the if-statement condition expression + const Expression* Condition() const { return condition_; } + + /// Sets the if-statement condition expression + /// @param condition the if condition expression + void SetCondition(const Expression* condition) { condition_ = condition; } + + private: + const Expression* condition_ = nullptr; }; /// Holds semantic information about an else statement @@ -56,6 +69,16 @@ class ElseStatement : public Castable { /// Destructor ~ElseStatement() override; + + /// @returns the else-statement condition expression + const Expression* Condition() const { return condition_; } + + /// Sets the else-statement condition expression + /// @param condition the else condition expression + void SetCondition(const Expression* condition) { condition_ = condition; } + + private: + const Expression* condition_ = nullptr; }; } // namespace sem diff --git a/src/sem/member_accessor_expression.h b/src/sem/member_accessor_expression.h index fcc6c6fbc3..6d444f0338 100644 --- a/src/sem/member_accessor_expression.h +++ b/src/sem/member_accessor_expression.h @@ -82,7 +82,7 @@ class Swizzle : public Castable { /// Constructor /// @param declaration the AST node /// @param type the resolved type of the expression - /// @param statement the statement that + /// @param statement the statement that owns this expression /// @param indices the swizzle indices Swizzle(const ast::MemberAccessorExpression* declaration, const sem::Type* type, diff --git a/src/sem/variable.h b/src/sem/variable.h index a389eed70d..ac7dac151d 100644 --- a/src/sem/variable.h +++ b/src/sem/variable.h @@ -234,7 +234,7 @@ class VariableUser : public Castable { const sem::Variable* Variable() const { return variable_; } private: - sem::Variable const* const variable_; + const sem::Variable* const variable_; }; } // namespace sem