sem: Replace SwitchCaseBlockStatement with CaseStatement

The SwitchCaseBlockStatement was bound to the BlockStatement of an ast::CaseStatement, but we had nothing that mapped to the actual ast::CaseStatement.
sem::CaseStatement replaces sem::SwitchCaseBlockStatement, and has a Block() accessor, providing a superset of the old behavior.

With this, we can now easily validate the `fallthrough` rules directly, instead of scanning the switch case. This keeps the validation more tigtly coupled to the ast / sem nodes.

Change-Id: I0f22eba37bb164b9e071a6166c7a41fc1a5ac532
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/71460
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2021-12-03 15:23:52 +00:00 committed by Tint LUCI CQ
parent 8e39ffd512
commit bf39c8fb19
9 changed files with 103 additions and 55 deletions

View File

@ -105,11 +105,15 @@ Semantic tree:
```
sem::SwitchStatement {
sem::Expression condition
sem::SwitchCaseBlockStatement {
sem::Statement statement_a
sem::CaseStatement {
sem::BlockStatement {
sem::Statement statement_a
}
}
sem::SwitchCaseBlockStatement {
sem::Statement statement_b
sem::CaseStatement {
sem::BlockStatement {
sem::Statement statement_b
}
}
}
```

View File

@ -343,31 +343,34 @@ TEST_F(ResolverCompoundStatementTest, Switch) {
{
auto* s = Sem().Get(stmt_a);
ASSERT_NE(s, nullptr);
EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::SwitchCaseBlockStatement>());
EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::BlockStatement>());
EXPECT_EQ(s->Parent(), s->Block());
EXPECT_EQ(s->Parent()->Parent(),
s->FindFirstParent<sem::SwitchStatement>());
EXPECT_EQ(s->Parent()->Parent(), s->FindFirstParent<sem::CaseStatement>());
EXPECT_EQ(s->Parent()->Parent()->Parent(),
s->FindFirstParent<sem::SwitchStatement>());
EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent(),
s->FindFirstParent<sem::FunctionBlockStatement>());
}
{
auto* s = Sem().Get(stmt_b);
ASSERT_NE(s, nullptr);
EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::SwitchCaseBlockStatement>());
EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::BlockStatement>());
EXPECT_EQ(s->Parent(), s->Block());
EXPECT_EQ(s->Parent()->Parent(),
s->FindFirstParent<sem::SwitchStatement>());
EXPECT_EQ(s->Parent()->Parent(), s->FindFirstParent<sem::CaseStatement>());
EXPECT_EQ(s->Parent()->Parent()->Parent(),
s->FindFirstParent<sem::SwitchStatement>());
EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent(),
s->FindFirstParent<sem::FunctionBlockStatement>());
}
{
auto* s = Sem().Get(stmt_c);
ASSERT_NE(s, nullptr);
EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::SwitchCaseBlockStatement>());
EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::BlockStatement>());
EXPECT_EQ(s->Parent(), s->Block());
EXPECT_EQ(s->Parent()->Parent(),
s->FindFirstParent<sem::SwitchStatement>());
EXPECT_EQ(s->Parent()->Parent(), s->FindFirstParent<sem::CaseStatement>());
EXPECT_EQ(s->Parent()->Parent()->Parent(),
s->FindFirstParent<sem::SwitchStatement>());
EXPECT_EQ(s->Parent()->Parent()->Parent()->Parent(),
s->FindFirstParent<sem::FunctionBlockStatement>());
}
}

View File

@ -294,8 +294,8 @@ TEST_F(ResolverControlBlockValidationTest,
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: a fallthrough statement must not appear as the last "
"statement in last clause of a switch");
"12:34 error: a fallthrough statement must not be used in the last "
"switch case");
}
TEST_F(ResolverControlBlockValidationTest, SwitchCase_Pass) {

View File

@ -874,17 +874,20 @@ sem::Statement* Resolver::Statement(const ast::Statement* stmt) {
return nullptr;
}
sem::SwitchCaseBlockStatement* Resolver::CaseStatement(
const ast::CaseStatement* stmt) {
auto* sem = builder_->create<sem::SwitchCaseBlockStatement>(
stmt->body, current_compound_statement_, current_function_);
sem::CaseStatement* Resolver::CaseStatement(const ast::CaseStatement* stmt) {
auto* sem = builder_->create<sem::CaseStatement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
builder_->Sem().Add(stmt->body, sem);
Mark(stmt->body);
for (auto* sel : stmt->selectors) {
Mark(sel);
}
return Statements(stmt->body->statements);
Mark(stmt->body);
auto* body = BlockStatement(stmt->body);
if (!body) {
return false;
}
sem->SetBlock(body);
return true;
});
}
@ -2361,7 +2364,9 @@ sem::Statement* Resolver::FallthroughStatement(
const ast::FallthroughStatement* stmt) {
auto* sem = builder_->create<sem::Statement>(
stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] { return true; });
return StatementScope(stmt, sem, [&] {
return ValidateFallthroughStatement(sem);
});
}
bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,

View File

@ -60,13 +60,13 @@ namespace sem {
class Array;
class Atomic;
class BlockStatement;
class CaseStatement;
class ElseStatement;
class ForLoopStatement;
class IfStatement;
class Intrinsic;
class LoopStatement;
class Statement;
class SwitchCaseBlockStatement;
class SwitchStatement;
class TypeConstructor;
} // namespace sem
@ -209,7 +209,7 @@ class Resolver {
sem::BlockStatement* BlockStatement(const ast::BlockStatement*);
sem::Statement* BreakStatement(const ast::BreakStatement*);
sem::Statement* CallStatement(const ast::CallStatement*);
sem::SwitchCaseBlockStatement* CaseStatement(const ast::CaseStatement*);
sem::CaseStatement* CaseStatement(const ast::CaseStatement*);
sem::Statement* ContinueStatement(const ast::ContinueStatement*);
sem::Statement* DiscardStatement(const ast::DiscardStatement*);
sem::ElseStatement* ElseStatement(const ast::ElseStatement*);
@ -238,14 +238,15 @@ class Resolver {
bool ValidateAtomicVariable(const sem::Variable* var);
bool ValidateAssignment(const ast::AssignmentStatement* a);
bool ValidateBreakStatement(const sem::Statement* stmt);
bool ValidateContinueStatement(const sem::Statement* stmt);
bool ValidateDiscardStatement(const sem::Statement* stmt);
bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
const sem::Type* storage_type,
const bool is_input);
bool ValidateContinueStatement(const sem::Statement* stmt);
bool ValidateDiscardStatement(const sem::Statement* stmt);
bool ValidateElseStatement(const sem::ElseStatement* stmt);
bool ValidateEntryPoint(const sem::Function* func);
bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt);
bool ValidateFallthroughStatement(const sem::Statement* stmt);
bool ValidateFunction(const sem::Function* func);
bool ValidateFunctionCall(const sem::Call* call);
bool ValidateGlobalVariable(const sem::Variable* var);

View File

@ -1346,7 +1346,7 @@ bool Resolver::ValidateStatements(const ast::StatementList& stmts) {
bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) {
if (!stmt->FindFirstParent<sem::LoopBlockStatement>() &&
!stmt->FindFirstParent<sem::SwitchCaseBlockStatement>()) {
!stmt->FindFirstParent<sem::CaseStatement>()) {
AddError("break statement must be in a loop or switch case",
stmt->Declaration()->source);
return false;
@ -1385,6 +1385,29 @@ bool Resolver::ValidateDiscardStatement(const sem::Statement* stmt) {
return true;
}
bool Resolver::ValidateFallthroughStatement(const sem::Statement* stmt) {
if (auto* block = As<sem::BlockStatement>(stmt->Parent())) {
if (auto* c = As<sem::CaseStatement>(block->Parent())) {
if (block->Declaration()->Last() == stmt->Declaration()) {
if (auto* s = As<sem::SwitchStatement>(c->Parent())) {
if (c->Declaration() != s->Declaration()->body.back()) {
return true;
}
AddError(
"a fallthrough statement must not be used in the last switch "
"case",
stmt->Declaration()->source);
return false;
}
}
}
}
AddError(
"fallthrough must only be used as the last statement of a case block",
stmt->Declaration()->source);
return false;
}
bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) {
if (auto* cond = stmt->Condition()) {
auto* cond_ty = cond->Type()->UnwrapRef();
@ -2231,18 +2254,6 @@ bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) {
return false;
}
if (!s->body.empty()) {
auto* last_clause = s->body.back()->As<ast::CaseStatement>();
auto* last_stmt = last_clause->body->Last();
if (last_stmt && last_stmt->Is<ast::FallthroughStatement>()) {
AddError(
"a fallthrough statement must not appear as "
"the last statement in last clause of a switch",
last_stmt->source);
return false;
}
}
return true;
}

View File

@ -16,8 +16,8 @@
#include "src/program_builder.h"
TINT_INSTANTIATE_TYPEINFO(tint::sem::CaseStatement);
TINT_INSTANTIATE_TYPEINFO(tint::sem::SwitchStatement);
TINT_INSTANTIATE_TYPEINFO(tint::sem::SwitchCaseBlockStatement);
namespace tint {
namespace sem {
@ -32,15 +32,22 @@ SwitchStatement::SwitchStatement(const ast::SwitchStatement* declaration,
SwitchStatement::~SwitchStatement() = default;
SwitchCaseBlockStatement::SwitchCaseBlockStatement(
const ast::BlockStatement* declaration,
const CompoundStatement* parent,
const sem::Function* function)
const ast::SwitchStatement* SwitchStatement::Declaration() const {
return static_cast<const ast::SwitchStatement*>(Base::Declaration());
}
CaseStatement::CaseStatement(const ast::CaseStatement* declaration,
const CompoundStatement* parent,
const sem::Function* function)
: Base(declaration, parent, function) {
TINT_ASSERT(Semantic, parent);
TINT_ASSERT(Semantic, function);
}
SwitchCaseBlockStatement::~SwitchCaseBlockStatement() = default;
CaseStatement::~CaseStatement() = default;
const ast::CaseStatement* CaseStatement::Declaration() const {
return static_cast<const ast::CaseStatement*>(Base::Declaration());
}
} // namespace sem
} // namespace tint

View File

@ -20,6 +20,7 @@
// Forward declarations
namespace tint {
namespace ast {
class CaseStatement;
class SwitchStatement;
} // namespace ast
} // namespace tint
@ -40,22 +41,36 @@ class SwitchStatement : public Castable<SwitchStatement, CompoundStatement> {
/// Destructor
~SwitchStatement() override;
/// @return the AST node for this statement
const ast::SwitchStatement* Declaration() const;
};
/// Holds semantic information about a switch case block
class SwitchCaseBlockStatement
: public Castable<SwitchCaseBlockStatement, BlockStatement> {
/// Holds semantic information about a switch case statement
class CaseStatement : public Castable<CaseStatement, CompoundStatement> {
public:
/// Constructor
/// @param declaration the AST node for this block statement
/// @param declaration the AST node for this case statement
/// @param parent the owning statement
/// @param function the owning function
SwitchCaseBlockStatement(const ast::BlockStatement* declaration,
const CompoundStatement* parent,
const sem::Function* function);
CaseStatement(const ast::CaseStatement* declaration,
const CompoundStatement* parent,
const sem::Function* function);
/// Destructor
~SwitchCaseBlockStatement() override;
~CaseStatement() override;
/// @return the AST node for this statement
const ast::CaseStatement* Declaration() const;
/// @param body the case body block statement
void SetBlock(const BlockStatement* body) { body_ = body; }
/// @returns the case body block statement
const BlockStatement* Body() const { return body_; }
private:
const BlockStatement* body_ = nullptr;
};
} // namespace sem

View File

@ -23,7 +23,9 @@ using WgslGeneratorImplTest = TestHelper;
TEST_F(WgslGeneratorImplTest, Emit_Fallthrough) {
auto* f = create<ast::FallthroughStatement>();
WrapInFunction(f);
WrapInFunction(Switch(1, //
Case(Expr(1), Block(f)), //
DefaultCase()));
GeneratorImpl& gen = Build();