diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 14bcb195ff..8bf9da499b 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -630,8 +630,12 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) { return result_id; } -bool Builder::GenerateIfStatement(ast::IfStatement* stmt) { - auto cond_id = GenerateExpression(stmt->condition()); +bool Builder::GenerateConditionalBlock( + ast::Expression* cond, + const ast::StatementList& true_body, + size_t cur_else_idx, + const ast::ElseStatementList& else_stmts) { + auto cond_id = GenerateExpression(cond); if (cond_id == 0) { return false; } @@ -646,10 +650,10 @@ bool Builder::GenerateIfStatement(ast::IfStatement* stmt) { auto true_block = result_op(); auto true_block_id = true_block.to_i(); - // if there are no else statements we branch on false to the merge block + // if there are no more else statements we branch on false to the merge block // otherwise we branch to the false block auto false_block_id = - stmt->has_else_statements() ? next_id() : merge_block_id; + cur_else_idx < else_stmts.size() ? next_id() : merge_block_id; push_function_inst(spv::Op::OpBranchConditional, {Operand::Int(cond_id), Operand::Int(true_block_id), @@ -657,28 +661,33 @@ bool Builder::GenerateIfStatement(ast::IfStatement* stmt) { // Output true block push_function_inst(spv::Op::OpLabel, {true_block}); - for (const auto& inst : stmt->body()) { - if (!GenerateStatement(inst.get())) { - return false; - } + if (!GenerateStatementList(true_body)) { + return false; } - // TODO(dsinclair): The branch should be optional based on how the // StatementList ended ... push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)}); + // Start the false block if needed if (false_block_id != merge_block_id) { push_function_inst(spv::Op::OpLabel, {Operand::Int(false_block_id)}); - for (const auto& else_stmt : stmt->else_statements()) { - if (!GenerateElseStatement(else_stmt.get())) { + auto* else_stmt = else_stmts[cur_else_idx].get(); + // Handle the else case by just outputting the statements. + if (!else_stmt->HasCondition()) { + if (!GenerateStatementList(else_stmt->body())) { return false; } + // TODO(dsinclair): The branch should be optional based on how the + // StatementList ended ... + push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)}); + } else { + if (!GenerateConditionalBlock(else_stmt->condition(), else_stmt->body(), + cur_else_idx + 1, else_stmts)) { + return false; + } + push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)}); } - - // TODO(dsinclair): The branch should be optional based on how the - // StatementList ended ... - push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)}); } // Output the merge block @@ -687,18 +696,11 @@ bool Builder::GenerateIfStatement(ast::IfStatement* stmt) { return true; } -bool Builder::GenerateElseStatement(ast::ElseStatement* stmt) { - // TODO(dsinclair): handle else if - if (stmt->HasCondition()) { - error_ = "else if not handled yet"; +bool Builder::GenerateIfStatement(ast::IfStatement* stmt) { + if (!GenerateConditionalBlock(stmt->condition(), stmt->body(), 0, + stmt->else_statements())) { return false; } - - for (const auto& inst : stmt->body()) { - if (!GenerateStatement(inst.get())) { - return false; - } - } return true; } @@ -716,6 +718,15 @@ bool Builder::GenerateReturnStatement(ast::ReturnStatement* stmt) { return true; } +bool Builder::GenerateStatementList(const ast::StatementList& list) { + for (const auto& inst : list) { + if (!GenerateStatement(inst.get())) { + return false; + } + } + return true; +} + bool Builder::GenerateStatement(ast::Statement* stmt) { if (stmt->IsAssign()) { return GenerateAssignStatement(stmt->AsAssign()); diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h index dab1e3edc4..0faf52d1ec 100644 --- a/src/writer/spirv/builder.h +++ b/src/writer/spirv/builder.h @@ -22,6 +22,7 @@ #include "spirv/unified1/spirv.h" #include "src/ast/builtin.h" +#include "src/ast/else_statement.h" #include "src/ast/literal.h" #include "src/ast/module.h" #include "src/ast/struct_member.h" @@ -148,10 +149,6 @@ class Builder { /// @param assign the statement to generate /// @returns true if the statement was successfully generated bool GenerateAssignStatement(ast::AssignmentStatement* assign); - /// Generates an else statement - /// @param stmt the statement to generate - /// @returns true on successfull generation - bool GenerateElseStatement(ast::ElseStatement* stmt); /// Generates an entry point instruction /// @param ep the entry point /// @returns true if the instruction was generated, false otherwise @@ -209,10 +206,24 @@ class Builder { /// @param stmt the statement to generate /// @returns true on success, false otherwise bool GenerateReturnStatement(ast::ReturnStatement* stmt); + /// Generates a conditional section merge block + /// @param cond the condition + /// @param true_body the statements making up the true block + /// @param cur_else_idx the index of the current else statement to process + /// @param else_stmts the list of all else statements + /// @returns true on success, false on failure + bool GenerateConditionalBlock(ast::Expression* cond, + const ast::StatementList& true_body, + size_t cur_else_idx, + const ast::ElseStatementList& else_stmts); /// Generates a statement /// @param stmt the statement to generate /// @returns true if the statement was generated bool GenerateStatement(ast::Statement* stmt); + /// Generates a list of statements + /// @param list the statement list to generate + /// @returns true on successful generation + bool GenerateStatementList(const ast::StatementList& list); /// Geneates an OpStore /// @param to the ID to store too /// @param from the ID to store from diff --git a/src/writer/spirv/builder_if_test.cc b/src/writer/spirv/builder_if_test.cc index 15d6809e94..e7f8bfb3bf 100644 --- a/src/writer/spirv/builder_if_test.cc +++ b/src/writer/spirv/builder_if_test.cc @@ -39,6 +39,8 @@ using BuilderTest = testing::Test; TEST_F(BuilderTest, If_Empty) { ast::type::BoolType bool_type; + // if (true) { + // } auto cond = std::make_unique( std::make_unique(&bool_type, true)); @@ -68,6 +70,9 @@ TEST_F(BuilderTest, If_WithStatements) { ast::type::BoolType bool_type; ast::type::I32Type i32; + // if (true) { + // v = 2; + // } auto var = std::make_unique("v", ast::StorageClass::kPrivate, &i32); @@ -114,6 +119,11 @@ TEST_F(BuilderTest, If_WithElse) { ast::type::BoolType bool_type; ast::type::I32Type i32; + // if (true) { + // v = 2; + // } else { + // v = 3; + // } auto var = std::make_unique("v", ast::StorageClass::kPrivate, &i32); @@ -171,9 +181,186 @@ OpBranch %6 )"); } -TEST_F(BuilderTest, DISABLED_If_WithElseIf) {} +TEST_F(BuilderTest, If_WithElseIf) { + ast::type::BoolType bool_type; + ast::type::I32Type i32; -TEST_F(BuilderTest, DISABLED_If_WithMultiple) {} + // if (true) { + // v = 2; + // } elseif (true) { + // v = 3; + // } + auto var = + std::make_unique("v", ast::StorageClass::kPrivate, &i32); + + ast::StatementList body; + body.push_back(std::make_unique( + std::make_unique("v"), + std::make_unique( + std::make_unique(&i32, 2)))); + + ast::StatementList else_body; + else_body.push_back(std::make_unique( + std::make_unique("v"), + std::make_unique( + std::make_unique(&i32, 3)))); + + auto else_cond = std::make_unique( + std::make_unique(&bool_type, true)); + + ast::ElseStatementList else_stmts; + else_stmts.push_back(std::make_unique( + std::move(else_cond), std::move(else_body))); + + auto cond = std::make_unique( + std::make_unique(&bool_type, true)); + + ast::IfStatement expr(std::move(cond), std::move(body)); + expr.set_else_statements(std::move(else_stmts)); + + Context ctx; + TypeDeterminer td(&ctx); + td.RegisterVariableForTesting(var.get()); + + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + Builder b; + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); + + EXPECT_TRUE(b.GenerateIfStatement(&expr)) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeInt 32 1 +%2 = OpTypePointer Private %3 +%1 = OpVariable %2 Private +%4 = OpTypeBool +%5 = OpConstantTrue %4 +%9 = OpConstant %3 2 +%12 = OpConstant %3 3 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(OpSelectionMerge %6 None +OpBranchConditional %5 %7 %8 +%7 = OpLabel +OpStore %1 %9 +OpBranch %6 +%8 = OpLabel +OpSelectionMerge %10 None +OpBranchConditional %5 %11 %10 +%11 = OpLabel +OpStore %1 %12 +OpBranch %10 +%10 = OpLabel +OpBranch %6 +%6 = OpLabel +)"); +} + +TEST_F(BuilderTest, If_WithMultiple) { + ast::type::BoolType bool_type; + ast::type::I32Type i32; + + // if (true) { + // v = 2; + // } elseif (true) { + // v = 3; + // } elseif (false) { + // v = 4; + // } else { + // v = 5; + // } + auto var = + std::make_unique("v", ast::StorageClass::kPrivate, &i32); + + ast::StatementList body; + body.push_back(std::make_unique( + std::make_unique("v"), + std::make_unique( + std::make_unique(&i32, 2)))); + ast::StatementList elseif_1_body; + elseif_1_body.push_back(std::make_unique( + std::make_unique("v"), + std::make_unique( + std::make_unique(&i32, 3)))); + ast::StatementList elseif_2_body; + elseif_2_body.push_back(std::make_unique( + std::make_unique("v"), + std::make_unique( + std::make_unique(&i32, 4)))); + ast::StatementList else_body; + else_body.push_back(std::make_unique( + std::make_unique("v"), + std::make_unique( + std::make_unique(&i32, 5)))); + + auto elseif_1_cond = std::make_unique( + std::make_unique(&bool_type, true)); + auto elseif_2_cond = std::make_unique( + std::make_unique(&bool_type, false)); + + ast::ElseStatementList else_stmts; + else_stmts.push_back(std::make_unique( + std::move(elseif_1_cond), std::move(elseif_1_body))); + else_stmts.push_back(std::make_unique( + std::move(elseif_2_cond), std::move(elseif_2_body))); + else_stmts.push_back( + std::make_unique(std::move(else_body))); + + auto cond = std::make_unique( + std::make_unique(&bool_type, true)); + + ast::IfStatement expr(std::move(cond), std::move(body)); + expr.set_else_statements(std::move(else_stmts)); + + Context ctx; + TypeDeterminer td(&ctx); + td.RegisterVariableForTesting(var.get()); + + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + Builder b; + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); + + EXPECT_TRUE(b.GenerateIfStatement(&expr)) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeInt 32 1 +%2 = OpTypePointer Private %3 +%1 = OpVariable %2 Private +%4 = OpTypeBool +%5 = OpConstantTrue %4 +%9 = OpConstant %3 2 +%13 = OpConstant %3 3 +%14 = OpConstantFalse %4 +%18 = OpConstant %3 4 +%19 = OpConstant %3 5 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(OpSelectionMerge %6 None +OpBranchConditional %5 %7 %8 +%7 = OpLabel +OpStore %1 %9 +OpBranch %6 +%8 = OpLabel +OpSelectionMerge %10 None +OpBranchConditional %5 %11 %12 +%11 = OpLabel +OpStore %1 %13 +OpBranch %10 +%12 = OpLabel +OpSelectionMerge %15 None +OpBranchConditional %14 %16 %17 +%16 = OpLabel +OpStore %1 %18 +OpBranch %15 +%17 = OpLabel +OpStore %1 %19 +OpBranch %15 +%15 = OpLabel +OpBranch %10 +%10 = OpLabel +OpBranch %6 +%6 = OpLabel +)"); +} TEST_F(BuilderTest, DISABLED_If_WithBreak) { // if (a) {