Add semantic::Statement::Block() for getting the owning BlockStatement

Required special casing the ElseStatement, as this isn't actually owned by a BlockStatement.

Change-Id: Ic33c207598b838a12b865a7694e596b2629c9208
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46443
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2021-03-31 20:35:46 +00:00 committed by Commit Bot service account
parent 285b8b6e75
commit dba65b7a34
8 changed files with 99 additions and 49 deletions

View File

@ -65,6 +65,9 @@ class BlockStatement : public Castable<BlockStatement, Statement> {
/// @returns the ending iterator /// @returns the ending iterator
StatementList::const_iterator end() const { return statements_.end(); } StatementList::const_iterator end() const { return statements_.end(); }
/// @returns the statement list
const StatementList& statements() const { return statements_; }
/// Clones this node and all transitive child nodes using the `CloneContext` /// Clones this node and all transitive child nodes using the `CloneContext`
/// `ctx`. /// `ctx`.
/// @param ctx the clone context /// @param ctx the clone context

View File

@ -78,9 +78,10 @@ Resolver::Resolver(ProgramBuilder* builder)
Resolver::~Resolver() = default; Resolver::~Resolver() = default;
Resolver::BlockInfo::BlockInfo(Resolver::BlockInfo::Type ty, Resolver::BlockInfo::BlockInfo(const ast::BlockStatement* b,
Resolver::BlockInfo::Type ty,
Resolver::BlockInfo* p) Resolver::BlockInfo* p)
: type(ty), parent(p) {} : block(b), type(ty), parent(p) {}
Resolver::BlockInfo::~BlockInfo() = default; Resolver::BlockInfo::~BlockInfo() = default;
@ -370,7 +371,7 @@ bool Resolver::Function(ast::Function* func) {
} }
bool Resolver::BlockStatement(const ast::BlockStatement* stmt) { bool Resolver::BlockStatement(const ast::BlockStatement* stmt) {
return BlockScope(BlockInfo::Type::kGeneric, return BlockScope(stmt, BlockInfo::Type::kGeneric,
[&] { return Statements(stmt->list()); }); [&] { return Statements(stmt->list()); });
} }
@ -384,7 +385,9 @@ bool Resolver::Statements(const ast::StatementList& stmts) {
} }
bool Resolver::Statement(ast::Statement* stmt) { bool Resolver::Statement(ast::Statement* stmt) {
auto* sem_statement = builder_->create<semantic::Statement>(stmt); auto* sem_statement =
builder_->create<semantic::Statement>(stmt, current_block_->block);
builder_->Sem().Add(stmt, sem_statement);
ScopedAssignment<semantic::Statement*> sa(current_statement_, sem_statement); ScopedAssignment<semantic::Statement*> sa(current_statement_, sem_statement);
@ -427,9 +430,6 @@ bool Resolver::Statement(ast::Statement* stmt) {
if (stmt->Is<ast::DiscardStatement>()) { if (stmt->Is<ast::DiscardStatement>()) {
return true; return true;
} }
if (auto* e = stmt->As<ast::ElseStatement>()) {
return Expression(e->condition()) && BlockStatement(e->body());
}
if (stmt->Is<ast::FallthroughStatement>()) { if (stmt->Is<ast::FallthroughStatement>()) {
return true; return true;
} }
@ -441,13 +441,13 @@ bool Resolver::Statement(ast::Statement* stmt) {
// these would make their BlockInfo siblings as in the AST, but we want the // these would make their BlockInfo siblings as in the AST, but we want the
// body BlockInfo to parent the continuing BlockInfo for semantics and // body BlockInfo to parent the continuing BlockInfo for semantics and
// validation. Also, we need to set their types differently. // validation. Also, we need to set their types differently.
return BlockScope(BlockInfo::Type::kLoop, [&] { return BlockScope(l->body(), BlockInfo::Type::kLoop, [&] {
if (!Statements(l->body()->list())) { if (!Statements(l->body()->list())) {
return false; return false;
} }
if (l->has_continuing()) { if (l->has_continuing()) {
if (!BlockScope(BlockInfo::Type::kLoopContinuing, if (!BlockScope(l->continuing(), BlockInfo::Type::kLoopContinuing,
[&] { return Statements(l->continuing()->list()); })) { [&] { return Statements(l->continuing()->list()); })) {
return false; return false;
} }
@ -473,7 +473,7 @@ bool Resolver::Statement(ast::Statement* stmt) {
} }
bool Resolver::CaseStatement(ast::CaseStatement* stmt) { bool Resolver::CaseStatement(ast::CaseStatement* stmt) {
return BlockScope(BlockInfo::Type::kSwitchCase, return BlockScope(stmt->body(), BlockInfo::Type::kSwitchCase,
[&] { return Statements(stmt->body()->list()); }); [&] { return Statements(stmt->body()->list()); });
} }
@ -495,7 +495,18 @@ bool Resolver::IfStatement(ast::IfStatement* stmt) {
} }
for (auto* else_stmt : stmt->else_statements()) { for (auto* else_stmt : stmt->else_statements()) {
if (!Statement(else_stmt)) { // Else statements are a bit unusual - they're owned by the if-statement,
// not a BlockStatement.
constexpr ast::BlockStatement* no_block_statement = nullptr;
auto* sem_else_stmt =
builder_->create<semantic::Statement>(else_stmt, no_block_statement);
builder_->Sem().Add(else_stmt, sem_else_stmt);
ScopedAssignment<semantic::Statement*> sa(current_statement_,
sem_else_stmt);
if (!Expression(else_stmt->condition())) {
return false;
}
if (!BlockStatement(else_stmt->body())) {
return false; return false;
} }
} }
@ -1923,8 +1934,10 @@ bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
} }
template <typename F> template <typename F>
bool Resolver::BlockScope(BlockInfo::Type type, F&& callback) { bool Resolver::BlockScope(const ast::BlockStatement* block,
BlockInfo block_info(type, current_block_); BlockInfo::Type type,
F&& callback) {
BlockInfo block_info(block, type, current_block_);
ScopedAssignment<BlockInfo*> sa(current_block_, &block_info); ScopedAssignment<BlockInfo*> sa(current_block_, &block_info);
variable_stack_.push_scope(); variable_stack_.push_scope();
bool result = callback(); bool result = callback();

View File

@ -149,7 +149,7 @@ class Resolver {
struct BlockInfo { struct BlockInfo {
enum class Type { kGeneric, kLoop, kLoopContinuing, kSwitchCase }; enum class Type { kGeneric, kLoop, kLoopContinuing, kSwitchCase };
BlockInfo(Type type, BlockInfo* parent); BlockInfo(const ast::BlockStatement* block, Type type, BlockInfo* parent);
~BlockInfo(); ~BlockInfo();
template <typename Pred> template <typename Pred>
@ -166,6 +166,7 @@ class Resolver {
[ty](auto* block_info) { return block_info->type == ty; }); [ty](auto* block_info) { return block_info->type == ty; });
} }
ast::BlockStatement const* const block;
Type const type; Type const type;
BlockInfo* const parent; BlockInfo* const parent;
std::vector<const ast::Variable*> decls; std::vector<const ast::Variable*> decls;
@ -279,7 +280,9 @@ class Resolver {
/// its parent, assigns this to #current_block_, and then calls `callback`. /// its parent, assigns this to #current_block_, and then calls `callback`.
/// The original #current_block_ is restored on exit. /// The original #current_block_ is restored on exit.
template <typename F> template <typename F>
bool BlockScope(BlockInfo::Type type, F&& callback); bool BlockScope(const ast::BlockStatement* block,
BlockInfo::Type type,
F&& callback);
/// Returns a human-readable string representation of the vector type name /// Returns a human-readable string representation of the vector type name
/// with the given parameters. /// with the given parameters.

View File

@ -77,10 +77,10 @@ TEST_F(ResolverTest, Stmt_Case) {
auto* rhs = Expr(2.3f); auto* rhs = Expr(2.3f);
auto* assign = create<ast::AssignmentStatement>(lhs, rhs); auto* assign = create<ast::AssignmentStatement>(lhs, rhs);
auto* body = Block(assign); auto* block = Block(assign);
ast::CaseSelectorList lit; ast::CaseSelectorList lit;
lit.push_back(create<ast::SintLiteral>(ty.i32(), 3)); lit.push_back(create<ast::SintLiteral>(ty.i32(), 3));
auto* cse = create<ast::CaseStatement>(lit, body); auto* cse = create<ast::CaseStatement>(lit, block);
WrapInFunction(v, cse); WrapInFunction(v, cse);
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
@ -91,6 +91,7 @@ TEST_F(ResolverTest, Stmt_Case) {
EXPECT_TRUE(TypeOf(rhs)->Is<type::F32>()); EXPECT_TRUE(TypeOf(rhs)->Is<type::F32>());
EXPECT_EQ(StmtOf(lhs), assign); EXPECT_EQ(StmtOf(lhs), assign);
EXPECT_EQ(StmtOf(rhs), assign); EXPECT_EQ(StmtOf(rhs), assign);
EXPECT_EQ(BlockOf(assign), block);
} }
TEST_F(ResolverTest, Stmt_Block) { TEST_F(ResolverTest, Stmt_Block) {
@ -110,30 +111,9 @@ TEST_F(ResolverTest, Stmt_Block) {
EXPECT_TRUE(TypeOf(rhs)->Is<type::F32>()); EXPECT_TRUE(TypeOf(rhs)->Is<type::F32>());
EXPECT_EQ(StmtOf(lhs), assign); EXPECT_EQ(StmtOf(lhs), assign);
EXPECT_EQ(StmtOf(rhs), assign); EXPECT_EQ(StmtOf(rhs), assign);
} EXPECT_EQ(BlockOf(lhs), block);
EXPECT_EQ(BlockOf(rhs), block);
TEST_F(ResolverTest, Stmt_Else) { EXPECT_EQ(BlockOf(assign), block);
auto* v = Var("v", ty.f32(), ast::StorageClass::kFunction);
auto* lhs = Expr("v");
auto* rhs = Expr(2.3f);
auto* assign = create<ast::AssignmentStatement>(lhs, rhs);
auto* body = Block(assign);
auto* cond = Expr(3);
auto* stmt = create<ast::ElseStatement>(cond, body);
WrapInFunction(v, stmt);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(stmt->condition()), nullptr);
ASSERT_NE(TypeOf(lhs), nullptr);
ASSERT_NE(TypeOf(rhs), nullptr);
EXPECT_TRUE(TypeOf(stmt->condition())->Is<type::I32>());
EXPECT_TRUE(TypeOf(lhs)->UnwrapAll()->Is<type::F32>());
EXPECT_TRUE(TypeOf(rhs)->Is<type::F32>());
EXPECT_EQ(StmtOf(lhs), assign);
EXPECT_EQ(StmtOf(rhs), assign);
EXPECT_EQ(StmtOf(cond), stmt);
} }
TEST_F(ResolverTest, Stmt_If) { TEST_F(ResolverTest, Stmt_If) {
@ -172,6 +152,10 @@ TEST_F(ResolverTest, Stmt_If) {
EXPECT_EQ(StmtOf(rhs), assign); EXPECT_EQ(StmtOf(rhs), assign);
EXPECT_EQ(StmtOf(cond), stmt); EXPECT_EQ(StmtOf(cond), stmt);
EXPECT_EQ(StmtOf(else_cond), else_stmt); EXPECT_EQ(StmtOf(else_cond), else_stmt);
EXPECT_EQ(BlockOf(lhs), body);
EXPECT_EQ(BlockOf(rhs), body);
EXPECT_EQ(BlockOf(else_lhs), else_body);
EXPECT_EQ(BlockOf(else_rhs), else_body);
} }
TEST_F(ResolverTest, Stmt_Loop) { TEST_F(ResolverTest, Stmt_Loop) {
@ -199,6 +183,10 @@ TEST_F(ResolverTest, Stmt_Loop) {
EXPECT_TRUE(TypeOf(body_rhs)->Is<type::F32>()); EXPECT_TRUE(TypeOf(body_rhs)->Is<type::F32>());
EXPECT_TRUE(TypeOf(continuing_lhs)->UnwrapAll()->Is<type::F32>()); EXPECT_TRUE(TypeOf(continuing_lhs)->UnwrapAll()->Is<type::F32>());
EXPECT_TRUE(TypeOf(continuing_rhs)->Is<type::F32>()); EXPECT_TRUE(TypeOf(continuing_rhs)->Is<type::F32>());
EXPECT_EQ(BlockOf(body_lhs), body);
EXPECT_EQ(BlockOf(body_rhs), body);
EXPECT_EQ(BlockOf(continuing_lhs), continuing);
EXPECT_EQ(BlockOf(continuing_rhs), continuing);
} }
TEST_F(ResolverTest, Stmt_Return) { TEST_F(ResolverTest, Stmt_Return) {
@ -224,9 +212,8 @@ TEST_F(ResolverTest, Stmt_Switch) {
auto* v = Var("v", ty.f32(), ast::StorageClass::kFunction); auto* v = Var("v", ty.f32(), ast::StorageClass::kFunction);
auto* lhs = Expr("v"); auto* lhs = Expr("v");
auto* rhs = Expr(2.3f); auto* rhs = Expr(2.3f);
auto* case_block = Block(Assign(lhs, rhs));
auto* stmt = auto* stmt = Switch(Expr(2), Case(Literal(3), case_block), DefaultCase());
Switch(Expr(2), Case(Literal(3), Block(Assign(lhs, rhs))), DefaultCase());
WrapInFunction(v, stmt); WrapInFunction(v, stmt);
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
@ -238,6 +225,8 @@ TEST_F(ResolverTest, Stmt_Switch) {
EXPECT_TRUE(TypeOf(stmt->condition())->Is<type::I32>()); EXPECT_TRUE(TypeOf(stmt->condition())->Is<type::I32>());
EXPECT_TRUE(TypeOf(lhs)->UnwrapAll()->Is<type::F32>()); EXPECT_TRUE(TypeOf(lhs)->UnwrapAll()->Is<type::F32>());
EXPECT_TRUE(TypeOf(rhs)->Is<type::F32>()); EXPECT_TRUE(TypeOf(rhs)->Is<type::F32>());
EXPECT_EQ(BlockOf(lhs), case_block);
EXPECT_EQ(BlockOf(rhs), case_block);
} }
TEST_F(ResolverTest, Stmt_Call) { TEST_F(ResolverTest, Stmt_Call) {

View File

@ -44,11 +44,29 @@ class TestHelper : public ProgramBuilder {
/// @param expr the ast::Expression /// @param expr the ast::Expression
/// @return the ast::Statement of the ast::Expression, or nullptr if the /// @return the ast::Statement of the ast::Expression, or nullptr if the
/// expression is not owned by a statement. /// expression is not owned by a statement.
ast::Statement* StmtOf(ast::Expression* expr) { const ast::Statement* StmtOf(ast::Expression* expr) {
auto* sem_stmt = Sem().Get(expr)->Stmt(); auto* sem_stmt = Sem().Get(expr)->Stmt();
return sem_stmt ? sem_stmt->Declaration() : nullptr; return sem_stmt ? sem_stmt->Declaration() : nullptr;
} }
/// Returns the BlockStatement that holds the given statement.
/// @param stmt the ast::Statment
/// @return the ast::BlockStatement that holds the ast::Statement, or nullptr
/// if the statement is not owned by a BlockStatement.
const ast::BlockStatement* BlockOf(ast::Statement* stmt) {
auto* sem_stmt = Sem().Get(stmt);
return sem_stmt ? sem_stmt->Block() : nullptr;
}
/// Returns the BlockStatement that holds the given expression.
/// @param expr the ast::Expression
/// @return the ast::Statement of the ast::Expression, or nullptr if the
/// expression is not indirectly owned by a BlockStatement.
const ast::BlockStatement* BlockOf(ast::Expression* expr) {
auto* sem_stmt = Sem().Get(expr)->Stmt();
return sem_stmt ? sem_stmt->Block() : nullptr;
}
/// Checks that all the users of the given variable are as expected /// Checks that all the users of the given variable are as expected
/// @param var the variable to check /// @param var the variable to check
/// @param expected_users the expected users of the variable /// @param expected_users the expected users of the variable

View File

@ -12,6 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <algorithm>
#include "src/ast/block_statement.h"
#include "src/debug.h"
#include "src/semantic/statement.h" #include "src/semantic/statement.h"
TINT_INSTANTIATE_TYPEINFO(tint::semantic::Statement); TINT_INSTANTIATE_TYPEINFO(tint::semantic::Statement);
@ -19,7 +23,17 @@ TINT_INSTANTIATE_TYPEINFO(tint::semantic::Statement);
namespace tint { namespace tint {
namespace semantic { namespace semantic {
Statement::Statement(ast::Statement* declaration) : declaration_(declaration) {} Statement::Statement(const ast::Statement* declaration,
const ast::BlockStatement* block)
: declaration_(declaration), block_(block) {
#ifndef NDEBUG
if (block) {
auto& stmts = block->statements();
TINT_ASSERT(std::find(stmts.begin(), stmts.end(), declaration) !=
stmts.end());
}
#endif // NDEBUG
}
} // namespace semantic } // namespace semantic
} // namespace tint } // namespace tint

View File

@ -21,6 +21,7 @@ namespace tint {
// Forward declarations // Forward declarations
namespace ast { namespace ast {
class BlockStatement;
class Statement; class Statement;
} // namespace ast } // namespace ast
@ -31,13 +32,19 @@ class Statement : public Castable<Statement, Node> {
public: public:
/// Constructor /// Constructor
/// @param declaration the AST node for this statement /// @param declaration the AST node for this statement
explicit Statement(ast::Statement* declaration); /// @param block the owning AST block statement
Statement(const ast::Statement* declaration,
const ast::BlockStatement* block);
/// @return the AST node for this statement /// @return the AST node for this statement
ast::Statement* Declaration() const { return declaration_; } const ast::Statement* Declaration() const { return declaration_; }
/// @return the owning AST block statement for this statement
const ast::BlockStatement* Block() const { return block_; }
private: private:
ast::Statement* const declaration_; ast::Statement const* const declaration_;
ast::BlockStatement const* const block_;
}; };
} // namespace semantic } // namespace semantic

View File

@ -25,6 +25,7 @@ class CallExpression;
class Expression; class Expression;
class Function; class Function;
class MemberAccessorExpression; class MemberAccessorExpression;
class Statement;
class StructMember; class StructMember;
class Variable; class Variable;
} // namespace ast } // namespace ast
@ -41,6 +42,7 @@ class Call;
class Expression; class Expression;
class Function; class Function;
class MemberAccessorExpression; class MemberAccessorExpression;
class Statement;
class Struct; class Struct;
class StructMember; class StructMember;
class Variable; class Variable;
@ -56,6 +58,7 @@ struct TypeMappings {
Expression* operator()(ast::Expression*); Expression* operator()(ast::Expression*);
Function* operator()(ast::Function*); Function* operator()(ast::Function*);
MemberAccessorExpression* operator()(ast::MemberAccessorExpression*); MemberAccessorExpression* operator()(ast::MemberAccessorExpression*);
Statement* operator()(ast::Statement*);
Struct* operator()(type::Struct*); Struct* operator()(type::Struct*);
StructMember* operator()(ast::StructMember*); StructMember* operator()(ast::StructMember*);
Variable* operator()(ast::Variable*); Variable* operator()(ast::Variable*);