From d336733c8338ae5bf3f4eadf9bcf8c286cd20fba Mon Sep 17 00:00:00 2001 From: dan sinclair Date: Tue, 1 Nov 2022 14:34:08 +0000 Subject: [PATCH] [IR] Add switch control flow node. This CL updates the IR builder to create control flow nodes for a switch statement and the contained case statements. Bug: tint:1718 Change-Id: I05b73db11ab14676cc123f436ae5912b1dbee0d5 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/107801 Reviewed-by: Ben Clayton Kokoro: Kokoro Commit-Queue: Dan Sinclair --- src/tint/ir/builder.cc | 15 +++ src/tint/ir/builder.h | 14 ++ src/tint/ir/builder_impl.cc | 48 +++++-- src/tint/ir/builder_impl.h | 5 + src/tint/ir/builder_impl_test.cc | 222 +++++++++++++++++++++++++++++++ src/tint/ir/switch.cc | 2 +- src/tint/ir/switch.h | 23 +++- 7 files changed, 317 insertions(+), 12 deletions(-) diff --git a/src/tint/ir/builder.cc b/src/tint/ir/builder.cc index 594b2799ae..b023a0a914 100644 --- a/src/tint/ir/builder.cc +++ b/src/tint/ir/builder.cc @@ -71,6 +71,21 @@ Loop* Builder::CreateLoop(const ast::LoopStatement* stmt) { return ir_loop; } +Switch* Builder::CreateSwitch(const ast::SwitchStatement* stmt) { + auto* ir_switch = ir.flow_nodes.Create(stmt); + ir_switch->merge_target = CreateBlock(); + return ir_switch; +} + +Block* Builder::CreateCase(Switch* s, const utils::VectorRef selectors) { + s->cases.Push(Switch::Case{selectors, CreateBlock()}); + + Block* b = s->cases.Back().start_target; + // Switch branches into the case block + b->inbound_branches.Push(s); + return b; +} + void Builder::Branch(Block* from, FlowNode* to) { TINT_ASSERT(IR, from); TINT_ASSERT(IR, to); diff --git a/src/tint/ir/builder.h b/src/tint/ir/builder.h index a60e7b087a..23d4da4439 100644 --- a/src/tint/ir/builder.h +++ b/src/tint/ir/builder.h @@ -26,6 +26,9 @@ namespace tint { class Program; } // namespace tint +namespace tint::ast { +class CaseSelector; +} // namespace tint::ast namespace tint::ir { @@ -62,6 +65,17 @@ class Builder { /// @returns the flow node Loop* CreateLoop(const ast::LoopStatement* stmt); + /// Creates a switch flow node for the given ast::SwitchStatement + /// @param stmt the ast::SwitchStatment + /// @returns the flow node + Switch* CreateSwitch(const ast::SwitchStatement* stmt); + + /// Creates a case flow node for the given case branch. + /// @param s the switch to create the case into + /// @param selectors the case selectors for the case statement + /// @returns the start block for the case flow node + Block* CreateCase(Switch* s, const utils::VectorRef selectors); + /// Branches the given block to the given flow node. /// @param from the block to branch from /// @param to the node to branch too diff --git a/src/tint/ir/builder_impl.cc b/src/tint/ir/builder_impl.cc index 9c5caff4d9..895e316a4f 100644 --- a/src/tint/ir/builder_impl.cc +++ b/src/tint/ir/builder_impl.cc @@ -24,6 +24,7 @@ #include "src/tint/ast/return_statement.h" #include "src/tint/ast/statement.h" #include "src/tint/ast/static_assert.h" +#include "src/tint/ast/switch_statement.h" #include "src/tint/ir/function.h" #include "src/tint/ir/if.h" #include "src/tint/ir/loop.h" @@ -209,7 +210,7 @@ bool BuilderImpl::EmitStatement(const ast::Statement* stmt) { // [&](const ast::ForLoopStatement* l) { }, // [&](const ast::WhileStatement* l) { }, [&](const ast::ReturnStatement* r) { return EmitReturn(r); }, - // [&](const ast::SwitchStatement* s) { }, + [&](const ast::SwitchStatement* s) { return EmitSwitch(s); }, // [&](const ast::VariableDeclStatement* v) { }, [&](const ast::StaticAssert*) { return true; // Not emitted @@ -254,15 +255,6 @@ bool BuilderImpl::EmitIf(const ast::IfStatement* stmt) { } current_flow_block_ = nullptr; - // If both branches went somewhere, then they both returned, continued or broke. So, - // there is no need for the if merge-block and there is nothing to branch to the merge - // block anyway. - if (IsBranched(if_node->true_target) && IsBranched(if_node->false_target)) { - return true; - } - - current_flow_block_ = if_node->merge_target; - // If the true branch did not execute control flow, then go to the merge target if (!IsBranched(if_node->true_target)) { builder_.Branch(if_node->true_target, if_node->merge_target); @@ -272,6 +264,13 @@ bool BuilderImpl::EmitIf(const ast::IfStatement* stmt) { builder_.Branch(if_node->false_target, if_node->merge_target); } + // If both branches went somewhere, then they both returned, continued or broke. So, + // there is no need for the if merge-block and there is nothing to branch to the merge + // block anyway. + if (IsConnected(if_node->merge_target)) { + current_flow_block_ = if_node->merge_target; + } + return true; } @@ -313,6 +312,35 @@ bool BuilderImpl::EmitLoop(const ast::LoopStatement* stmt) { return true; } +bool BuilderImpl::EmitSwitch(const ast::SwitchStatement* stmt) { + auto* switch_node = builder_.CreateSwitch(stmt); + + // TODO(dsinclair): Emit the condition expression into the current block + + BranchTo(switch_node); + + ast_to_flow_[stmt] = switch_node; + + { + FlowStackScope scope(this, switch_node); + + for (const auto* c : stmt->body) { + current_flow_block_ = builder_.CreateCase(switch_node, c->selectors); + if (!EmitStatement(c->body)) { + return false; + } + BranchToIfNeeded(switch_node->merge_target); + } + } + current_flow_block_ = nullptr; + + if (IsConnected(switch_node->merge_target)) { + current_flow_block_ = switch_node->merge_target; + } + + return true; +} + bool BuilderImpl::EmitReturn(const ast::ReturnStatement*) { // TODO(dsinclair): Emit the return value .... diff --git a/src/tint/ir/builder_impl.h b/src/tint/ir/builder_impl.h index 1ca248223f..ac76bd7cf4 100644 --- a/src/tint/ir/builder_impl.h +++ b/src/tint/ir/builder_impl.h @@ -102,6 +102,11 @@ class BuilderImpl { /// @returns true if successful, false otherwise. bool EmitLoop(const ast::LoopStatement* stmt); + /// Emits a switch statement + /// @param stmt the switch statement + /// @returns true if successfull, false otherwise. + bool EmitSwitch(const ast::SwitchStatement* stmt); + /// Emits a break statement /// @param stmt the break statement /// @returns true if successfull, false otherwise. diff --git a/src/tint/ir/builder_impl_test.cc b/src/tint/ir/builder_impl_test.cc index eb932da748..826cd77ef7 100644 --- a/src/tint/ir/builder_impl_test.cc +++ b/src/tint/ir/builder_impl_test.cc @@ -14,9 +14,14 @@ #include "src/tint/ir/test_helper.h" +#include "src/tint/ast/case_selector.h" +#include "src/tint/ast/int_literal_expression.h" + namespace tint::ir { namespace { +using namespace tint::number_suffixes; // NOLINT + using IRBuilderImplTest = TestHelper; TEST_F(IRBuilderImplTest, Func) { @@ -817,5 +822,222 @@ TEST_F(IRBuilderImplTest, Loop_Nested) { EXPECT_EQ(loop_flow_a->merge_target->branch_target, func->end_target); } +TEST_F(IRBuilderImplTest, Switch) { + // func -> switch -> case 1 + // -> case 2 + // -> default + // + // [case 1] -> switch merge + // [case 2] -> switch merge + // [default] -> switch merge + // [switch merge] -> func end + // + auto* ast_switch = Switch( + 1_i, utils::Vector{Case(utils::Vector{CaseSelector(0_i)}, Block()), + Case(utils::Vector{CaseSelector(1_i)}, Block()), DefaultCase(Block())}); + + WrapInFunction(ast_switch); + auto& b = Build(); + + auto r = b.Build(); + ASSERT_TRUE(r) << b.error(); + auto m = r.Move(); + + auto* ir_switch = b.FlowNodeForAstNode(ast_switch); + ASSERT_NE(ir_switch, nullptr); + ASSERT_TRUE(ir_switch->Is()); + + auto* flow = ir_switch->As(); + ASSERT_NE(flow->merge_target, nullptr); + ASSERT_EQ(3u, flow->cases.Length()); + + ASSERT_EQ(1u, m.functions.Length()); + auto* func = m.functions[0]; + + ASSERT_EQ(1u, flow->cases[0].selectors.Length()); + ASSERT_TRUE(flow->cases[0].selectors[0]->expr->Is()); + EXPECT_EQ(0_i, flow->cases[0].selectors[0]->expr->As()->value); + + ASSERT_EQ(1u, flow->cases[1].selectors.Length()); + ASSERT_TRUE(flow->cases[1].selectors[0]->expr->Is()); + EXPECT_EQ(1_i, flow->cases[1].selectors[0]->expr->As()->value); + + ASSERT_EQ(1u, flow->cases[2].selectors.Length()); + EXPECT_TRUE(flow->cases[2].selectors[0]->IsDefault()); + + EXPECT_EQ(1u, flow->inbound_branches.Length()); + EXPECT_EQ(1u, flow->cases[0].start_target->inbound_branches.Length()); + EXPECT_EQ(1u, flow->cases[1].start_target->inbound_branches.Length()); + EXPECT_EQ(1u, flow->cases[2].start_target->inbound_branches.Length()); + EXPECT_EQ(3u, flow->merge_target->inbound_branches.Length()); + EXPECT_EQ(1u, func->end_target->inbound_branches.Length()); + + EXPECT_EQ(func->start_target->branch_target, ir_switch); + EXPECT_EQ(flow->cases[0].start_target->branch_target, flow->merge_target); + EXPECT_EQ(flow->cases[1].start_target->branch_target, flow->merge_target); + EXPECT_EQ(flow->cases[2].start_target->branch_target, flow->merge_target); + EXPECT_EQ(flow->merge_target->branch_target, func->end_target); +} + +TEST_F(IRBuilderImplTest, Switch_OnlyDefault) { + // func -> switch -> default -> switch merge -> func end + // + auto* ast_switch = Switch(1_i, utils::Vector{DefaultCase(Block())}); + + WrapInFunction(ast_switch); + auto& b = Build(); + + auto r = b.Build(); + ASSERT_TRUE(r) << b.error(); + auto m = r.Move(); + + auto* ir_switch = b.FlowNodeForAstNode(ast_switch); + ASSERT_NE(ir_switch, nullptr); + ASSERT_TRUE(ir_switch->Is()); + + auto* flow = ir_switch->As(); + ASSERT_NE(flow->merge_target, nullptr); + ASSERT_EQ(1u, flow->cases.Length()); + + ASSERT_EQ(1u, m.functions.Length()); + auto* func = m.functions[0]; + + ASSERT_EQ(1u, flow->cases[0].selectors.Length()); + EXPECT_TRUE(flow->cases[0].selectors[0]->IsDefault()); + + EXPECT_EQ(1u, flow->inbound_branches.Length()); + EXPECT_EQ(1u, flow->cases[0].start_target->inbound_branches.Length()); + EXPECT_EQ(1u, flow->merge_target->inbound_branches.Length()); + EXPECT_EQ(1u, func->end_target->inbound_branches.Length()); + + EXPECT_EQ(func->start_target->branch_target, ir_switch); + EXPECT_EQ(flow->cases[0].start_target->branch_target, flow->merge_target); + EXPECT_EQ(flow->merge_target->branch_target, func->end_target); +} + +TEST_F(IRBuilderImplTest, Switch_WithBreak) { + // { + // switch(1) { + // case 0: { + // break; + // if true { return;} // Dead code + // } + // default: {} + // } + // } + // + // func -> switch -> case 1 + // -> default + // + // [case 1] -> switch merge + // [default] -> switch merge + // [switch merge] -> func end + auto* ast_switch = Switch(1_i, utils::Vector{Case(utils::Vector{CaseSelector(0_i)}, + Block(Break(), If(true, Block(Return())))), + DefaultCase(Block())}); + + WrapInFunction(ast_switch); + auto& b = Build(); + + auto r = b.Build(); + ASSERT_TRUE(r) << b.error(); + auto m = r.Move(); + + auto* ir_switch = b.FlowNodeForAstNode(ast_switch); + ASSERT_NE(ir_switch, nullptr); + ASSERT_TRUE(ir_switch->Is()); + + auto* flow = ir_switch->As(); + ASSERT_NE(flow->merge_target, nullptr); + ASSERT_EQ(2u, flow->cases.Length()); + + ASSERT_EQ(1u, m.functions.Length()); + auto* func = m.functions[0]; + + ASSERT_EQ(1u, flow->cases[0].selectors.Length()); + ASSERT_TRUE(flow->cases[0].selectors[0]->expr->Is()); + EXPECT_EQ(0_i, flow->cases[0].selectors[0]->expr->As()->value); + + ASSERT_EQ(1u, flow->cases[1].selectors.Length()); + EXPECT_TRUE(flow->cases[1].selectors[0]->IsDefault()); + + EXPECT_EQ(1u, flow->inbound_branches.Length()); + EXPECT_EQ(1u, flow->cases[0].start_target->inbound_branches.Length()); + EXPECT_EQ(1u, flow->cases[1].start_target->inbound_branches.Length()); + EXPECT_EQ(2u, flow->merge_target->inbound_branches.Length()); + // This is 1 because the if is dead-code eliminated and the return doesn't happen. + EXPECT_EQ(1u, func->end_target->inbound_branches.Length()); + + EXPECT_EQ(func->start_target->branch_target, ir_switch); + EXPECT_EQ(flow->cases[0].start_target->branch_target, flow->merge_target); + EXPECT_EQ(flow->cases[1].start_target->branch_target, flow->merge_target); + EXPECT_EQ(flow->merge_target->branch_target, func->end_target); +} + +TEST_F(IRBuilderImplTest, Switch_AllReturn) { + // { + // switch(1) { + // case 0: { + // return; + // } + // default: { + // return; + // } + // } + // if true { return; } // Dead code + // } + // + // func -> switch -> case 1 + // -> default + // + // [case 1] -> func end + // [default] -> func end + // [switch merge] -> nullptr + // + auto* ast_switch = + Switch(1_i, utils::Vector{Case(utils::Vector{CaseSelector(0_i)}, Block(Return())), + DefaultCase(Block(Return()))}); + + auto* ast_if = If(true, Block(Return())); + + WrapInFunction(ast_switch, ast_if); + auto& b = Build(); + + auto r = b.Build(); + ASSERT_TRUE(r) << b.error(); + auto m = r.Move(); + + ASSERT_EQ(b.FlowNodeForAstNode(ast_if), nullptr); + + auto* ir_switch = b.FlowNodeForAstNode(ast_switch); + ASSERT_NE(ir_switch, nullptr); + ASSERT_TRUE(ir_switch->Is()); + + auto* flow = ir_switch->As(); + ASSERT_NE(flow->merge_target, nullptr); + ASSERT_EQ(2u, flow->cases.Length()); + + ASSERT_EQ(1u, m.functions.Length()); + auto* func = m.functions[0]; + + ASSERT_EQ(1u, flow->cases[0].selectors.Length()); + ASSERT_TRUE(flow->cases[0].selectors[0]->expr->Is()); + EXPECT_EQ(0_i, flow->cases[0].selectors[0]->expr->As()->value); + + ASSERT_EQ(1u, flow->cases[1].selectors.Length()); + EXPECT_TRUE(flow->cases[1].selectors[0]->IsDefault()); + + EXPECT_EQ(1u, flow->inbound_branches.Length()); + EXPECT_EQ(1u, flow->cases[0].start_target->inbound_branches.Length()); + EXPECT_EQ(1u, flow->cases[1].start_target->inbound_branches.Length()); + EXPECT_EQ(0u, flow->merge_target->inbound_branches.Length()); + EXPECT_EQ(2u, func->end_target->inbound_branches.Length()); + + EXPECT_EQ(func->start_target->branch_target, ir_switch); + EXPECT_EQ(flow->cases[0].start_target->branch_target, func->end_target); + EXPECT_EQ(flow->cases[1].start_target->branch_target, func->end_target); + EXPECT_EQ(flow->merge_target->branch_target, nullptr); +} + } // namespace } // namespace tint::ir diff --git a/src/tint/ir/switch.cc b/src/tint/ir/switch.cc index 9ad6d3030f..23b7fbbd21 100644 --- a/src/tint/ir/switch.cc +++ b/src/tint/ir/switch.cc @@ -18,7 +18,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::ir::Switch); namespace tint::ir { -Switch::Switch() : Base() {} +Switch::Switch(const ast::SwitchStatement* stmt) : Base(), source(stmt) {} Switch::~Switch() = default; diff --git a/src/tint/ir/switch.h b/src/tint/ir/switch.h index e9de3ae551..39d3d06b4c 100644 --- a/src/tint/ir/switch.h +++ b/src/tint/ir/switch.h @@ -18,17 +18,38 @@ #include "src/tint/ir/block.h" #include "src/tint/ir/flow_node.h" +// Forward declarations +namespace tint::ast { +class CaseSelector; +class SwitchStatement; +} // namespace tint::ast + namespace tint::ir { /// Flow node representing a switch statement class Switch : public Castable { public: + /// A case label in the struct + struct Case { + /// The case selector for this node + const utils::VectorRef selectors; + /// The start block for the case block. + Block* start_target; + }; + /// Constructor - Switch(); + /// @param stmt the originating ast switch statement + explicit Switch(const ast::SwitchStatement* stmt); ~Switch() override; + /// The originating switch statment in the AST + const ast::SwitchStatement* source; + /// The switch merge target Block* merge_target; + + /// The switch case statements + utils::Vector cases; }; } // namespace tint::ir