Resolver: Validate if() conditions are bools

Fixed: tint:317
Change-Id: Ica56dfb12d6b1a45e61d0d791f723414b464da5f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/44163
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Ben Clayton 2021-03-09 15:17:28 +00:00 committed by Commit Bot service account
parent 2c41f4fbdf
commit 5fb87dd915
4 changed files with 42 additions and 12 deletions

View File

@ -254,16 +254,7 @@ bool Resolver::Statement(ast::Statement* stmt) {
return true;
}
if (auto* i = stmt->As<ast::IfStatement>()) {
if (!Expression(i->condition()) || !BlockStatement(i->body())) {
return false;
}
for (auto* else_stmt : i->else_statements()) {
if (!Statement(else_stmt)) {
return false;
}
}
return true;
return IfStatement(i);
}
if (auto* l = stmt->As<ast::LoopStatement>()) {
// We don't call DetermineBlockStatement on the body and continuing block as
@ -317,6 +308,31 @@ bool Resolver::CaseStatement(ast::CaseStatement* stmt) {
[&] { return Statements(stmt->body()->list()); });
}
bool Resolver::IfStatement(ast::IfStatement* stmt) {
if (!Expression(stmt->condition())) {
return false;
}
auto* cond_type = TypeOf(stmt->condition())->UnwrapAll();
if (cond_type != builder_->ty.bool_()) {
diagnostics_.add_error("if statement condition must be bool, got " +
cond_type->FriendlyName(builder_->Symbols()),
stmt->condition()->source());
return false;
}
if (!BlockStatement(stmt->body())) {
return false;
}
for (auto* else_stmt : stmt->else_statements()) {
if (!Statement(else_stmt)) {
return false;
}
}
return true;
}
bool Resolver::Expressions(const ast::ExpressionList& list) {
for (auto* expr : list) {
if (!Expression(expr)) {

View File

@ -202,6 +202,7 @@ class Resolver {
bool CaseStatement(ast::CaseStatement* stmt);
bool Constructor(ast::ConstructorExpression* expr);
bool Identifier(ast::IdentifierExpression* expr);
bool IfStatement(ast::IfStatement* stmt);
bool IntrinsicCall(ast::CallExpression* call,
semantic::IntrinsicType intrinsic_type);
bool MemberAccessor(ast::MemberAccessorExpression* expr);

View File

@ -146,7 +146,7 @@ TEST_F(ResolverTest, Stmt_If) {
auto* assign = create<ast::AssignmentStatement>(lhs, rhs);
auto* body = create<ast::BlockStatement>(ast::StatementList{assign});
auto* cond = Expr(3);
auto* cond = Expr(true);
auto* stmt =
create<ast::IfStatement>(cond, body, ast::ElseStatementList{else_stmt});
WrapInFunction(stmt);
@ -158,7 +158,7 @@ TEST_F(ResolverTest, Stmt_If) {
ASSERT_NE(TypeOf(else_rhs), nullptr);
ASSERT_NE(TypeOf(lhs), nullptr);
ASSERT_NE(TypeOf(rhs), nullptr);
EXPECT_TRUE(TypeOf(stmt->condition())->Is<type::I32>());
EXPECT_TRUE(TypeOf(stmt->condition())->Is<type::Bool>());
EXPECT_TRUE(TypeOf(else_lhs)->Is<type::I32>());
EXPECT_TRUE(TypeOf(else_rhs)->Is<type::F32>());
EXPECT_TRUE(TypeOf(lhs)->Is<type::I32>());

View File

@ -133,6 +133,19 @@ TEST_F(ResolverValidationTest, Stmt_Call_recursive) {
"itself.");
}
TEST_F(ResolverValidationTest, Stmt_If_NonBool) {
// if (1.23f) {}
WrapInFunction(If(create<ast::ScalarConstructorExpression>(Source{{12, 34}},
Literal(1.23f)),
Block()));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: if statement condition must be bool, got f32");
}
TEST_F(ResolverValidationTest,
Stmt_VariableDecl_MismatchedTypeScalarConstructor) {
u32 unsigned_value = 2u; // Type does not match variable type