sem: Replace SwitchCaseBlockStatement with CaseStatement

The SwitchCaseBlockStatement was bound to the BlockStatement of an ast::CaseStatement, but we had nothing that mapped to the actual ast::CaseStatement.
sem::CaseStatement replaces sem::SwitchCaseBlockStatement, and has a Block() accessor, providing a superset of the old behavior.

With this, we can now easily validate the `fallthrough` rules directly, instead of scanning the switch case. This keeps the validation more tigtly coupled to the ast / sem nodes.

Change-Id: I0f22eba37bb164b9e071a6166c7a41fc1a5ac532
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/71460
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2021-12-03 15:23:52 +00:00 committed by Tint LUCI CQ
parent 8e39ffd512
commit bf39c8fb19
9 changed files with 103 additions and 55 deletions

View File

@ -105,11 +105,15 @@ Semantic tree:
``` ```
sem::SwitchStatement { sem::SwitchStatement {
sem::Expression condition sem::Expression condition
sem::SwitchCaseBlockStatement { sem::CaseStatement {
sem::Statement statement_a sem::BlockStatement {
sem::Statement statement_a
}
} }
sem::SwitchCaseBlockStatement { sem::CaseStatement {
sem::Statement statement_b sem::BlockStatement {
sem::Statement statement_b
}
} }
} }
``` ```

View File

@ -343,31 +343,34 @@ TEST_F(ResolverCompoundStatementTest, Switch) {
{ {
auto* s = Sem().Get(stmt_a); auto* s = Sem().Get(stmt_a);
ASSERT_NE(s, nullptr); ASSERT_NE(s, nullptr);
EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::SwitchCaseBlockStatement>()); EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::BlockStatement>());
EXPECT_EQ(s->Parent(), s->Block()); EXPECT_EQ(s->Parent(), s->Block());
EXPECT_EQ(s->Parent()->Parent(), EXPECT_EQ(s->Parent()->Parent(), s->FindFirstParent<sem::CaseStatement>());
s->FindFirstParent<sem::SwitchStatement>());
EXPECT_EQ(s->Parent()->Parent()->Parent(), EXPECT_EQ(s->Parent()->Parent()->Parent(),
s->FindFirstParent<sem::SwitchStatement>());
EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent(),
s->FindFirstParent<sem::FunctionBlockStatement>()); s->FindFirstParent<sem::FunctionBlockStatement>());
} }
{ {
auto* s = Sem().Get(stmt_b); auto* s = Sem().Get(stmt_b);
ASSERT_NE(s, nullptr); ASSERT_NE(s, nullptr);
EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::SwitchCaseBlockStatement>()); EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::BlockStatement>());
EXPECT_EQ(s->Parent(), s->Block()); EXPECT_EQ(s->Parent(), s->Block());
EXPECT_EQ(s->Parent()->Parent(), EXPECT_EQ(s->Parent()->Parent(), s->FindFirstParent<sem::CaseStatement>());
s->FindFirstParent<sem::SwitchStatement>());
EXPECT_EQ(s->Parent()->Parent()->Parent(), EXPECT_EQ(s->Parent()->Parent()->Parent(),
s->FindFirstParent<sem::SwitchStatement>());
EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent(),
s->FindFirstParent<sem::FunctionBlockStatement>()); s->FindFirstParent<sem::FunctionBlockStatement>());
} }
{ {
auto* s = Sem().Get(stmt_c); auto* s = Sem().Get(stmt_c);
ASSERT_NE(s, nullptr); ASSERT_NE(s, nullptr);
EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::SwitchCaseBlockStatement>()); EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::BlockStatement>());
EXPECT_EQ(s->Parent(), s->Block()); EXPECT_EQ(s->Parent(), s->Block());
EXPECT_EQ(s->Parent()->Parent(), EXPECT_EQ(s->Parent()->Parent(), s->FindFirstParent<sem::CaseStatement>());
s->FindFirstParent<sem::SwitchStatement>());
EXPECT_EQ(s->Parent()->Parent()->Parent(), EXPECT_EQ(s->Parent()->Parent()->Parent(),
s->FindFirstParent<sem::SwitchStatement>());
EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent(),
s->FindFirstParent<sem::FunctionBlockStatement>()); s->FindFirstParent<sem::FunctionBlockStatement>());
} }
} }

View File

@ -294,8 +294,8 @@ TEST_F(ResolverControlBlockValidationTest,
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_EQ(r()->error(),
"12:34 error: a fallthrough statement must not appear as the last " "12:34 error: a fallthrough statement must not be used in the last "
"statement in last clause of a switch"); "switch case");
} }
TEST_F(ResolverControlBlockValidationTest, SwitchCase_Pass) { TEST_F(ResolverControlBlockValidationTest, SwitchCase_Pass) {

View File

@ -874,17 +874,20 @@ sem::Statement* Resolver::Statement(const ast::Statement* stmt) {
return nullptr; return nullptr;
} }
sem::SwitchCaseBlockStatement* Resolver::CaseStatement( sem::CaseStatement* Resolver::CaseStatement(const ast::CaseStatement* stmt) {
const ast::CaseStatement* stmt) { auto* sem = builder_->create<sem::CaseStatement>(
auto* sem = builder_->create<sem::SwitchCaseBlockStatement>( stmt, current_compound_statement_, current_function_);
stmt->body, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] { return StatementScope(stmt, sem, [&] {
builder_->Sem().Add(stmt->body, sem);
Mark(stmt->body);
for (auto* sel : stmt->selectors) { for (auto* sel : stmt->selectors) {
Mark(sel); Mark(sel);
} }
return Statements(stmt->body->statements); Mark(stmt->body);
auto* body = BlockStatement(stmt->body);
if (!body) {
return false;
}
sem->SetBlock(body);
return true;
}); });
} }
@ -2361,7 +2364,9 @@ sem::Statement* Resolver::FallthroughStatement(
const ast::FallthroughStatement* stmt) { const ast::FallthroughStatement* stmt) {
auto* sem = builder_->create<sem::Statement>( auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_); stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] { return true; }); return StatementScope(stmt, sem, [&] {
return ValidateFallthroughStatement(sem);
});
} }
bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc, bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,

View File

@ -60,13 +60,13 @@ namespace sem {
class Array; class Array;
class Atomic; class Atomic;
class BlockStatement; class BlockStatement;
class CaseStatement;
class ElseStatement; class ElseStatement;
class ForLoopStatement; class ForLoopStatement;
class IfStatement; class IfStatement;
class Intrinsic; class Intrinsic;
class LoopStatement; class LoopStatement;
class Statement; class Statement;
class SwitchCaseBlockStatement;
class SwitchStatement; class SwitchStatement;
class TypeConstructor; class TypeConstructor;
} // namespace sem } // namespace sem
@ -209,7 +209,7 @@ class Resolver {
sem::BlockStatement* BlockStatement(const ast::BlockStatement*); sem::BlockStatement* BlockStatement(const ast::BlockStatement*);
sem::Statement* BreakStatement(const ast::BreakStatement*); sem::Statement* BreakStatement(const ast::BreakStatement*);
sem::Statement* CallStatement(const ast::CallStatement*); sem::Statement* CallStatement(const ast::CallStatement*);
sem::SwitchCaseBlockStatement* CaseStatement(const ast::CaseStatement*); sem::CaseStatement* CaseStatement(const ast::CaseStatement*);
sem::Statement* ContinueStatement(const ast::ContinueStatement*); sem::Statement* ContinueStatement(const ast::ContinueStatement*);
sem::Statement* DiscardStatement(const ast::DiscardStatement*); sem::Statement* DiscardStatement(const ast::DiscardStatement*);
sem::ElseStatement* ElseStatement(const ast::ElseStatement*); sem::ElseStatement* ElseStatement(const ast::ElseStatement*);
@ -238,14 +238,15 @@ class Resolver {
bool ValidateAtomicVariable(const sem::Variable* var); bool ValidateAtomicVariable(const sem::Variable* var);
bool ValidateAssignment(const ast::AssignmentStatement* a); bool ValidateAssignment(const ast::AssignmentStatement* a);
bool ValidateBreakStatement(const sem::Statement* stmt); bool ValidateBreakStatement(const sem::Statement* stmt);
bool ValidateContinueStatement(const sem::Statement* stmt);
bool ValidateDiscardStatement(const sem::Statement* stmt);
bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco, bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
const sem::Type* storage_type, const sem::Type* storage_type,
const bool is_input); const bool is_input);
bool ValidateContinueStatement(const sem::Statement* stmt);
bool ValidateDiscardStatement(const sem::Statement* stmt);
bool ValidateElseStatement(const sem::ElseStatement* stmt); bool ValidateElseStatement(const sem::ElseStatement* stmt);
bool ValidateEntryPoint(const sem::Function* func); bool ValidateEntryPoint(const sem::Function* func);
bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt); bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt);
bool ValidateFallthroughStatement(const sem::Statement* stmt);
bool ValidateFunction(const sem::Function* func); bool ValidateFunction(const sem::Function* func);
bool ValidateFunctionCall(const sem::Call* call); bool ValidateFunctionCall(const sem::Call* call);
bool ValidateGlobalVariable(const sem::Variable* var); bool ValidateGlobalVariable(const sem::Variable* var);

View File

@ -1346,7 +1346,7 @@ bool Resolver::ValidateStatements(const ast::StatementList& stmts) {
bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) { bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) {
if (!stmt->FindFirstParent<sem::LoopBlockStatement>() && if (!stmt->FindFirstParent<sem::LoopBlockStatement>() &&
!stmt->FindFirstParent<sem::SwitchCaseBlockStatement>()) { !stmt->FindFirstParent<sem::CaseStatement>()) {
AddError("break statement must be in a loop or switch case", AddError("break statement must be in a loop or switch case",
stmt->Declaration()->source); stmt->Declaration()->source);
return false; return false;
@ -1385,6 +1385,29 @@ bool Resolver::ValidateDiscardStatement(const sem::Statement* stmt) {
return true; return true;
} }
bool Resolver::ValidateFallthroughStatement(const sem::Statement* stmt) {
if (auto* block = As<sem::BlockStatement>(stmt->Parent())) {
if (auto* c = As<sem::CaseStatement>(block->Parent())) {
if (block->Declaration()->Last() == stmt->Declaration()) {
if (auto* s = As<sem::SwitchStatement>(c->Parent())) {
if (c->Declaration() != s->Declaration()->body.back()) {
return true;
}
AddError(
"a fallthrough statement must not be used in the last switch "
"case",
stmt->Declaration()->source);
return false;
}
}
}
}
AddError(
"fallthrough must only be used as the last statement of a case block",
stmt->Declaration()->source);
return false;
}
bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) { bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) {
if (auto* cond = stmt->Condition()) { if (auto* cond = stmt->Condition()) {
auto* cond_ty = cond->Type()->UnwrapRef(); auto* cond_ty = cond->Type()->UnwrapRef();
@ -2231,18 +2254,6 @@ bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) {
return false; return false;
} }
if (!s->body.empty()) {
auto* last_clause = s->body.back()->As<ast::CaseStatement>();
auto* last_stmt = last_clause->body->Last();
if (last_stmt && last_stmt->Is<ast::FallthroughStatement>()) {
AddError(
"a fallthrough statement must not appear as "
"the last statement in last clause of a switch",
last_stmt->source);
return false;
}
}
return true; return true;
} }

View File

@ -16,8 +16,8 @@
#include "src/program_builder.h" #include "src/program_builder.h"
TINT_INSTANTIATE_TYPEINFO(tint::sem::CaseStatement);
TINT_INSTANTIATE_TYPEINFO(tint::sem::SwitchStatement); TINT_INSTANTIATE_TYPEINFO(tint::sem::SwitchStatement);
TINT_INSTANTIATE_TYPEINFO(tint::sem::SwitchCaseBlockStatement);
namespace tint { namespace tint {
namespace sem { namespace sem {
@ -32,15 +32,22 @@ SwitchStatement::SwitchStatement(const ast::SwitchStatement* declaration,
SwitchStatement::~SwitchStatement() = default; SwitchStatement::~SwitchStatement() = default;
SwitchCaseBlockStatement::SwitchCaseBlockStatement( const ast::SwitchStatement* SwitchStatement::Declaration() const {
const ast::BlockStatement* declaration, return static_cast<const ast::SwitchStatement*>(Base::Declaration());
const CompoundStatement* parent, }
const sem::Function* function)
CaseStatement::CaseStatement(const ast::CaseStatement* declaration,
const CompoundStatement* parent,
const sem::Function* function)
: Base(declaration, parent, function) { : Base(declaration, parent, function) {
TINT_ASSERT(Semantic, parent); TINT_ASSERT(Semantic, parent);
TINT_ASSERT(Semantic, function); TINT_ASSERT(Semantic, function);
} }
SwitchCaseBlockStatement::~SwitchCaseBlockStatement() = default; CaseStatement::~CaseStatement() = default;
const ast::CaseStatement* CaseStatement::Declaration() const {
return static_cast<const ast::CaseStatement*>(Base::Declaration());
}
} // namespace sem } // namespace sem
} // namespace tint } // namespace tint

View File

@ -20,6 +20,7 @@
// Forward declarations // Forward declarations
namespace tint { namespace tint {
namespace ast { namespace ast {
class CaseStatement;
class SwitchStatement; class SwitchStatement;
} // namespace ast } // namespace ast
} // namespace tint } // namespace tint
@ -40,22 +41,36 @@ class SwitchStatement : public Castable<SwitchStatement, CompoundStatement> {
/// Destructor /// Destructor
~SwitchStatement() override; ~SwitchStatement() override;
/// @return the AST node for this statement
const ast::SwitchStatement* Declaration() const;
}; };
/// Holds semantic information about a switch case block /// Holds semantic information about a switch case statement
class SwitchCaseBlockStatement class CaseStatement : public Castable<CaseStatement, CompoundStatement> {
: public Castable<SwitchCaseBlockStatement, BlockStatement> {
public: public:
/// Constructor /// Constructor
/// @param declaration the AST node for this block statement /// @param declaration the AST node for this case statement
/// @param parent the owning statement /// @param parent the owning statement
/// @param function the owning function /// @param function the owning function
SwitchCaseBlockStatement(const ast::BlockStatement* declaration, CaseStatement(const ast::CaseStatement* declaration,
const CompoundStatement* parent, const CompoundStatement* parent,
const sem::Function* function); const sem::Function* function);
/// Destructor /// Destructor
~SwitchCaseBlockStatement() override; ~CaseStatement() override;
/// @return the AST node for this statement
const ast::CaseStatement* Declaration() const;
/// @param body the case body block statement
void SetBlock(const BlockStatement* body) { body_ = body; }
/// @returns the case body block statement
const BlockStatement* Body() const { return body_; }
private:
const BlockStatement* body_ = nullptr;
}; };
} // namespace sem } // namespace sem

View File

@ -23,7 +23,9 @@ using WgslGeneratorImplTest = TestHelper;
TEST_F(WgslGeneratorImplTest, Emit_Fallthrough) { TEST_F(WgslGeneratorImplTest, Emit_Fallthrough) {
auto* f = create<ast::FallthroughStatement>(); auto* f = create<ast::FallthroughStatement>();
WrapInFunction(f); WrapInFunction(Switch(1, //
Case(Expr(1), Block(f)), //
DefaultCase()));
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();