Add semantic::Statement::Block() for getting the owning BlockStatement
Required special casing the ElseStatement, as this isn't actually owned by a BlockStatement. Change-Id: Ic33c207598b838a12b865a7694e596b2629c9208 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46443 Reviewed-by: Antonio Maiorano <amaiorano@google.com> Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
285b8b6e75
commit
dba65b7a34
|
@ -65,6 +65,9 @@ class BlockStatement : public Castable<BlockStatement, Statement> {
|
|||
/// @returns the ending iterator
|
||||
StatementList::const_iterator end() const { return statements_.end(); }
|
||||
|
||||
/// @returns the statement list
|
||||
const StatementList& statements() const { return statements_; }
|
||||
|
||||
/// Clones this node and all transitive child nodes using the `CloneContext`
|
||||
/// `ctx`.
|
||||
/// @param ctx the clone context
|
||||
|
|
|
@ -78,9 +78,10 @@ Resolver::Resolver(ProgramBuilder* builder)
|
|||
|
||||
Resolver::~Resolver() = default;
|
||||
|
||||
Resolver::BlockInfo::BlockInfo(Resolver::BlockInfo::Type ty,
|
||||
Resolver::BlockInfo::BlockInfo(const ast::BlockStatement* b,
|
||||
Resolver::BlockInfo::Type ty,
|
||||
Resolver::BlockInfo* p)
|
||||
: type(ty), parent(p) {}
|
||||
: block(b), type(ty), parent(p) {}
|
||||
|
||||
Resolver::BlockInfo::~BlockInfo() = default;
|
||||
|
||||
|
@ -370,7 +371,7 @@ bool Resolver::Function(ast::Function* func) {
|
|||
}
|
||||
|
||||
bool Resolver::BlockStatement(const ast::BlockStatement* stmt) {
|
||||
return BlockScope(BlockInfo::Type::kGeneric,
|
||||
return BlockScope(stmt, BlockInfo::Type::kGeneric,
|
||||
[&] { return Statements(stmt->list()); });
|
||||
}
|
||||
|
||||
|
@ -384,7 +385,9 @@ bool Resolver::Statements(const ast::StatementList& stmts) {
|
|||
}
|
||||
|
||||
bool Resolver::Statement(ast::Statement* stmt) {
|
||||
auto* sem_statement = builder_->create<semantic::Statement>(stmt);
|
||||
auto* sem_statement =
|
||||
builder_->create<semantic::Statement>(stmt, current_block_->block);
|
||||
builder_->Sem().Add(stmt, sem_statement);
|
||||
|
||||
ScopedAssignment<semantic::Statement*> sa(current_statement_, sem_statement);
|
||||
|
||||
|
@ -427,9 +430,6 @@ bool Resolver::Statement(ast::Statement* stmt) {
|
|||
if (stmt->Is<ast::DiscardStatement>()) {
|
||||
return true;
|
||||
}
|
||||
if (auto* e = stmt->As<ast::ElseStatement>()) {
|
||||
return Expression(e->condition()) && BlockStatement(e->body());
|
||||
}
|
||||
if (stmt->Is<ast::FallthroughStatement>()) {
|
||||
return true;
|
||||
}
|
||||
|
@ -441,13 +441,13 @@ bool Resolver::Statement(ast::Statement* stmt) {
|
|||
// these would make their BlockInfo siblings as in the AST, but we want the
|
||||
// body BlockInfo to parent the continuing BlockInfo for semantics and
|
||||
// validation. Also, we need to set their types differently.
|
||||
return BlockScope(BlockInfo::Type::kLoop, [&] {
|
||||
return BlockScope(l->body(), BlockInfo::Type::kLoop, [&] {
|
||||
if (!Statements(l->body()->list())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (l->has_continuing()) {
|
||||
if (!BlockScope(BlockInfo::Type::kLoopContinuing,
|
||||
if (!BlockScope(l->continuing(), BlockInfo::Type::kLoopContinuing,
|
||||
[&] { return Statements(l->continuing()->list()); })) {
|
||||
return false;
|
||||
}
|
||||
|
@ -473,7 +473,7 @@ bool Resolver::Statement(ast::Statement* stmt) {
|
|||
}
|
||||
|
||||
bool Resolver::CaseStatement(ast::CaseStatement* stmt) {
|
||||
return BlockScope(BlockInfo::Type::kSwitchCase,
|
||||
return BlockScope(stmt->body(), BlockInfo::Type::kSwitchCase,
|
||||
[&] { return Statements(stmt->body()->list()); });
|
||||
}
|
||||
|
||||
|
@ -495,7 +495,18 @@ bool Resolver::IfStatement(ast::IfStatement* stmt) {
|
|||
}
|
||||
|
||||
for (auto* else_stmt : stmt->else_statements()) {
|
||||
if (!Statement(else_stmt)) {
|
||||
// Else statements are a bit unusual - they're owned by the if-statement,
|
||||
// not a BlockStatement.
|
||||
constexpr ast::BlockStatement* no_block_statement = nullptr;
|
||||
auto* sem_else_stmt =
|
||||
builder_->create<semantic::Statement>(else_stmt, no_block_statement);
|
||||
builder_->Sem().Add(else_stmt, sem_else_stmt);
|
||||
ScopedAssignment<semantic::Statement*> sa(current_statement_,
|
||||
sem_else_stmt);
|
||||
if (!Expression(else_stmt->condition())) {
|
||||
return false;
|
||||
}
|
||||
if (!BlockStatement(else_stmt->body())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -1923,8 +1934,10 @@ bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
|
|||
}
|
||||
|
||||
template <typename F>
|
||||
bool Resolver::BlockScope(BlockInfo::Type type, F&& callback) {
|
||||
BlockInfo block_info(type, current_block_);
|
||||
bool Resolver::BlockScope(const ast::BlockStatement* block,
|
||||
BlockInfo::Type type,
|
||||
F&& callback) {
|
||||
BlockInfo block_info(block, type, current_block_);
|
||||
ScopedAssignment<BlockInfo*> sa(current_block_, &block_info);
|
||||
variable_stack_.push_scope();
|
||||
bool result = callback();
|
||||
|
|
|
@ -149,7 +149,7 @@ class Resolver {
|
|||
struct BlockInfo {
|
||||
enum class Type { kGeneric, kLoop, kLoopContinuing, kSwitchCase };
|
||||
|
||||
BlockInfo(Type type, BlockInfo* parent);
|
||||
BlockInfo(const ast::BlockStatement* block, Type type, BlockInfo* parent);
|
||||
~BlockInfo();
|
||||
|
||||
template <typename Pred>
|
||||
|
@ -166,6 +166,7 @@ class Resolver {
|
|||
[ty](auto* block_info) { return block_info->type == ty; });
|
||||
}
|
||||
|
||||
ast::BlockStatement const* const block;
|
||||
Type const type;
|
||||
BlockInfo* const parent;
|
||||
std::vector<const ast::Variable*> decls;
|
||||
|
@ -279,7 +280,9 @@ class Resolver {
|
|||
/// its parent, assigns this to #current_block_, and then calls `callback`.
|
||||
/// The original #current_block_ is restored on exit.
|
||||
template <typename F>
|
||||
bool BlockScope(BlockInfo::Type type, F&& callback);
|
||||
bool BlockScope(const ast::BlockStatement* block,
|
||||
BlockInfo::Type type,
|
||||
F&& callback);
|
||||
|
||||
/// Returns a human-readable string representation of the vector type name
|
||||
/// with the given parameters.
|
||||
|
|
|
@ -77,10 +77,10 @@ TEST_F(ResolverTest, Stmt_Case) {
|
|||
auto* rhs = Expr(2.3f);
|
||||
|
||||
auto* assign = create<ast::AssignmentStatement>(lhs, rhs);
|
||||
auto* body = Block(assign);
|
||||
auto* block = Block(assign);
|
||||
ast::CaseSelectorList lit;
|
||||
lit.push_back(create<ast::SintLiteral>(ty.i32(), 3));
|
||||
auto* cse = create<ast::CaseStatement>(lit, body);
|
||||
auto* cse = create<ast::CaseStatement>(lit, block);
|
||||
WrapInFunction(v, cse);
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
|
@ -91,6 +91,7 @@ TEST_F(ResolverTest, Stmt_Case) {
|
|||
EXPECT_TRUE(TypeOf(rhs)->Is<type::F32>());
|
||||
EXPECT_EQ(StmtOf(lhs), assign);
|
||||
EXPECT_EQ(StmtOf(rhs), assign);
|
||||
EXPECT_EQ(BlockOf(assign), block);
|
||||
}
|
||||
|
||||
TEST_F(ResolverTest, Stmt_Block) {
|
||||
|
@ -110,30 +111,9 @@ TEST_F(ResolverTest, Stmt_Block) {
|
|||
EXPECT_TRUE(TypeOf(rhs)->Is<type::F32>());
|
||||
EXPECT_EQ(StmtOf(lhs), assign);
|
||||
EXPECT_EQ(StmtOf(rhs), assign);
|
||||
}
|
||||
|
||||
TEST_F(ResolverTest, Stmt_Else) {
|
||||
auto* v = Var("v", ty.f32(), ast::StorageClass::kFunction);
|
||||
auto* lhs = Expr("v");
|
||||
auto* rhs = Expr(2.3f);
|
||||
|
||||
auto* assign = create<ast::AssignmentStatement>(lhs, rhs);
|
||||
auto* body = Block(assign);
|
||||
auto* cond = Expr(3);
|
||||
auto* stmt = create<ast::ElseStatement>(cond, body);
|
||||
WrapInFunction(v, stmt);
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
|
||||
ASSERT_NE(TypeOf(stmt->condition()), nullptr);
|
||||
ASSERT_NE(TypeOf(lhs), nullptr);
|
||||
ASSERT_NE(TypeOf(rhs), nullptr);
|
||||
EXPECT_TRUE(TypeOf(stmt->condition())->Is<type::I32>());
|
||||
EXPECT_TRUE(TypeOf(lhs)->UnwrapAll()->Is<type::F32>());
|
||||
EXPECT_TRUE(TypeOf(rhs)->Is<type::F32>());
|
||||
EXPECT_EQ(StmtOf(lhs), assign);
|
||||
EXPECT_EQ(StmtOf(rhs), assign);
|
||||
EXPECT_EQ(StmtOf(cond), stmt);
|
||||
EXPECT_EQ(BlockOf(lhs), block);
|
||||
EXPECT_EQ(BlockOf(rhs), block);
|
||||
EXPECT_EQ(BlockOf(assign), block);
|
||||
}
|
||||
|
||||
TEST_F(ResolverTest, Stmt_If) {
|
||||
|
@ -172,6 +152,10 @@ TEST_F(ResolverTest, Stmt_If) {
|
|||
EXPECT_EQ(StmtOf(rhs), assign);
|
||||
EXPECT_EQ(StmtOf(cond), stmt);
|
||||
EXPECT_EQ(StmtOf(else_cond), else_stmt);
|
||||
EXPECT_EQ(BlockOf(lhs), body);
|
||||
EXPECT_EQ(BlockOf(rhs), body);
|
||||
EXPECT_EQ(BlockOf(else_lhs), else_body);
|
||||
EXPECT_EQ(BlockOf(else_rhs), else_body);
|
||||
}
|
||||
|
||||
TEST_F(ResolverTest, Stmt_Loop) {
|
||||
|
@ -199,6 +183,10 @@ TEST_F(ResolverTest, Stmt_Loop) {
|
|||
EXPECT_TRUE(TypeOf(body_rhs)->Is<type::F32>());
|
||||
EXPECT_TRUE(TypeOf(continuing_lhs)->UnwrapAll()->Is<type::F32>());
|
||||
EXPECT_TRUE(TypeOf(continuing_rhs)->Is<type::F32>());
|
||||
EXPECT_EQ(BlockOf(body_lhs), body);
|
||||
EXPECT_EQ(BlockOf(body_rhs), body);
|
||||
EXPECT_EQ(BlockOf(continuing_lhs), continuing);
|
||||
EXPECT_EQ(BlockOf(continuing_rhs), continuing);
|
||||
}
|
||||
|
||||
TEST_F(ResolverTest, Stmt_Return) {
|
||||
|
@ -224,9 +212,8 @@ TEST_F(ResolverTest, Stmt_Switch) {
|
|||
auto* v = Var("v", ty.f32(), ast::StorageClass::kFunction);
|
||||
auto* lhs = Expr("v");
|
||||
auto* rhs = Expr(2.3f);
|
||||
|
||||
auto* stmt =
|
||||
Switch(Expr(2), Case(Literal(3), Block(Assign(lhs, rhs))), DefaultCase());
|
||||
auto* case_block = Block(Assign(lhs, rhs));
|
||||
auto* stmt = Switch(Expr(2), Case(Literal(3), case_block), DefaultCase());
|
||||
WrapInFunction(v, stmt);
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
|
@ -238,6 +225,8 @@ TEST_F(ResolverTest, Stmt_Switch) {
|
|||
EXPECT_TRUE(TypeOf(stmt->condition())->Is<type::I32>());
|
||||
EXPECT_TRUE(TypeOf(lhs)->UnwrapAll()->Is<type::F32>());
|
||||
EXPECT_TRUE(TypeOf(rhs)->Is<type::F32>());
|
||||
EXPECT_EQ(BlockOf(lhs), case_block);
|
||||
EXPECT_EQ(BlockOf(rhs), case_block);
|
||||
}
|
||||
|
||||
TEST_F(ResolverTest, Stmt_Call) {
|
||||
|
|
|
@ -44,11 +44,29 @@ class TestHelper : public ProgramBuilder {
|
|||
/// @param expr the ast::Expression
|
||||
/// @return the ast::Statement of the ast::Expression, or nullptr if the
|
||||
/// expression is not owned by a statement.
|
||||
ast::Statement* StmtOf(ast::Expression* expr) {
|
||||
const ast::Statement* StmtOf(ast::Expression* expr) {
|
||||
auto* sem_stmt = Sem().Get(expr)->Stmt();
|
||||
return sem_stmt ? sem_stmt->Declaration() : nullptr;
|
||||
}
|
||||
|
||||
/// Returns the BlockStatement that holds the given statement.
|
||||
/// @param stmt the ast::Statment
|
||||
/// @return the ast::BlockStatement that holds the ast::Statement, or nullptr
|
||||
/// if the statement is not owned by a BlockStatement.
|
||||
const ast::BlockStatement* BlockOf(ast::Statement* stmt) {
|
||||
auto* sem_stmt = Sem().Get(stmt);
|
||||
return sem_stmt ? sem_stmt->Block() : nullptr;
|
||||
}
|
||||
|
||||
/// Returns the BlockStatement that holds the given expression.
|
||||
/// @param expr the ast::Expression
|
||||
/// @return the ast::Statement of the ast::Expression, or nullptr if the
|
||||
/// expression is not indirectly owned by a BlockStatement.
|
||||
const ast::BlockStatement* BlockOf(ast::Expression* expr) {
|
||||
auto* sem_stmt = Sem().Get(expr)->Stmt();
|
||||
return sem_stmt ? sem_stmt->Block() : nullptr;
|
||||
}
|
||||
|
||||
/// Checks that all the users of the given variable are as expected
|
||||
/// @param var the variable to check
|
||||
/// @param expected_users the expected users of the variable
|
||||
|
|
|
@ -12,6 +12,10 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "src/ast/block_statement.h"
|
||||
#include "src/debug.h"
|
||||
#include "src/semantic/statement.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::semantic::Statement);
|
||||
|
@ -19,7 +23,17 @@ TINT_INSTANTIATE_TYPEINFO(tint::semantic::Statement);
|
|||
namespace tint {
|
||||
namespace semantic {
|
||||
|
||||
Statement::Statement(ast::Statement* declaration) : declaration_(declaration) {}
|
||||
Statement::Statement(const ast::Statement* declaration,
|
||||
const ast::BlockStatement* block)
|
||||
: declaration_(declaration), block_(block) {
|
||||
#ifndef NDEBUG
|
||||
if (block) {
|
||||
auto& stmts = block->statements();
|
||||
TINT_ASSERT(std::find(stmts.begin(), stmts.end(), declaration) !=
|
||||
stmts.end());
|
||||
}
|
||||
#endif // NDEBUG
|
||||
}
|
||||
|
||||
} // namespace semantic
|
||||
} // namespace tint
|
||||
|
|
|
@ -21,6 +21,7 @@ namespace tint {
|
|||
|
||||
// Forward declarations
|
||||
namespace ast {
|
||||
class BlockStatement;
|
||||
class Statement;
|
||||
} // namespace ast
|
||||
|
||||
|
@ -31,13 +32,19 @@ class Statement : public Castable<Statement, Node> {
|
|||
public:
|
||||
/// Constructor
|
||||
/// @param declaration the AST node for this statement
|
||||
explicit Statement(ast::Statement* declaration);
|
||||
/// @param block the owning AST block statement
|
||||
Statement(const ast::Statement* declaration,
|
||||
const ast::BlockStatement* block);
|
||||
|
||||
/// @return the AST node for this statement
|
||||
ast::Statement* Declaration() const { return declaration_; }
|
||||
const ast::Statement* Declaration() const { return declaration_; }
|
||||
|
||||
/// @return the owning AST block statement for this statement
|
||||
const ast::BlockStatement* Block() const { return block_; }
|
||||
|
||||
private:
|
||||
ast::Statement* const declaration_;
|
||||
ast::Statement const* const declaration_;
|
||||
ast::BlockStatement const* const block_;
|
||||
};
|
||||
|
||||
} // namespace semantic
|
||||
|
|
|
@ -25,6 +25,7 @@ class CallExpression;
|
|||
class Expression;
|
||||
class Function;
|
||||
class MemberAccessorExpression;
|
||||
class Statement;
|
||||
class StructMember;
|
||||
class Variable;
|
||||
} // namespace ast
|
||||
|
@ -41,6 +42,7 @@ class Call;
|
|||
class Expression;
|
||||
class Function;
|
||||
class MemberAccessorExpression;
|
||||
class Statement;
|
||||
class Struct;
|
||||
class StructMember;
|
||||
class Variable;
|
||||
|
@ -56,6 +58,7 @@ struct TypeMappings {
|
|||
Expression* operator()(ast::Expression*);
|
||||
Function* operator()(ast::Function*);
|
||||
MemberAccessorExpression* operator()(ast::MemberAccessorExpression*);
|
||||
Statement* operator()(ast::Statement*);
|
||||
Struct* operator()(type::Struct*);
|
||||
StructMember* operator()(ast::StructMember*);
|
||||
Variable* operator()(ast::Variable*);
|
||||
|
|
Loading…
Reference in New Issue