Replace Statement::(Is|As)* with Castable

Change-Id: I5520752a4b5844be0ecac7921616893d123b246a
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/34315
Reviewed-by: dan sinclair <dsinclair@chromium.org>
This commit is contained in:
Ben Clayton 2020-11-30 23:30:58 +00:00
parent 4d3ca7f132
commit 1d8098ae94
75 changed files with 331 additions and 671 deletions

View File

@ -31,10 +31,6 @@ AssignmentStatement::AssignmentStatement(AssignmentStatement&&) = default;
AssignmentStatement::~AssignmentStatement() = default;
bool AssignmentStatement::IsAssign() const {
return true;
}
bool AssignmentStatement::IsValid() const {
if (lhs_ == nullptr || !lhs_->IsValid())
return false;

View File

@ -55,9 +55,6 @@ class AssignmentStatement : public Castable<AssignmentStatement, Statement> {
/// @returns the right side expression
Expression* rhs() const { return rhs_; }
/// @returns true if this is an assignment statement
bool IsAssign() const override;
/// @returns true if the node is valid
bool IsValid() const override;

View File

@ -47,7 +47,7 @@ TEST_F(AssignmentStatementTest, IsAssign) {
auto* rhs = create<ast::IdentifierExpression>("rhs");
AssignmentStatement stmt(lhs, rhs);
EXPECT_TRUE(stmt.IsAssign());
EXPECT_TRUE(stmt.Is<AssignmentStatement>());
}
TEST_F(AssignmentStatementTest, IsValid) {

View File

@ -25,10 +25,6 @@ BlockStatement::BlockStatement(BlockStatement&&) = default;
BlockStatement::~BlockStatement() = default;
bool BlockStatement::IsBlock() const {
return true;
}
bool BlockStatement::IsValid() const {
for (auto* stmt : *this) {
if (stmt == nullptr || !stmt->IsValid()) {

View File

@ -87,9 +87,6 @@ class BlockStatement : public Castable<BlockStatement, Statement> {
return statements_.end();
}
/// @returns true if this is a block statement
bool IsBlock() const override;
/// @returns true if the node is valid
bool IsValid() const override;

View File

@ -65,7 +65,7 @@ TEST_F(BlockStatementTest, Creation_WithSource) {
TEST_F(BlockStatementTest, IsBlock) {
BlockStatement b;
EXPECT_TRUE(b.IsBlock());
EXPECT_TRUE(b.Is<BlockStatement>());
}
TEST_F(BlockStatementTest, IsValid) {

View File

@ -25,10 +25,6 @@ BreakStatement::BreakStatement(BreakStatement&&) = default;
BreakStatement::~BreakStatement() = default;
bool BreakStatement::IsBreak() const {
return true;
}
bool BreakStatement::IsValid() const {
return true;
}

View File

@ -32,9 +32,6 @@ class BreakStatement : public Castable<BreakStatement, Statement> {
BreakStatement(BreakStatement&&);
~BreakStatement() override;
/// @returns true if this is an break statement
bool IsBreak() const override;
/// @returns true if the node is valid
bool IsValid() const override;

View File

@ -31,7 +31,7 @@ TEST_F(BreakStatementTest, Creation_WithSource) {
TEST_F(BreakStatementTest, IsBreak) {
BreakStatement stmt;
EXPECT_TRUE(stmt.IsBreak());
EXPECT_TRUE(stmt.Is<BreakStatement>());
}
TEST_F(BreakStatementTest, IsValid) {

View File

@ -27,10 +27,6 @@ CallStatement::CallStatement(CallStatement&&) = default;
CallStatement::~CallStatement() = default;
bool CallStatement::IsCall() const {
return true;
}
bool CallStatement::IsValid() const {
return call_ != nullptr && call_->IsValid();
}

View File

@ -42,9 +42,6 @@ class CallStatement : public Castable<CallStatement, Statement> {
/// @returns the call expression
CallExpression* expr() const { return call_; }
/// @returns true if this is a call statement
bool IsCall() const override;
/// @returns true if the node is valid
bool IsValid() const override;

View File

@ -34,7 +34,7 @@ TEST_F(CallStatementTest, Creation) {
TEST_F(CallStatementTest, IsCall) {
CallStatement c;
EXPECT_TRUE(c.IsCall());
EXPECT_TRUE(c.Is<CallStatement>());
}
TEST_F(CallStatementTest, IsValid) {

View File

@ -31,10 +31,6 @@ CaseStatement::CaseStatement(CaseStatement&&) = default;
CaseStatement::~CaseStatement() = default;
bool CaseStatement::IsCase() const {
return true;
}
bool CaseStatement::IsValid() const {
return body_ != nullptr && body_->IsValid();
}

View File

@ -70,9 +70,6 @@ class CaseStatement : public Castable<CaseStatement, Statement> {
/// @returns the case body
BlockStatement* body() { return body_; }
/// @returns true if this is a case statement
bool IsCase() const override;
/// @returns true if the node is valid
bool IsValid() const override;

View File

@ -99,7 +99,7 @@ TEST_F(CaseStatementTest, IsDefault_WithSelectors) {
TEST_F(CaseStatementTest, IsCase) {
CaseStatement c(create<ast::BlockStatement>());
EXPECT_TRUE(c.IsCase());
EXPECT_TRUE(c.Is<ast::CaseStatement>());
}
TEST_F(CaseStatementTest, IsValid) {

View File

@ -25,10 +25,6 @@ ContinueStatement::ContinueStatement(ContinueStatement&&) = default;
ContinueStatement::~ContinueStatement() = default;
bool ContinueStatement::IsContinue() const {
return true;
}
bool ContinueStatement::IsValid() const {
return true;
}

View File

@ -35,9 +35,6 @@ class ContinueStatement : public Castable<ContinueStatement, Statement> {
ContinueStatement(ContinueStatement&&);
~ContinueStatement() override;
/// @returns true if this is an continue statement
bool IsContinue() const override;
/// @returns true if the node is valid
bool IsValid() const override;

View File

@ -31,7 +31,7 @@ TEST_F(ContinueStatementTest, Creation_WithSource) {
TEST_F(ContinueStatementTest, IsContinue) {
ContinueStatement stmt;
EXPECT_TRUE(stmt.IsContinue());
EXPECT_TRUE(stmt.Is<ContinueStatement>());
}
TEST_F(ContinueStatementTest, IsValid) {

View File

@ -25,10 +25,6 @@ DiscardStatement::DiscardStatement(DiscardStatement&&) = default;
DiscardStatement::~DiscardStatement() = default;
bool DiscardStatement::IsDiscard() const {
return true;
}
bool DiscardStatement::IsValid() const {
return true;
}

View File

@ -32,9 +32,6 @@ class DiscardStatement : public Castable<DiscardStatement, Statement> {
DiscardStatement(DiscardStatement&&);
~DiscardStatement() override;
/// @returns true if this is a discard statement
bool IsDiscard() const override;
/// @returns true if the node is valid
bool IsValid() const override;

View File

@ -43,7 +43,7 @@ TEST_F(DiscardStatementTest, Creation_WithSource) {
TEST_F(DiscardStatementTest, IsDiscard) {
DiscardStatement stmt;
EXPECT_TRUE(stmt.IsDiscard());
EXPECT_TRUE(stmt.Is<DiscardStatement>());
}
TEST_F(DiscardStatementTest, IsValid) {

View File

@ -34,10 +34,6 @@ ElseStatement::ElseStatement(ElseStatement&&) = default;
ElseStatement::~ElseStatement() = default;
bool ElseStatement::IsElse() const {
return true;
}
bool ElseStatement::IsValid() const {
if (body_ == nullptr || !body_->IsValid()) {
return false;

View File

@ -67,9 +67,6 @@ class ElseStatement : public Castable<ElseStatement, Statement> {
/// @returns the else body
BlockStatement* body() { return body_; }
/// @returns true if this is a else statement
bool IsElse() const override;
/// @returns true if the node is valid
bool IsValid() const override;

View File

@ -51,7 +51,7 @@ TEST_F(ElseStatementTest, Creation_WithSource) {
TEST_F(ElseStatementTest, IsElse) {
ElseStatement e(create<BlockStatement>());
EXPECT_TRUE(e.IsElse());
EXPECT_TRUE(e.Is<ElseStatement>());
}
TEST_F(ElseStatementTest, HasCondition) {

View File

@ -26,10 +26,6 @@ FallthroughStatement::FallthroughStatement(FallthroughStatement&&) = default;
FallthroughStatement::~FallthroughStatement() = default;
bool FallthroughStatement::IsFallthrough() const {
return true;
}
bool FallthroughStatement::IsValid() const {
return true;
}

View File

@ -32,9 +32,6 @@ class FallthroughStatement : public Castable<FallthroughStatement, Statement> {
FallthroughStatement(FallthroughStatement&&);
~FallthroughStatement() override;
/// @returns true if this is an fallthrough statement
bool IsFallthrough() const override;
/// @returns true if the node is valid
bool IsValid() const override;

View File

@ -39,7 +39,7 @@ TEST_F(FallthroughStatementTest, Creation_WithSource) {
TEST_F(FallthroughStatementTest, IsFallthrough) {
FallthroughStatement stmt;
EXPECT_TRUE(stmt.IsFallthrough());
EXPECT_TRUE(stmt.Is<FallthroughStatement>());
}
TEST_F(FallthroughStatementTest, IsValid) {

View File

@ -31,10 +31,6 @@ IfStatement::IfStatement(IfStatement&&) = default;
IfStatement::~IfStatement() = default;
bool IfStatement::IsIf() const {
return true;
}
bool IfStatement::IsValid() const {
if (condition_ == nullptr || !condition_->IsValid()) {
return false;

View File

@ -71,9 +71,6 @@ class IfStatement : public Castable<IfStatement, Statement> {
/// @returns true if there are else statements
bool has_else_statements() const { return !else_statements_.empty(); }
/// @returns true if this is a if statement
bool IsIf() const override;
/// @returns true if the node is valid
bool IsValid() const override;

View File

@ -49,7 +49,7 @@ TEST_F(IfStatementTest, Creation_WithSource) {
TEST_F(IfStatementTest, IsIf) {
IfStatement stmt(nullptr, create<BlockStatement>());
EXPECT_TRUE(stmt.IsIf());
EXPECT_TRUE(stmt.Is<IfStatement>());
}
TEST_F(IfStatementTest, IsValid) {

View File

@ -29,10 +29,6 @@ LoopStatement::LoopStatement(LoopStatement&&) = default;
LoopStatement::~LoopStatement() = default;
bool LoopStatement::IsLoop() const {
return true;
}
bool LoopStatement::IsValid() const {
if (body_ == nullptr || !body_->IsValid()) {
return false;

View File

@ -62,9 +62,6 @@ class LoopStatement : public Castable<LoopStatement, Statement> {
return continuing_ != nullptr && !continuing_->empty();
}
/// @returns true if this is a loop statement
bool IsLoop() const override;
/// @returns true if the node is valid
bool IsValid() const override;

View File

@ -57,7 +57,7 @@ TEST_F(LoopStatementTest, Creation_WithSource) {
TEST_F(LoopStatementTest, IsLoop) {
LoopStatement l(create<BlockStatement>(), create<BlockStatement>());
EXPECT_TRUE(l.IsLoop());
EXPECT_TRUE(l.Is<LoopStatement>());
}
TEST_F(LoopStatementTest, HasContinuing_WithoutContinuing) {

View File

@ -30,10 +30,6 @@ ReturnStatement::ReturnStatement(ReturnStatement&&) = default;
ReturnStatement::~ReturnStatement() = default;
bool ReturnStatement::IsReturn() const {
return true;
}
bool ReturnStatement::IsValid() const {
if (value_ != nullptr) {
return value_->IsValid();

View File

@ -51,9 +51,6 @@ class ReturnStatement : public Castable<ReturnStatement, Statement> {
/// @returns true if the return has a value
bool has_value() const { return value_ != nullptr; }
/// @returns true if this is a return statement
bool IsReturn() const override;
/// @returns true if the node is valid
bool IsValid() const override;

View File

@ -41,7 +41,7 @@ TEST_F(ReturnStatementTest, Creation_WithSource) {
TEST_F(ReturnStatementTest, IsReturn) {
ReturnStatement r;
EXPECT_TRUE(r.IsReturn());
EXPECT_TRUE(r.Is<ReturnStatement>());
}
TEST_F(ReturnStatementTest, HasValue_WithoutValue) {

View File

@ -42,247 +42,51 @@ Statement::Statement(Statement&&) = default;
Statement::~Statement() = default;
bool Statement::IsAssign() const {
return false;
}
bool Statement::IsBlock() const {
return false;
}
bool Statement::IsBreak() const {
return false;
}
bool Statement::IsCase() const {
return false;
}
bool Statement::IsCall() const {
return false;
}
bool Statement::IsContinue() const {
return false;
}
bool Statement::IsDiscard() const {
return false;
}
bool Statement::IsElse() const {
return false;
}
bool Statement::IsFallthrough() const {
return false;
}
bool Statement::IsIf() const {
return false;
}
bool Statement::IsLoop() const {
return false;
}
bool Statement::IsReturn() const {
return false;
}
bool Statement::IsSwitch() const {
return false;
}
bool Statement::IsVariableDecl() const {
return false;
}
const char* Statement::Name() const {
if (IsAssign()) {
if (Is<AssignmentStatement>()) {
return "assignment statement";
}
if (IsBlock()) {
if (Is<BlockStatement>()) {
return "block statement";
}
if (IsBreak()) {
if (Is<BreakStatement>()) {
return "break statement";
}
if (IsCase()) {
if (Is<CaseStatement>()) {
return "case statement";
}
if (IsCall()) {
if (Is<CallStatement>()) {
return "function call";
}
if (IsContinue()) {
if (Is<ContinueStatement>()) {
return "continue statement";
}
if (IsDiscard()) {
if (Is<DiscardStatement>()) {
return "discard statement";
}
if (IsElse()) {
if (Is<ElseStatement>()) {
return "else statement";
}
if (IsFallthrough()) {
if (Is<FallthroughStatement>()) {
return "fallthrough statement";
}
if (IsIf()) {
if (Is<IfStatement>()) {
return "if statement";
}
if (IsLoop()) {
if (Is<LoopStatement>()) {
return "loop statement";
}
if (IsReturn()) {
if (Is<ReturnStatement>()) {
return "return statement";
}
if (IsSwitch()) {
if (Is<SwitchStatement>()) {
return "switch statement";
}
if (IsVariableDecl()) {
if (Is<VariableDeclStatement>()) {
return "variable declaration";
}
return "statement";
}
const AssignmentStatement* Statement::AsAssign() const {
assert(IsAssign());
return static_cast<const AssignmentStatement*>(this);
}
const BlockStatement* Statement::AsBlock() const {
assert(IsBlock());
return static_cast<const BlockStatement*>(this);
}
const BreakStatement* Statement::AsBreak() const {
assert(IsBreak());
return static_cast<const BreakStatement*>(this);
}
const CallStatement* Statement::AsCall() const {
assert(IsCall());
return static_cast<const CallStatement*>(this);
}
const CaseStatement* Statement::AsCase() const {
assert(IsCase());
return static_cast<const CaseStatement*>(this);
}
const ContinueStatement* Statement::AsContinue() const {
assert(IsContinue());
return static_cast<const ContinueStatement*>(this);
}
const DiscardStatement* Statement::AsDiscard() const {
assert(IsDiscard());
return static_cast<const DiscardStatement*>(this);
}
const ElseStatement* Statement::AsElse() const {
assert(IsElse());
return static_cast<const ElseStatement*>(this);
}
const FallthroughStatement* Statement::AsFallthrough() const {
assert(IsFallthrough());
return static_cast<const FallthroughStatement*>(this);
}
const IfStatement* Statement::AsIf() const {
assert(IsIf());
return static_cast<const IfStatement*>(this);
}
const LoopStatement* Statement::AsLoop() const {
assert(IsLoop());
return static_cast<const LoopStatement*>(this);
}
const ReturnStatement* Statement::AsReturn() const {
assert(IsReturn());
return static_cast<const ReturnStatement*>(this);
}
const SwitchStatement* Statement::AsSwitch() const {
assert(IsSwitch());
return static_cast<const SwitchStatement*>(this);
}
const VariableDeclStatement* Statement::AsVariableDecl() const {
assert(IsVariableDecl());
return static_cast<const VariableDeclStatement*>(this);
}
AssignmentStatement* Statement::AsAssign() {
assert(IsAssign());
return static_cast<AssignmentStatement*>(this);
}
BlockStatement* Statement::AsBlock() {
assert(IsBlock());
return static_cast<BlockStatement*>(this);
}
BreakStatement* Statement::AsBreak() {
assert(IsBreak());
return static_cast<BreakStatement*>(this);
}
CallStatement* Statement::AsCall() {
assert(IsCall());
return static_cast<CallStatement*>(this);
}
CaseStatement* Statement::AsCase() {
assert(IsCase());
return static_cast<CaseStatement*>(this);
}
ContinueStatement* Statement::AsContinue() {
assert(IsContinue());
return static_cast<ContinueStatement*>(this);
}
DiscardStatement* Statement::AsDiscard() {
assert(IsDiscard());
return static_cast<DiscardStatement*>(this);
}
ElseStatement* Statement::AsElse() {
assert(IsElse());
return static_cast<ElseStatement*>(this);
}
FallthroughStatement* Statement::AsFallthrough() {
assert(IsFallthrough());
return static_cast<FallthroughStatement*>(this);
}
IfStatement* Statement::AsIf() {
assert(IsIf());
return static_cast<IfStatement*>(this);
}
LoopStatement* Statement::AsLoop() {
assert(IsLoop());
return static_cast<LoopStatement*>(this);
}
ReturnStatement* Statement::AsReturn() {
assert(IsReturn());
return static_cast<ReturnStatement*>(this);
}
SwitchStatement* Statement::AsSwitch() {
assert(IsSwitch());
return static_cast<SwitchStatement*>(this);
}
VariableDeclStatement* Statement::AsVariableDecl() {
assert(IsVariableDecl());
return static_cast<VariableDeclStatement*>(this);
}
} // namespace ast
} // namespace tint

View File

@ -23,116 +23,14 @@
namespace tint {
namespace ast {
class AssignmentStatement;
class BlockStatement;
class BreakStatement;
class CallStatement;
class CaseStatement;
class ContinueStatement;
class DiscardStatement;
class ElseStatement;
class FallthroughStatement;
class IfStatement;
class LoopStatement;
class ReturnStatement;
class SwitchStatement;
class VariableDeclStatement;
/// Base statement class
class Statement : public Castable<Statement, Node> {
public:
~Statement() override;
/// @returns true if this is an assign statement
virtual bool IsAssign() const;
/// @returns true if this is a block statement
virtual bool IsBlock() const;
/// @returns true if this is a break statement
virtual bool IsBreak() const;
/// @returns true if this is a call statement
virtual bool IsCall() const;
/// @returns true if this is a case statement
virtual bool IsCase() const;
/// @returns true if this is a continue statement
virtual bool IsContinue() const;
/// @returns true if this is a discard statement
virtual bool IsDiscard() const;
/// @returns true if this is an else statement
virtual bool IsElse() const;
/// @returns true if this is a fallthrough statement
virtual bool IsFallthrough() const;
/// @returns true if this is an if statement
virtual bool IsIf() const;
/// @returns true if this is a loop statement
virtual bool IsLoop() const;
/// @returns true if this is a return statement
virtual bool IsReturn() const;
/// @returns true if this is a switch statement
virtual bool IsSwitch() const;
/// @returns true if this is an variable statement
virtual bool IsVariableDecl() const;
/// @returns the human readable name for the statement type.
const char* Name() const;
/// @returns the statement as a const assign statement
const AssignmentStatement* AsAssign() const;
/// @returns the statement as a const block statement
const BlockStatement* AsBlock() const;
/// @returns the statement as a const break statement
const BreakStatement* AsBreak() const;
/// @returns the statement as a const call statement
const CallStatement* AsCall() const;
/// @returns the statement as a const case statement
const CaseStatement* AsCase() const;
/// @returns the statement as a const continue statement
const ContinueStatement* AsContinue() const;
/// @returns the statement as a const discard statement
const DiscardStatement* AsDiscard() const;
/// @returns the statement as a const else statement
const ElseStatement* AsElse() const;
/// @returns the statement as a const fallthrough statement
const FallthroughStatement* AsFallthrough() const;
/// @returns the statement as a const if statement
const IfStatement* AsIf() const;
/// @returns the statement as a const loop statement
const LoopStatement* AsLoop() const;
/// @returns the statement as a const return statement
const ReturnStatement* AsReturn() const;
/// @returns the statement as a const switch statement
const SwitchStatement* AsSwitch() const;
/// @returns the statement as a const variable statement
const VariableDeclStatement* AsVariableDecl() const;
/// @returns the statement as an assign statement
AssignmentStatement* AsAssign();
/// @returns the statement as a block statement
BlockStatement* AsBlock();
/// @returns the statement as a break statement
BreakStatement* AsBreak();
/// @returns the statement as a call statement
CallStatement* AsCall();
/// @returns the statement as a case statement
CaseStatement* AsCase();
/// @returns the statement as a continue statement
ContinueStatement* AsContinue();
/// @returns the statement as a discard statement
DiscardStatement* AsDiscard();
/// @returns the statement as a else statement
ElseStatement* AsElse();
/// @returns the statement as a fallthrough statement
FallthroughStatement* AsFallthrough();
/// @returns the statement as a if statement
IfStatement* AsIf();
/// @returns the statement as a loop statement
LoopStatement* AsLoop();
/// @returns the statement as a return statement
ReturnStatement* AsReturn();
/// @returns the statement as a switch statement
SwitchStatement* AsSwitch();
/// @returns the statement as an variable statement
VariableDeclStatement* AsVariableDecl();
protected:
/// Constructor
Statement();

View File

@ -29,10 +29,6 @@ SwitchStatement::SwitchStatement(const Source& source,
CaseStatementList body)
: Base(source), condition_(condition), body_(body) {}
bool SwitchStatement::IsSwitch() const {
return true;
}
SwitchStatement::SwitchStatement(SwitchStatement&&) = default;
SwitchStatement::~SwitchStatement() = default;

View File

@ -60,9 +60,6 @@ class SwitchStatement : public Castable<SwitchStatement, Statement> {
/// @returns the Switch body
const CaseStatementList& body() const { return body_; }
/// @returns true if this is a switch statement
bool IsSwitch() const override;
/// @returns true if the node is valid
bool IsValid() const override;

View File

@ -57,7 +57,7 @@ TEST_F(SwitchStatementTest, Creation_WithSource) {
TEST_F(SwitchStatementTest, IsSwitch) {
SwitchStatement stmt;
EXPECT_TRUE(stmt.IsSwitch());
EXPECT_TRUE(stmt.Is<SwitchStatement>());
}
TEST_F(SwitchStatementTest, IsValid) {

View File

@ -30,10 +30,6 @@ VariableDeclStatement::VariableDeclStatement(VariableDeclStatement&&) = default;
VariableDeclStatement::~VariableDeclStatement() = default;
bool VariableDeclStatement::IsVariableDecl() const {
return true;
}
bool VariableDeclStatement::IsValid() const {
return variable_ != nullptr && variable_->IsValid();
}

View File

@ -48,9 +48,6 @@ class VariableDeclStatement
/// @returns the variable
Variable* variable() const { return variable_; }
/// @returns true if this is an variable statement
bool IsVariableDecl() const override;
/// @returns true if the node is valid
bool IsValid() const override;

View File

@ -44,7 +44,7 @@ TEST_F(VariableDeclStatementTest, Creation_WithSource) {
TEST_F(VariableDeclStatementTest, IsVariableDecl) {
VariableDeclStatement s;
EXPECT_TRUE(s.IsVariableDecl());
EXPECT_TRUE(s.Is<VariableDeclStatement>());
}
TEST_F(VariableDeclStatementTest, IsValid) {

View File

@ -561,8 +561,8 @@ void FunctionEmitter::PushGuard(const std::string& guard_name,
const auto& top = statements_stack_.back();
auto* cond = create<ast::IdentifierExpression>(guard_name);
auto* body = create<ast::BlockStatement>();
auto* const guard_stmt =
AddStatement(create<ast::IfStatement>(cond, body))->AsIf();
auto* const guard_stmt = AddStatement(create<ast::IfStatement>(cond, body))
->As<ast::IfStatement>();
PushNewStatementBlock(top.construct_, end_id,
[guard_stmt](StatementBlock* s) {
guard_stmt->set_body(s->statements_);
@ -574,8 +574,8 @@ void FunctionEmitter::PushTrueGuard(uint32_t end_id) {
const auto& top = statements_stack_.back();
auto* cond = MakeTrue();
auto* body = create<ast::BlockStatement>();
auto* const guard_stmt =
AddStatement(create<ast::IfStatement>(cond, body))->AsIf();
auto* const guard_stmt = AddStatement(create<ast::IfStatement>(cond, body))
->As<ast::IfStatement>();
guard_stmt->set_condition(MakeTrue());
PushNewStatementBlock(top.construct_, end_id,
[guard_stmt](StatementBlock* s) {
@ -2023,8 +2023,8 @@ bool FunctionEmitter::EmitIfStart(const BlockInfo& block_info) {
block_info.basic_block->terminator()->GetSingleWordInOperand(0);
auto* cond = MakeExpression(condition_id).expr;
auto* body = create<ast::BlockStatement>();
auto* const if_stmt =
AddStatement(create<ast::IfStatement>(cond, body))->AsIf();
auto* const if_stmt = AddStatement(create<ast::IfStatement>(cond, body))
->As<ast::IfStatement>();
// Generate the code for the condition.
@ -2137,7 +2137,7 @@ bool FunctionEmitter::EmitSwitchStart(const BlockInfo& block_info) {
const auto* branch = block_info.basic_block->terminator();
auto* const switch_stmt =
AddStatement(create<ast::SwitchStatement>())->AsSwitch();
AddStatement(create<ast::SwitchStatement>())->As<ast::SwitchStatement>();
const auto selector_id = branch->GetSingleWordInOperand(0);
// Generate the code for the selector.
auto selector = MakeExpression(selector_id);
@ -2255,7 +2255,7 @@ bool FunctionEmitter::EmitLoopStart(const Construct* construct) {
auto* loop =
AddStatement(create<ast::LoopStatement>(create<ast::BlockStatement>(),
create<ast::BlockStatement>()))
->AsLoop();
->As<ast::LoopStatement>();
PushNewStatementBlock(
construct, construct->end_id,
[loop](StatementBlock* s) { loop->set_body(s->statements_); });
@ -2266,11 +2266,11 @@ bool FunctionEmitter::EmitContinuingStart(const Construct* construct) {
// A continue construct has the same depth as its associated loop
// construct. Start a continue construct.
auto* loop_candidate = LastStatement();
if (!loop_candidate->IsLoop()) {
if (!loop_candidate->Is<ast::LoopStatement>()) {
return Fail() << "internal error: starting continue construct, "
"expected loop on top of stack";
}
auto* loop = loop_candidate->AsLoop();
auto* loop = loop_candidate->As<ast::LoopStatement>();
PushNewStatementBlock(
construct, construct->end_id,
[loop](StatementBlock* s) { loop->set_continuing(s->statements_); });

View File

@ -27,16 +27,21 @@
#include "src/ast/access_control.h"
#include "src/ast/array_decoration.h"
#include "src/ast/assignment_statement.h"
#include "src/ast/break_statement.h"
#include "src/ast/builtin.h"
#include "src/ast/call_statement.h"
#include "src/ast/case_statement.h"
#include "src/ast/constructor_expression.h"
#include "src/ast/continue_statement.h"
#include "src/ast/else_statement.h"
#include "src/ast/switch_statement.h"
#include "src/ast/function.h"
#include "src/ast/if_statement.h"
#include "src/ast/literal.h"
#include "src/ast/loop_statement.h"
#include "src/ast/module.h"
#include "src/ast/pipeline_stage.h"
#include "src/ast/return_statement.h"
#include "src/ast/statement.h"
#include "src/ast/storage_class.h"
#include "src/ast/struct.h"
@ -44,10 +49,11 @@
#include "src/ast/struct_member.h"
#include "src/ast/struct_member_decoration.h"
#include "src/ast/type/storage_texture_type.h"
#include "src/ast/type/struct_type.h"
#include "src/ast/type/texture_type.h"
#include "src/ast/type/type.h"
#include "src/ast/type/struct_type.h"
#include "src/ast/variable.h"
#include "src/ast/variable_decl_statement.h"
#include "src/ast/variable_decoration.h"
#include "src/context.h"
#include "src/diagnostic/diagnostic.h"

View File

@ -36,7 +36,7 @@ TEST_F(ParserImplTest, AssignmentStmt_Parses_ToVariable) {
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->IsAssign());
ASSERT_TRUE(e->Is<ast::AssignmentStatement>());
ASSERT_NE(e->lhs(), nullptr);
ASSERT_NE(e->rhs(), nullptr);
@ -61,7 +61,7 @@ TEST_F(ParserImplTest, AssignmentStmt_Parses_ToMember) {
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->IsAssign());
ASSERT_TRUE(e->Is<ast::AssignmentStatement>());
ASSERT_NE(e->lhs(), nullptr);
ASSERT_NE(e->rhs(), nullptr);

View File

@ -13,6 +13,7 @@
// limitations under the License.
#include "gtest/gtest.h"
#include "src/ast/discard_statement.h"
#include "src/reader/wgsl/parser_impl.h"
#include "src/reader/wgsl/parser_impl_test_helper.h"
@ -30,8 +31,8 @@ TEST_F(ParserImplTest, BodyStmt) {
ASSERT_FALSE(p->has_error()) << p->error();
ASSERT_FALSE(e.errored);
ASSERT_EQ(e->size(), 2u);
EXPECT_TRUE(e->get(0)->IsDiscard());
EXPECT_TRUE(e->get(1)->IsReturn());
EXPECT_TRUE(e->get(0)->Is<ast::DiscardStatement>());
EXPECT_TRUE(e->get(1)->Is<ast::ReturnStatement>());
}
TEST_F(ParserImplTest, BodyStmt_Empty) {

View File

@ -28,7 +28,7 @@ TEST_F(ParserImplTest, BreakStmt) {
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->IsBreak());
ASSERT_TRUE(e->Is<ast::BreakStatement>());
}
} // namespace

View File

@ -32,8 +32,8 @@ TEST_F(ParserImplTest, Statement_Call) {
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_TRUE(e->IsCall());
auto* c = e->AsCall()->expr();
ASSERT_TRUE(e->Is<ast::CallStatement>());
auto* c = e->As<ast::CallStatement>()->expr();
ASSERT_TRUE(c->func()->IsIdentifier());
auto* func = c->func()->AsIdentifier();
@ -50,8 +50,8 @@ TEST_F(ParserImplTest, Statement_Call_WithParams) {
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_TRUE(e->IsCall());
auto* c = e->AsCall()->expr();
ASSERT_TRUE(e->Is<ast::CallStatement>());
auto* c = e->As<ast::CallStatement>()->expr();
ASSERT_TRUE(c->func()->IsIdentifier());
auto* func = c->func()->AsIdentifier();

View File

@ -13,6 +13,7 @@
// limitations under the License.
#include "gtest/gtest.h"
#include "src/ast/fallthrough_statement.h"
#include "src/reader/wgsl/parser_impl.h"
#include "src/reader/wgsl/parser_impl_test_helper.h"
@ -40,8 +41,8 @@ TEST_F(ParserImplTest, CaseBody_Statements) {
EXPECT_FALSE(e.errored);
EXPECT_TRUE(e.matched);
ASSERT_EQ(e->size(), 2u);
EXPECT_TRUE(e->get(0)->IsVariableDecl());
EXPECT_TRUE(e->get(1)->IsAssign());
EXPECT_TRUE(e->get(0)->Is<ast::VariableDeclStatement>());
EXPECT_TRUE(e->get(1)->Is<ast::AssignmentStatement>());
}
TEST_F(ParserImplTest, CaseBody_InvalidStatement) {
@ -60,7 +61,7 @@ TEST_F(ParserImplTest, CaseBody_Fallthrough) {
EXPECT_FALSE(e.errored);
EXPECT_TRUE(e.matched);
ASSERT_EQ(e->size(), 1u);
EXPECT_TRUE(e->get(0)->IsFallthrough());
EXPECT_TRUE(e->get(0)->Is<ast::FallthroughStatement>());
}
TEST_F(ParserImplTest, CaseBody_Fallthrough_MissingSemicolon) {

View File

@ -28,7 +28,7 @@ TEST_F(ParserImplTest, ContinueStmt) {
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->IsContinue());
ASSERT_TRUE(e->Is<ast::ContinueStatement>());
}
} // namespace

View File

@ -13,6 +13,7 @@
// limitations under the License.
#include "gtest/gtest.h"
#include "src/ast/discard_statement.h"
#include "src/reader/wgsl/parser_impl.h"
#include "src/reader/wgsl/parser_impl_test_helper.h"
@ -28,7 +29,7 @@ TEST_F(ParserImplTest, ContinuingStmt) {
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_EQ(e->size(), 1u);
ASSERT_TRUE(e->get(0)->IsDiscard());
ASSERT_TRUE(e->get(0)->Is<ast::DiscardStatement>());
}
TEST_F(ParserImplTest, ContinuingStmt_InvalidBody) {

View File

@ -29,7 +29,7 @@ TEST_F(ParserImplTest, ElseStmt) {
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->IsElse());
ASSERT_TRUE(e->Is<ast::ElseStatement>());
ASSERT_EQ(e->condition(), nullptr);
EXPECT_EQ(e->body()->size(), 2u);
}

View File

@ -30,7 +30,7 @@ TEST_F(ParserImplTest, ElseIfStmt) {
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_EQ(e.value.size(), 1u);
ASSERT_TRUE(e.value[0]->IsElse());
ASSERT_TRUE(e.value[0]->Is<ast::ElseStatement>());
ASSERT_NE(e.value[0]->condition(), nullptr);
ASSERT_TRUE(e.value[0]->condition()->IsBinary());
EXPECT_EQ(e.value[0]->body()->size(), 2u);
@ -44,12 +44,12 @@ TEST_F(ParserImplTest, ElseIfStmt_Multiple) {
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_EQ(e.value.size(), 2u);
ASSERT_TRUE(e.value[0]->IsElse());
ASSERT_TRUE(e.value[0]->Is<ast::ElseStatement>());
ASSERT_NE(e.value[0]->condition(), nullptr);
ASSERT_TRUE(e.value[0]->condition()->IsBinary());
EXPECT_EQ(e.value[0]->body()->size(), 2u);
ASSERT_TRUE(e.value[1]->IsElse());
ASSERT_TRUE(e.value[1]->Is<ast::ElseStatement>());
ASSERT_NE(e.value[1]->condition(), nullptr);
ASSERT_TRUE(e.value[1]->condition()->IsIdentifier());
EXPECT_EQ(e.value[1]->body()->size(), 1u);

View File

@ -50,7 +50,7 @@ TEST_F(ParserImplTest, FunctionDecl) {
auto* body = f->body();
ASSERT_EQ(body->size(), 1u);
EXPECT_TRUE(body->get(0)->IsReturn());
EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
}
TEST_F(ParserImplTest, FunctionDecl_DecorationList) {
@ -86,7 +86,7 @@ TEST_F(ParserImplTest, FunctionDecl_DecorationList) {
auto* body = f->body();
ASSERT_EQ(body->size(), 1u);
EXPECT_TRUE(body->get(0)->IsReturn());
EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
}
TEST_F(ParserImplTest, FunctionDecl_DecorationList_MultipleEntries) {
@ -130,7 +130,7 @@ fn main() -> void { return; })");
auto* body = f->body();
ASSERT_EQ(body->size(), 1u);
EXPECT_TRUE(body->get(0)->IsReturn());
EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
}
TEST_F(ParserImplTest, FunctionDecl_DecorationList_MultipleLists) {
@ -175,7 +175,7 @@ fn main() -> void { return; })");
auto* body = f->body();
ASSERT_EQ(body->size(), 1u);
EXPECT_TRUE(body->get(0)->IsReturn());
EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
}
TEST_F(ParserImplTest, FunctionDecl_InvalidHeader) {

View File

@ -31,7 +31,7 @@ TEST_F(ParserImplTest, IfStmt) {
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->IsIf());
ASSERT_TRUE(e->Is<ast::IfStatement>());
ASSERT_NE(e->condition(), nullptr);
ASSERT_TRUE(e->condition()->IsBinary());
EXPECT_EQ(e->body()->size(), 2u);
@ -46,7 +46,7 @@ TEST_F(ParserImplTest, IfStmt_WithElse) {
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->IsIf());
ASSERT_TRUE(e->Is<ast::IfStatement>());
ASSERT_NE(e->condition(), nullptr);
ASSERT_TRUE(e->condition()->IsBinary());
EXPECT_EQ(e->body()->size(), 2u);

View File

@ -13,6 +13,7 @@
// limitations under the License.
#include "gtest/gtest.h"
#include "src/ast/discard_statement.h"
#include "src/reader/wgsl/parser_impl.h"
#include "src/reader/wgsl/parser_impl_test_helper.h"
@ -30,7 +31,7 @@ TEST_F(ParserImplTest, LoopStmt_BodyNoContinuing) {
ASSERT_NE(e.value, nullptr);
ASSERT_EQ(e->body()->size(), 1u);
EXPECT_TRUE(e->body()->get(0)->IsDiscard());
EXPECT_TRUE(e->body()->get(0)->Is<ast::DiscardStatement>());
EXPECT_EQ(e->continuing()->size(), 0u);
}
@ -44,10 +45,10 @@ TEST_F(ParserImplTest, LoopStmt_BodyWithContinuing) {
ASSERT_NE(e.value, nullptr);
ASSERT_EQ(e->body()->size(), 1u);
EXPECT_TRUE(e->body()->get(0)->IsDiscard());
EXPECT_TRUE(e->body()->get(0)->Is<ast::DiscardStatement>());
EXPECT_EQ(e->continuing()->size(), 1u);
EXPECT_TRUE(e->continuing()->get(0)->IsDiscard());
EXPECT_TRUE(e->continuing()->get(0)->Is<ast::DiscardStatement>());
}
TEST_F(ParserImplTest, LoopStmt_NoBodyNoContinuing) {
@ -70,7 +71,7 @@ TEST_F(ParserImplTest, LoopStmt_NoBodyWithContinuing) {
ASSERT_NE(e.value, nullptr);
ASSERT_EQ(e->body()->size(), 0u);
ASSERT_EQ(e->continuing()->size(), 1u);
EXPECT_TRUE(e->continuing()->get(0)->IsDiscard());
EXPECT_TRUE(e->continuing()->get(0)->Is<ast::DiscardStatement>());
}
TEST_F(ParserImplTest, LoopStmt_MissingBracketLeft) {

View File

@ -13,6 +13,7 @@
// limitations under the License.
#include "gtest/gtest.h"
#include "src/ast/discard_statement.h"
#include "src/ast/return_statement.h"
#include "src/ast/statement.h"
#include "src/reader/wgsl/parser_impl.h"
@ -29,7 +30,7 @@ TEST_F(ParserImplTest, Statement) {
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_TRUE(e->IsReturn());
ASSERT_TRUE(e->Is<ast::ReturnStatement>());
}
TEST_F(ParserImplTest, Statement_Semicolon) {
@ -44,8 +45,8 @@ TEST_F(ParserImplTest, Statement_Return_NoValue) {
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_TRUE(e->IsReturn());
auto* ret = e->AsReturn();
ASSERT_TRUE(e->Is<ast::ReturnStatement>());
auto* ret = e->As<ast::ReturnStatement>();
ASSERT_EQ(ret->value(), nullptr);
}
@ -56,8 +57,8 @@ TEST_F(ParserImplTest, Statement_Return_Value) {
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_TRUE(e->IsReturn());
auto* ret = e->AsReturn();
ASSERT_TRUE(e->Is<ast::ReturnStatement>());
auto* ret = e->As<ast::ReturnStatement>();
ASSERT_NE(ret->value(), nullptr);
EXPECT_TRUE(ret->value()->IsBinary());
}
@ -88,7 +89,7 @@ TEST_F(ParserImplTest, Statement_If) {
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_TRUE(e->IsIf());
ASSERT_TRUE(e->Is<ast::IfStatement>());
}
TEST_F(ParserImplTest, Statement_If_Invalid) {
@ -107,7 +108,7 @@ TEST_F(ParserImplTest, Statement_Variable) {
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_TRUE(e->IsVariableDecl());
ASSERT_TRUE(e->Is<ast::VariableDeclStatement>());
}
TEST_F(ParserImplTest, Statement_Variable_Invalid) {
@ -136,7 +137,7 @@ TEST_F(ParserImplTest, Statement_Switch) {
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_TRUE(e->IsSwitch());
ASSERT_TRUE(e->Is<ast::SwitchStatement>());
}
TEST_F(ParserImplTest, Statement_Switch_Invalid) {
@ -155,7 +156,7 @@ TEST_F(ParserImplTest, Statement_Loop) {
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_TRUE(e->IsLoop());
ASSERT_TRUE(e->Is<ast::LoopStatement>());
}
TEST_F(ParserImplTest, Statement_Loop_Invalid) {
@ -174,7 +175,7 @@ TEST_F(ParserImplTest, Statement_Assignment) {
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_TRUE(e->IsAssign());
ASSERT_TRUE(e->Is<ast::AssignmentStatement>());
}
TEST_F(ParserImplTest, Statement_Assignment_Invalid) {
@ -203,7 +204,7 @@ TEST_F(ParserImplTest, Statement_Break) {
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_TRUE(e->IsBreak());
ASSERT_TRUE(e->Is<ast::BreakStatement>());
}
TEST_F(ParserImplTest, Statement_Break_MissingSemicolon) {
@ -222,7 +223,7 @@ TEST_F(ParserImplTest, Statement_Continue) {
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_TRUE(e->IsContinue());
ASSERT_TRUE(e->Is<ast::ContinueStatement>());
}
TEST_F(ParserImplTest, Statement_Continue_MissingSemicolon) {
@ -242,7 +243,7 @@ TEST_F(ParserImplTest, Statement_Discard) {
ASSERT_NE(e.value, nullptr);
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_TRUE(e->IsDiscard());
ASSERT_TRUE(e->Is<ast::DiscardStatement>());
}
TEST_F(ParserImplTest, Statement_Discard_MissingSemicolon) {
@ -261,8 +262,9 @@ TEST_F(ParserImplTest, Statement_Body) {
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_TRUE(e->IsBlock());
EXPECT_TRUE(e->AsBlock()->get(0)->IsVariableDecl());
ASSERT_TRUE(e->Is<ast::BlockStatement>());
EXPECT_TRUE(
e->As<ast::BlockStatement>()->get(0)->Is<ast::VariableDeclStatement>());
}
TEST_F(ParserImplTest, Statement_Body_Invalid) {

View File

@ -13,6 +13,7 @@
// limitations under the License.
#include "gtest/gtest.h"
#include "src/ast/discard_statement.h"
#include "src/ast/statement.h"
#include "src/reader/wgsl/parser_impl.h"
#include "src/reader/wgsl/parser_impl_test_helper.h"
@ -28,8 +29,8 @@ TEST_F(ParserImplTest, Statements) {
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_EQ(e->size(), 2u);
EXPECT_TRUE(e->get(0)->IsDiscard());
EXPECT_TRUE(e->get(1)->IsReturn());
EXPECT_TRUE(e->get(0)->Is<ast::DiscardStatement>());
EXPECT_TRUE(e->get(1)->Is<ast::ReturnStatement>());
}
TEST_F(ParserImplTest, Statements_Empty) {

View File

@ -29,10 +29,10 @@ TEST_F(ParserImplTest, SwitchBody_Case) {
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->IsCase());
ASSERT_TRUE(e->Is<ast::CaseStatement>());
EXPECT_FALSE(e->IsDefault());
ASSERT_EQ(e->body()->size(), 1u);
EXPECT_TRUE(e->body()->get(0)->IsAssign());
EXPECT_TRUE(e->body()->get(0)->Is<ast::AssignmentStatement>());
}
TEST_F(ParserImplTest, SwitchBody_Case_InvalidConstLiteral) {
@ -112,10 +112,10 @@ TEST_F(ParserImplTest, SwitchBody_Default) {
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->IsCase());
ASSERT_TRUE(e->Is<ast::CaseStatement>());
EXPECT_TRUE(e->IsDefault());
ASSERT_EQ(e->body()->size(), 1u);
EXPECT_TRUE(e->body()->get(0)->IsAssign());
EXPECT_TRUE(e->body()->get(0)->Is<ast::AssignmentStatement>());
}
TEST_F(ParserImplTest, SwitchBody_Default_MissingColon) {

View File

@ -33,7 +33,7 @@ TEST_F(ParserImplTest, SwitchStmt_WithoutDefault) {
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->IsSwitch());
ASSERT_TRUE(e->Is<ast::SwitchStatement>());
ASSERT_EQ(e->body().size(), 2u);
EXPECT_FALSE(e->body()[0]->IsDefault());
EXPECT_FALSE(e->body()[1]->IsDefault());
@ -46,7 +46,7 @@ TEST_F(ParserImplTest, SwitchStmt_Empty) {
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->IsSwitch());
ASSERT_TRUE(e->Is<ast::SwitchStatement>());
ASSERT_EQ(e->body().size(), 0u);
}
@ -61,7 +61,7 @@ TEST_F(ParserImplTest, SwitchStmt_DefaultInMiddle) {
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->IsSwitch());
ASSERT_TRUE(e->Is<ast::SwitchStatement>());
ASSERT_EQ(e->body().size(), 3u);
ASSERT_FALSE(e->body()[0]->IsDefault());

View File

@ -30,7 +30,7 @@ TEST_F(ParserImplTest, VariableStmt_VariableDecl) {
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->IsVariableDecl());
ASSERT_TRUE(e->Is<ast::VariableDeclStatement>());
ASSERT_NE(e->variable(), nullptr);
EXPECT_EQ(e->variable()->name(), "a");
@ -49,7 +49,7 @@ TEST_F(ParserImplTest, VariableStmt_VariableDecl_WithInit) {
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->IsVariableDecl());
ASSERT_TRUE(e->Is<ast::VariableDeclStatement>());
ASSERT_NE(e->variable(), nullptr);
EXPECT_EQ(e->variable()->name(), "a");
@ -89,7 +89,7 @@ TEST_F(ParserImplTest, VariableStmt_Const) {
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
ASSERT_TRUE(e->IsVariableDecl());
ASSERT_TRUE(e->Is<ast::VariableDeclStatement>());
ASSERT_EQ(e->source().range.begin.line, 1u);
ASSERT_EQ(e->source().range.begin.column, 7u);

View File

@ -21,10 +21,14 @@
#include "src/ast/binary_expression.h"
#include "src/ast/bitcast_expression.h"
#include "src/ast/block_statement.h"
#include "src/ast/break_statement.h"
#include "src/ast/call_expression.h"
#include "src/ast/call_statement.h"
#include "src/ast/case_statement.h"
#include "src/ast/continue_statement.h"
#include "src/ast/discard_statement.h"
#include "src/ast/else_statement.h"
#include "src/ast/fallthrough_statement.h"
#include "src/ast/if_statement.h"
#include "src/ast/loop_statement.h"
#include "src/ast/member_accessor_expression.h"
@ -67,52 +71,47 @@ bool BoundArrayAccessorsTransform::Run() {
}
bool BoundArrayAccessorsTransform::ProcessStatement(ast::Statement* stmt) {
if (stmt->IsAssign()) {
auto* as = stmt->AsAssign();
if (auto* as = stmt->As<ast::AssignmentStatement>()) {
return ProcessExpression(as->lhs()) && ProcessExpression(as->rhs());
} else if (stmt->IsBlock()) {
for (auto* s : *(stmt->AsBlock())) {
} else if (auto* block = stmt->As<ast::BlockStatement>()) {
for (auto* s : *block) {
if (!ProcessStatement(s)) {
return false;
}
}
} else if (stmt->IsBreak()) {
} else if (stmt->Is<ast::BreakStatement>()) {
/* nop */
} else if (stmt->IsCall()) {
return ProcessExpression(stmt->AsCall()->expr());
} else if (stmt->IsCase()) {
return ProcessStatement(stmt->AsCase()->body());
} else if (stmt->IsContinue()) {
} else if (auto* call = stmt->As<ast::CallStatement>()) {
return ProcessExpression(call->expr());
} else if (auto* kase = stmt->As<ast::CaseStatement>()) {
return ProcessStatement(kase->body());
} else if (stmt->Is<ast::ContinueStatement>()) {
/* nop */
} else if (stmt->IsDiscard()) {
} else if (stmt->Is<ast::DiscardStatement>()) {
/* nop */
} else if (stmt->IsElse()) {
auto* e = stmt->AsElse();
} else if (auto* e = stmt->As<ast::ElseStatement>()) {
return ProcessExpression(e->condition()) && ProcessStatement(e->body());
} else if (stmt->IsFallthrough()) {
} else if (stmt->Is<ast::FallthroughStatement>()) {
/* nop */
} else if (stmt->IsIf()) {
auto* e = stmt->AsIf();
if (!ProcessExpression(e->condition()) || !ProcessStatement(e->body())) {
} else if (auto* i = stmt->As<ast::IfStatement>()) {
if (!ProcessExpression(i->condition()) || !ProcessStatement(i->body())) {
return false;
}
for (auto* s : e->else_statements()) {
for (auto* s : i->else_statements()) {
if (!ProcessStatement(s)) {
return false;
}
}
} else if (stmt->IsLoop()) {
auto* l = stmt->AsLoop();
} else if (auto* l = stmt->As<ast::LoopStatement>()) {
if (l->has_continuing() && !ProcessStatement(l->continuing())) {
return false;
}
return ProcessStatement(l->body());
} else if (stmt->IsReturn()) {
if (stmt->AsReturn()->has_value()) {
return ProcessExpression(stmt->AsReturn()->value());
} else if (auto* r = stmt->As<ast::ReturnStatement>()) {
if (r->has_value()) {
return ProcessExpression(r->value());
}
} else if (stmt->IsSwitch()) {
auto* s = stmt->AsSwitch();
} else if (auto* s = stmt->As<ast::SwitchStatement>()) {
if (!ProcessExpression(s->condition())) {
return false;
}
@ -122,8 +121,8 @@ bool BoundArrayAccessorsTransform::ProcessStatement(ast::Statement* stmt) {
return false;
}
}
} else if (stmt->IsVariableDecl()) {
auto* v = stmt->AsVariableDecl()->variable();
} else if (auto* vd = stmt->As<ast::VariableDeclStatement>()) {
auto* v = vd->variable();
if (v->has_constructor() && !ProcessExpression(v->constructor())) {
return false;
}

View File

@ -27,7 +27,9 @@
#include "src/ast/call_statement.h"
#include "src/ast/case_statement.h"
#include "src/ast/continue_statement.h"
#include "src/ast/discard_statement.h"
#include "src/ast/else_statement.h"
#include "src/ast/fallthrough_statement.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/if_statement.h"
#include "src/ast/intrinsic.h"
@ -178,11 +180,11 @@ bool TypeDeterminer::DetermineStatements(const ast::BlockStatement* stmts) {
}
bool TypeDeterminer::DetermineVariableStorageClass(ast::Statement* stmt) {
if (!stmt->IsVariableDecl()) {
if (!stmt->Is<ast::VariableDeclStatement>()) {
return true;
}
auto* var = stmt->AsVariableDecl()->variable();
auto* var = stmt->As<ast::VariableDeclStatement>()->variable();
// Nothing to do for const
if (var->is_const()) {
return true;
@ -203,39 +205,35 @@ bool TypeDeterminer::DetermineVariableStorageClass(ast::Statement* stmt) {
}
bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
if (stmt->IsAssign()) {
auto* a = stmt->AsAssign();
if (auto* a = stmt->As<ast::AssignmentStatement>()) {
return DetermineResultType(a->lhs()) && DetermineResultType(a->rhs());
}
if (stmt->IsBlock()) {
return DetermineStatements(stmt->AsBlock());
if (auto* b = stmt->As<ast::BlockStatement>()) {
return DetermineStatements(b);
}
if (stmt->IsBreak()) {
if (stmt->Is<ast::BreakStatement>()) {
return true;
}
if (stmt->IsCall()) {
return DetermineResultType(stmt->AsCall()->expr());
if (auto* c = stmt->As<ast::CallStatement>()) {
return DetermineResultType(c->expr());
}
if (stmt->IsCase()) {
auto* c = stmt->AsCase();
if (auto* c = stmt->As<ast::CaseStatement>()) {
return DetermineStatements(c->body());
}
if (stmt->IsContinue()) {
if (stmt->Is<ast::ContinueStatement>()) {
return true;
}
if (stmt->IsDiscard()) {
if (stmt->Is<ast::DiscardStatement>()) {
return true;
}
if (stmt->IsElse()) {
auto* e = stmt->AsElse();
if (auto* e = stmt->As<ast::ElseStatement>()) {
return DetermineResultType(e->condition()) &&
DetermineStatements(e->body());
}
if (stmt->IsFallthrough()) {
if (stmt->Is<ast::FallthroughStatement>()) {
return true;
}
if (stmt->IsIf()) {
auto* i = stmt->AsIf();
if (auto* i = stmt->As<ast::IfStatement>()) {
if (!DetermineResultType(i->condition()) ||
!DetermineStatements(i->body())) {
return false;
@ -248,17 +246,14 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
}
return true;
}
if (stmt->IsLoop()) {
auto* l = stmt->AsLoop();
if (auto* l = stmt->As<ast::LoopStatement>()) {
return DetermineStatements(l->body()) &&
DetermineStatements(l->continuing());
}
if (stmt->IsReturn()) {
auto* r = stmt->AsReturn();
if (auto* r = stmt->As<ast::ReturnStatement>()) {
return DetermineResultType(r->value());
}
if (stmt->IsSwitch()) {
auto* s = stmt->AsSwitch();
if (auto* s = stmt->As<ast::SwitchStatement>()) {
if (!DetermineResultType(s->condition())) {
return false;
}
@ -269,8 +264,7 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
}
return true;
}
if (stmt->IsVariableDecl()) {
auto* v = stmt->AsVariableDecl();
if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
variable_stack_.set(v->variable()->name(), v->variable());
return DetermineResultType(v->variable()->constructor());
}

View File

@ -19,6 +19,7 @@
#include <utility>
#include "src/ast/call_statement.h"
#include "src/ast/fallthrough_statement.h"
#include "src/ast/function.h"
#include "src/ast/int_literal.h"
#include "src/ast/intrinsic.h"
@ -208,7 +209,7 @@ bool ValidatorImpl::ValidateFunction(const ast::Function* func) {
if (!current_function_->return_type()->Is<ast::type::VoidType>()) {
if (!func->get_last_statement() ||
!func->get_last_statement()->IsReturn()) {
!func->get_last_statement()->Is<ast::ReturnStatement>()) {
add_error(func->source(), "v-0002",
"non-void function must end with a return statement");
return false;
@ -284,29 +285,28 @@ bool ValidatorImpl::ValidateStatement(const ast::Statement* stmt) {
if (!stmt) {
return false;
}
if (stmt->IsVariableDecl()) {
auto* v = stmt->AsVariableDecl();
if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
bool constructor_valid =
v->variable()->has_constructor()
? ValidateExpression(v->variable()->constructor())
: true;
return constructor_valid && ValidateDeclStatement(stmt->AsVariableDecl());
return constructor_valid && ValidateDeclStatement(v);
}
if (stmt->IsAssign()) {
return ValidateAssign(stmt->AsAssign());
if (auto* a = stmt->As<ast::AssignmentStatement>()) {
return ValidateAssign(a);
}
if (stmt->IsReturn()) {
return ValidateReturnStatement(stmt->AsReturn());
if (auto* r = stmt->As<ast::ReturnStatement>()) {
return ValidateReturnStatement(r);
}
if (stmt->IsCall()) {
return ValidateCallExpr(stmt->AsCall()->expr());
if (auto* c = stmt->As<ast::CallStatement>()) {
return ValidateCallExpr(c->expr());
}
if (stmt->IsSwitch()) {
return ValidateSwitch(stmt->AsSwitch());
if (auto* s = stmt->As<ast::SwitchStatement>()) {
return ValidateSwitch(s);
}
if (stmt->IsCase()) {
return ValidateCase(stmt->AsCase());
if (auto* c = stmt->As<ast::CaseStatement>()) {
return ValidateCase(c);
}
return true;
}
@ -368,8 +368,10 @@ bool ValidatorImpl::ValidateSwitch(const ast::SwitchStatement* s) {
}
auto* last_clause = s->body().back();
auto* last_stmt_of_last_clause = last_clause->AsCase()->body()->last();
if (last_stmt_of_last_clause && last_stmt_of_last_clause->IsFallthrough()) {
auto* last_stmt_of_last_clause =
last_clause->As<ast::CaseStatement>()->body()->last();
if (last_stmt_of_last_clause &&
last_stmt_of_last_clause->Is<ast::FallthroughStatement>()) {
add_error(last_stmt_of_last_clause->source(), "v-0028",
"a fallthrough statement must not appear as "
"the last statement in last clause of a switch");

View File

@ -26,7 +26,9 @@
#include "src/ast/module.h"
#include "src/ast/return_statement.h"
#include "src/ast/statement.h"
#include "src/ast/switch_statement.h"
#include "src/ast/variable.h"
#include "src/ast/variable_decl_statement.h"
#include "src/diagnostic/diagnostic.h"
#include "src/diagnostic/formatter.h"
#include "src/scope_stack.h"

View File

@ -29,6 +29,7 @@
#include "src/ast/case_statement.h"
#include "src/ast/decorated_variable.h"
#include "src/ast/else_statement.h"
#include "src/ast/fallthrough_statement.h"
#include "src/ast/float_literal.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/if_statement.h"
@ -75,7 +76,8 @@ bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
return false;
}
return stmts->last()->IsBreak() || stmts->last()->IsFallthrough();
return stmts->last()->Is<ast::BreakStatement>() ||
stmts->last()->Is<ast::FallthroughStatement>();
}
std::string get_buffer_name(ast::Expression* expr) {
@ -1601,11 +1603,10 @@ bool GeneratorImpl::EmitLoop(std::ostream& out, ast::LoopStatement* stmt) {
// the for loop into the continuing scope. Then, the variable declarations
// will be turned into assignments.
for (auto* s : *stmt->body()) {
if (!s->IsVariableDecl()) {
continue;
}
if (!EmitVariable(out, s->AsVariableDecl()->variable(), true)) {
return false;
if (auto* v = s->As<ast::VariableDeclStatement>()) {
if (!EmitVariable(out, v->variable(), true)) {
return false;
}
}
}
}
@ -1630,10 +1631,11 @@ bool GeneratorImpl::EmitLoop(std::ostream& out, ast::LoopStatement* stmt) {
for (auto* s : *(stmt->body())) {
// If we have a continuing block we've already emitted the variable
// declaration before the loop, so treat it as an assignment.
if (s->IsVariableDecl() && stmt->has_continuing()) {
auto* decl = s->As<ast::VariableDeclStatement>();
if (decl != nullptr && stmt->has_continuing()) {
make_indent(out);
auto* var = s->AsVariableDecl()->variable();
auto* var = decl->variable();
std::ostringstream pre;
std::ostringstream constructor_out;
@ -1963,51 +1965,51 @@ bool GeneratorImpl::EmitReturn(std::ostream& out, ast::ReturnStatement* stmt) {
}
bool GeneratorImpl::EmitStatement(std::ostream& out, ast::Statement* stmt) {
if (stmt->IsAssign()) {
return EmitAssign(out, stmt->AsAssign());
if (auto* a = stmt->As<ast::AssignmentStatement>()) {
return EmitAssign(out, a);
}
if (stmt->IsBlock()) {
return EmitIndentedBlockAndNewline(out, stmt->AsBlock());
if (auto* b = stmt->As<ast::BlockStatement>()) {
return EmitIndentedBlockAndNewline(out, b);
}
if (stmt->IsBreak()) {
return EmitBreak(out, stmt->AsBreak());
if (auto* b = stmt->As<ast::BreakStatement>()) {
return EmitBreak(out, b);
}
if (stmt->IsCall()) {
if (auto* c = stmt->As<ast::CallStatement>()) {
make_indent(out);
std::ostringstream pre;
std::ostringstream call_out;
if (!EmitCall(pre, call_out, stmt->AsCall()->expr())) {
if (!EmitCall(pre, call_out, c->expr())) {
return false;
}
out << pre.str();
out << call_out.str() << ";" << std::endl;
return true;
}
if (stmt->IsContinue()) {
return EmitContinue(out, stmt->AsContinue());
if (auto* c = stmt->As<ast::ContinueStatement>()) {
return EmitContinue(out, c);
}
if (stmt->IsDiscard()) {
return EmitDiscard(out, stmt->AsDiscard());
if (auto* d = stmt->As<ast::DiscardStatement>()) {
return EmitDiscard(out, d);
}
if (stmt->IsFallthrough()) {
if (auto* f = stmt->As<ast::FallthroughStatement>()) {
make_indent(out);
out << "/* fallthrough */" << std::endl;
return true;
}
if (stmt->IsIf()) {
return EmitIf(out, stmt->AsIf());
if (auto* i = stmt->As<ast::IfStatement>()) {
return EmitIf(out, i);
}
if (stmt->IsLoop()) {
return EmitLoop(out, stmt->AsLoop());
if (auto* l = stmt->As<ast::LoopStatement>()) {
return EmitLoop(out, l);
}
if (stmt->IsReturn()) {
return EmitReturn(out, stmt->AsReturn());
if (auto* r = stmt->As<ast::ReturnStatement>()) {
return EmitReturn(out, r);
}
if (stmt->IsSwitch()) {
return EmitSwitch(out, stmt->AsSwitch());
if (auto* s = stmt->As<ast::SwitchStatement>()) {
return EmitSwitch(out, s);
}
if (stmt->IsVariableDecl()) {
return EmitVariable(out, stmt->AsVariableDecl()->variable(), false);
if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
return EmitVariable(out, v->variable(), false);
}
error_ = "unknown statement type: " + stmt->str();

View File

@ -19,10 +19,19 @@
#include <unordered_map>
#include <unordered_set>
#include "src/ast/assignment_statement.h"
#include "src/ast/break_statement.h"
#include "src/ast/case_statement.h"
#include "src/ast/continue_statement.h"
#include "src/ast/discard_statement.h"
#include "src/ast/if_statement.h"
#include "src/ast/intrinsic.h"
#include "src/ast/literal.h"
#include "src/ast/loop_statement.h"
#include "src/ast/module.h"
#include "src/ast/return_statement.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/switch_statement.h"
#include "src/ast/type/struct_type.h"
#include "src/ast/type_constructor_expression.h"
#include "src/context.h"

View File

@ -32,6 +32,7 @@
#include "src/ast/continue_statement.h"
#include "src/ast/decorated_variable.h"
#include "src/ast/else_statement.h"
#include "src/ast/fallthrough_statement.h"
#include "src/ast/float_literal.h"
#include "src/ast/function.h"
#include "src/ast/identifier_expression.h"
@ -80,7 +81,8 @@ bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
return false;
}
return stmts->last()->IsBreak() || stmts->last()->IsFallthrough();
return stmts->last()->Is<ast::BreakStatement>() ||
stmts->last()->Is<ast::FallthroughStatement>();
}
uint32_t adjust_for_alignment(uint32_t count, uint32_t alignment) {
@ -1540,11 +1542,10 @@ bool GeneratorImpl::EmitLoop(ast::LoopStatement* stmt) {
// the for loop into the continuing scope. Then, the variable declarations
// will be turned into assignments.
for (auto* s : *(stmt->body())) {
if (!s->IsVariableDecl()) {
continue;
}
if (!EmitVariable(s->AsVariableDecl()->variable(), true)) {
return false;
if (auto* decl = s->As<ast::VariableDeclStatement>()) {
if (!EmitVariable(decl->variable(), true)) {
return false;
}
}
}
}
@ -1569,10 +1570,11 @@ bool GeneratorImpl::EmitLoop(ast::LoopStatement* stmt) {
for (auto* s : *(stmt->body())) {
// If we have a continuing block we've already emitted the variable
// declaration before the loop, so treat it as an assignment.
if (s->IsVariableDecl() && stmt->has_continuing()) {
auto* decl = s->As<ast::VariableDeclStatement>();
if (decl != nullptr && stmt->has_continuing()) {
make_indent();
auto* var = s->AsVariableDecl()->variable();
auto* var = decl->variable();
out_ << var->name() << " = ";
if (var->constructor() != nullptr) {
if (!EmitExpression(var->constructor())) {
@ -1716,48 +1718,48 @@ bool GeneratorImpl::EmitIndentedBlockAndNewline(ast::BlockStatement* stmt) {
}
bool GeneratorImpl::EmitStatement(ast::Statement* stmt) {
if (stmt->IsAssign()) {
return EmitAssign(stmt->AsAssign());
if (auto* a = stmt->As<ast::AssignmentStatement>()) {
return EmitAssign(a);
}
if (stmt->IsBlock()) {
return EmitIndentedBlockAndNewline(stmt->AsBlock());
if (auto* b = stmt->As<ast::BlockStatement>()) {
return EmitIndentedBlockAndNewline(b);
}
if (stmt->IsBreak()) {
return EmitBreak(stmt->AsBreak());
if (auto* b = stmt->As<ast::BreakStatement>()) {
return EmitBreak(b);
}
if (stmt->IsCall()) {
if (auto* c = stmt->As<ast::CallStatement>()) {
make_indent();
if (!EmitCall(stmt->AsCall()->expr())) {
if (!EmitCall(c->expr())) {
return false;
}
out_ << ";" << std::endl;
return true;
}
if (stmt->IsContinue()) {
return EmitContinue(stmt->AsContinue());
if (auto* c = stmt->As<ast::ContinueStatement>()) {
return EmitContinue(c);
}
if (stmt->IsDiscard()) {
return EmitDiscard(stmt->AsDiscard());
if (auto* d = stmt->As<ast::DiscardStatement>()) {
return EmitDiscard(d);
}
if (stmt->IsFallthrough()) {
if (auto* f = stmt->As<ast::FallthroughStatement>()) {
make_indent();
out_ << "/* fallthrough */" << std::endl;
return true;
}
if (stmt->IsIf()) {
return EmitIf(stmt->AsIf());
if (auto* i = stmt->As<ast::IfStatement>()) {
return EmitIf(i);
}
if (stmt->IsLoop()) {
return EmitLoop(stmt->AsLoop());
if (auto* l = stmt->As<ast::LoopStatement>()) {
return EmitLoop(l);
}
if (stmt->IsReturn()) {
return EmitReturn(stmt->AsReturn());
if (auto* r = stmt->As<ast::ReturnStatement>()) {
return EmitReturn(r);
}
if (stmt->IsSwitch()) {
return EmitSwitch(stmt->AsSwitch());
if (auto* s = stmt->As<ast::SwitchStatement>()) {
return EmitSwitch(s);
}
if (stmt->IsVariableDecl()) {
return EmitVariable(stmt->AsVariableDecl()->variable(), false);
if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
return EmitVariable(v->variable(), false);
}
error_ = "unknown statement type: " + stmt->str();

View File

@ -19,10 +19,20 @@
#include <string>
#include <unordered_map>
#include "src/ast/assignment_statement.h"
#include "src/ast/break_statement.h"
#include "src/ast/case_statement.h"
#include "src/ast/continue_statement.h"
#include "src/ast/discard_statement.h"
#include "src/ast/else_statement.h"
#include "src/ast/if_statement.h"
#include "src/ast/intrinsic.h"
#include "src/ast/literal.h"
#include "src/ast/loop_statement.h"
#include "src/ast/module.h"
#include "src/ast/return_statement.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/switch_statement.h"
#include "src/ast/type/struct_type.h"
#include "src/ast/type_constructor_expression.h"
#include "src/scope_stack.h"

View File

@ -35,6 +35,7 @@
#include "src/ast/constructor_expression.h"
#include "src/ast/decorated_variable.h"
#include "src/ast/else_statement.h"
#include "src/ast/fallthrough_statement.h"
#include "src/ast/float_literal.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/if_statement.h"
@ -109,7 +110,7 @@ uint32_t pipeline_stage_to_execution_model(ast::PipelineStage stage) {
}
bool LastIsFallthrough(const ast::BlockStatement* stmts) {
return !stmts->empty() && stmts->last()->IsFallthrough();
return !stmts->empty() && stmts->last()->Is<ast::FallthroughStatement>();
}
// A terminator is anything which will case a SPIR-V terminator to be emitted.
@ -121,8 +122,11 @@ bool LastIsTerminator(const ast::BlockStatement* stmts) {
}
auto* last = stmts->last();
return last->IsBreak() || last->IsContinue() || last->IsDiscard() ||
last->IsReturn() || last->IsFallthrough();
return last->Is<ast::BreakStatement>() ||
last->Is<ast::ContinueStatement>() ||
last->Is<ast::DiscardStatement>() ||
last->Is<ast::ReturnStatement>() ||
last->Is<ast::FallthroughStatement>();
}
uint32_t IndexFromName(char name) {
@ -2359,42 +2363,42 @@ bool Builder::GenerateLoopStatement(ast::LoopStatement* stmt) {
}
bool Builder::GenerateStatement(ast::Statement* stmt) {
if (stmt->IsAssign()) {
return GenerateAssignStatement(stmt->AsAssign());
if (auto* a = stmt->As<ast::AssignmentStatement>()) {
return GenerateAssignStatement(a);
}
if (stmt->IsBlock()) {
return GenerateBlockStatement(stmt->AsBlock());
if (auto* b = stmt->As<ast::BlockStatement>()) {
return GenerateBlockStatement(b);
}
if (stmt->IsBreak()) {
return GenerateBreakStatement(stmt->AsBreak());
if (auto* b = stmt->As<ast::BreakStatement>()) {
return GenerateBreakStatement(b);
}
if (stmt->IsCall()) {
return GenerateCallExpression(stmt->AsCall()->expr()) != 0;
if (auto* c = stmt->As<ast::CallStatement>()) {
return GenerateCallExpression(c->expr()) != 0;
}
if (stmt->IsContinue()) {
return GenerateContinueStatement(stmt->AsContinue());
if (auto* c = stmt->As<ast::ContinueStatement>()) {
return GenerateContinueStatement(c);
}
if (stmt->IsDiscard()) {
return GenerateDiscardStatement(stmt->AsDiscard());
if (auto* d = stmt->As<ast::DiscardStatement>()) {
return GenerateDiscardStatement(d);
}
if (stmt->IsFallthrough()) {
if (stmt->Is<ast::FallthroughStatement>()) {
// Do nothing here, the fallthrough gets handled by the switch code.
return true;
}
if (stmt->IsIf()) {
return GenerateIfStatement(stmt->AsIf());
if (auto* i = stmt->As<ast::IfStatement>()) {
return GenerateIfStatement(i);
}
if (stmt->IsLoop()) {
return GenerateLoopStatement(stmt->AsLoop());
if (auto* l = stmt->As<ast::LoopStatement>()) {
return GenerateLoopStatement(l);
}
if (stmt->IsReturn()) {
return GenerateReturnStatement(stmt->AsReturn());
if (auto* r = stmt->As<ast::ReturnStatement>()) {
return GenerateReturnStatement(r);
}
if (stmt->IsSwitch()) {
return GenerateSwitchStatement(stmt->AsSwitch());
if (auto* s = stmt->As<ast::SwitchStatement>()) {
return GenerateSwitchStatement(s);
}
if (stmt->IsVariableDecl()) {
return GenerateVariableDeclStatement(stmt->AsVariableDecl());
if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
return GenerateVariableDeclStatement(v);
}
error_ = "Unknown statement: " + stmt->str();

View File

@ -22,11 +22,19 @@
#include <vector>
#include "spirv/unified1/spirv.h"
#include "src/ast/assignment_statement.h"
#include "src/ast/break_statement.h"
#include "src/ast/builtin.h"
#include "src/ast/continue_statement.h"
#include "src/ast/discard_statement.h"
#include "src/ast/else_statement.h"
#include "src/ast/if_statement.h"
#include "src/ast/literal.h"
#include "src/ast/loop_statement.h"
#include "src/ast/module.h"
#include "src/ast/return_statement.h"
#include "src/ast/struct_member.h"
#include "src/ast/switch_statement.h"
#include "src/ast/type/access_control_type.h"
#include "src/ast/type/array_type.h"
#include "src/ast/type/matrix_type.h"
@ -35,6 +43,7 @@
#include "src/ast/type/struct_type.h"
#include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h"
#include "src/ast/variable_decl_statement.h"
#include "src/context.h"
#include "src/scope_stack.h"
#include "src/writer/spirv/function.h"

View File

@ -796,46 +796,46 @@ bool GeneratorImpl::EmitBlockAndNewline(const ast::BlockStatement* stmt) {
}
bool GeneratorImpl::EmitStatement(ast::Statement* stmt) {
if (stmt->IsAssign()) {
return EmitAssign(stmt->AsAssign());
if (auto* a = stmt->As<ast::AssignmentStatement>()) {
return EmitAssign(a);
}
if (stmt->IsBlock()) {
return EmitIndentedBlockAndNewline(stmt->AsBlock());
if (auto* b = stmt->As<ast::BlockStatement>()) {
return EmitIndentedBlockAndNewline(b);
}
if (stmt->IsBreak()) {
return EmitBreak(stmt->AsBreak());
if (auto* b = stmt->As<ast::BreakStatement>()) {
return EmitBreak(b);
}
if (stmt->IsCall()) {
if (auto* c = stmt->As<ast::CallStatement>()) {
make_indent();
if (!EmitCall(stmt->AsCall()->expr())) {
if (!EmitCall(c->expr())) {
return false;
}
out_ << ";" << std::endl;
return true;
}
if (stmt->IsContinue()) {
return EmitContinue(stmt->AsContinue());
if (auto* c = stmt->As<ast::ContinueStatement>()) {
return EmitContinue(c);
}
if (stmt->IsDiscard()) {
return EmitDiscard(stmt->AsDiscard());
if (auto* d = stmt->As<ast::DiscardStatement>()) {
return EmitDiscard(d);
}
if (stmt->IsFallthrough()) {
return EmitFallthrough(stmt->AsFallthrough());
if (auto* f = stmt->As<ast::FallthroughStatement>()) {
return EmitFallthrough(f);
}
if (stmt->IsIf()) {
return EmitIf(stmt->AsIf());
if (auto* i = stmt->As<ast::IfStatement>()) {
return EmitIf(i);
}
if (stmt->IsLoop()) {
return EmitLoop(stmt->AsLoop());
if (auto* l = stmt->As<ast::LoopStatement>()) {
return EmitLoop(l);
}
if (stmt->IsReturn()) {
return EmitReturn(stmt->AsReturn());
if (auto* r = stmt->As<ast::ReturnStatement>()) {
return EmitReturn(r);
}
if (stmt->IsSwitch()) {
return EmitSwitch(stmt->AsSwitch());
if (auto* s = stmt->As<ast::SwitchStatement>()) {
return EmitSwitch(s);
}
if (stmt->IsVariableDecl()) {
return EmitVariable(stmt->AsVariableDecl()->variable());
if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
return EmitVariable(v->variable());
}
error_ = "unknown statement type: " + stmt->str();

View File

@ -19,10 +19,20 @@
#include <string>
#include "src/ast/array_accessor_expression.h"
#include "src/ast/assignment_statement.h"
#include "src/ast/break_statement.h"
#include "src/ast/case_statement.h"
#include "src/ast/constructor_expression.h"
#include "src/ast/continue_statement.h"
#include "src/ast/discard_statement.h"
#include "src/ast/fallthrough_statement.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/if_statement.h"
#include "src/ast/loop_statement.h"
#include "src/ast/module.h"
#include "src/ast/return_statement.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/switch_statement.h"
#include "src/ast/type/storage_texture_type.h"
#include "src/ast/type/struct_type.h"
#include "src/ast/type/type.h"