diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 566b8c7968..14bcb195ff 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -24,6 +24,7 @@ #include "src/ast/builtin_decoration.h" #include "src/ast/constructor_expression.h" #include "src/ast/decorated_variable.h" +#include "src/ast/else_statement.h" #include "src/ast/float_literal.h" #include "src/ast/identifier_expression.h" #include "src/ast/if_statement.h" @@ -658,19 +659,26 @@ bool Builder::GenerateIfStatement(ast::IfStatement* stmt) { push_function_inst(spv::Op::OpLabel, {true_block}); for (const auto& inst : stmt->body()) { if (!GenerateStatement(inst.get())) { - return 0; + return false; } } // TODO(dsinclair): The branch should be optional based on how the // StatementList ended ... - push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)}); if (false_block_id != merge_block_id) { push_function_inst(spv::Op::OpLabel, {Operand::Int(false_block_id)}); - // TODO(dsinclair): Output else statements, pass in merge_block_id? + for (const auto& else_stmt : stmt->else_statements()) { + if (!GenerateElseStatement(else_stmt.get())) { + return false; + } + } + + // TODO(dsinclair): The branch should be optional based on how the + // StatementList ended ... + push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)}); } // Output the merge block @@ -679,6 +687,21 @@ bool Builder::GenerateIfStatement(ast::IfStatement* stmt) { return true; } +bool Builder::GenerateElseStatement(ast::ElseStatement* stmt) { + // TODO(dsinclair): handle else if + if (stmt->HasCondition()) { + error_ = "else if not handled yet"; + return false; + } + + for (const auto& inst : stmt->body()) { + if (!GenerateStatement(inst.get())) { + return false; + } + } + return true; +} + bool Builder::GenerateReturnStatement(ast::ReturnStatement* stmt) { if (stmt->has_value()) { auto val_id = GenerateExpression(stmt->value()); diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h index d2981a1a2d..dab1e3edc4 100644 --- a/src/writer/spirv/builder.h +++ b/src/writer/spirv/builder.h @@ -148,6 +148,10 @@ class Builder { /// @param assign the statement to generate /// @returns true if the statement was successfully generated bool GenerateAssignStatement(ast::AssignmentStatement* assign); + /// Generates an else statement + /// @param stmt the statement to generate + /// @returns true on successfull generation + bool GenerateElseStatement(ast::ElseStatement* stmt); /// Generates an entry point instruction /// @param ep the entry point /// @returns true if the instruction was generated, false otherwise diff --git a/src/writer/spirv/builder_if_test.cc b/src/writer/spirv/builder_if_test.cc index 1be836a455..15d6809e94 100644 --- a/src/writer/spirv/builder_if_test.cc +++ b/src/writer/spirv/builder_if_test.cc @@ -17,6 +17,7 @@ #include "gtest/gtest.h" #include "src/ast/assignment_statement.h" #include "src/ast/bool_literal.h" +#include "src/ast/else_statement.h" #include "src/ast/identifier_expression.h" #include "src/ast/if_statement.h" #include "src/ast/int_literal.h" @@ -109,19 +110,88 @@ OpBranch %6 )"); } -TEST_F(BuilderTest, DISABLED_If_WithStatements_Returns) { - // if (a) { return; } -} +TEST_F(BuilderTest, If_WithElse) { + ast::type::BoolType bool_type; + ast::type::I32Type i32; -TEST_F(BuilderTest, DISABLED_If_WithElse) {} + auto var = + std::make_unique("v", ast::StorageClass::kPrivate, &i32); + + ast::StatementList body; + body.push_back(std::make_unique( + std::make_unique("v"), + std::make_unique( + std::make_unique(&i32, 2)))); + + ast::StatementList else_body; + else_body.push_back(std::make_unique( + std::make_unique("v"), + std::make_unique( + std::make_unique(&i32, 3)))); + + ast::ElseStatementList else_stmts; + else_stmts.push_back( + std::make_unique(std::move(else_body))); + + auto cond = std::make_unique( + std::make_unique(&bool_type, true)); + + ast::IfStatement expr(std::move(cond), std::move(body)); + expr.set_else_statements(std::move(else_stmts)); + + Context ctx; + TypeDeterminer td(&ctx); + td.RegisterVariableForTesting(var.get()); + + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + Builder b; + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); + + EXPECT_TRUE(b.GenerateIfStatement(&expr)) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeInt 32 1 +%2 = OpTypePointer Private %3 +%1 = OpVariable %2 Private +%4 = OpTypeBool +%5 = OpConstantTrue %4 +%9 = OpConstant %3 2 +%10 = OpConstant %3 3 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(OpSelectionMerge %6 None +OpBranchConditional %5 %7 %8 +%7 = OpLabel +OpStore %1 %9 +OpBranch %6 +%8 = OpLabel +OpStore %1 %10 +OpBranch %6 +%6 = OpLabel +)"); +} TEST_F(BuilderTest, DISABLED_If_WithElseIf) {} TEST_F(BuilderTest, DISABLED_If_WithMultiple) {} -TEST_F(BuilderTest, DISABLED_If_WithBreak) {} +TEST_F(BuilderTest, DISABLED_If_WithBreak) { + // if (a) { + // break; + // } +} -TEST_F(BuilderTest, DISABLED_If_WithContinue) {} +TEST_F(BuilderTest, DISABLED_If_WithContinue) { + // if (a) { + // continue; + // } +} + +TEST_F(BuilderTest, DISABLED_IF_WithReturn) { + // if (a) { + // return; + // } +} } // namespace } // namespace spirv