diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index 3c42772054..723f35174e 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -14,6 +14,7 @@ #include "src/reader/spirv/function.h" +#include #include #include #include @@ -28,6 +29,7 @@ #include "src/ast/assignment_statement.h" #include "src/ast/binary_expression.h" #include "src/ast/break_statement.h" +#include "src/ast/case_statement.h" #include "src/ast/continue_statement.h" #include "src/ast/else_statement.h" #include "src/ast/fallthrough_statement.h" @@ -38,6 +40,7 @@ #include "src/ast/member_accessor_expression.h" #include "src/ast/return_statement.h" #include "src/ast/scalar_constructor_expression.h" +#include "src/ast/sint_literal.h" #include "src/ast/storage_class.h" #include "src/ast/switch_statement.h" #include "src/ast/uint_literal.h" @@ -387,7 +390,7 @@ FunctionEmitter::StatementBlock::StatementBlock( uint32_t end_id, CompletionAction completion_action, ast::StatementList statements, - ast::CaseStatementList cases) + std::unique_ptr cases) : construct_(construct), end_id_(end_id), completion_action_(completion_action), @@ -401,9 +404,8 @@ FunctionEmitter::StatementBlock::~StatementBlock() = default; void FunctionEmitter::PushNewStatementBlock(const Construct* construct, uint32_t end_id, CompletionAction action) { - statements_stack_.emplace_back(StatementBlock(construct, end_id, action, - ast::StatementList{}, - ast::CaseStatementList{})); + statements_stack_.emplace_back( + StatementBlock{construct, end_id, action, ast::StatementList{}, nullptr}); } const ast::StatementList& FunctionEmitter::ast_body() { @@ -981,7 +983,6 @@ bool FunctionEmitter::FindSwitchCaseHeaders() { // Process case targets. for (uint32_t iarg = 2; iarg + 1 < branch->NumInOperands(); iarg += 2) { - const auto o = branch->GetInOperand(iarg); const auto value = branch->GetInOperand(iarg).AsLiteralUint64(); const auto case_target_id = branch->GetSingleWordInOperand(iarg + 1); @@ -1715,8 +1716,14 @@ bool FunctionEmitter::EmitBasicBlock(const BlockInfo& block_info) { break; case Construct::kSwitchSelection: + if (!EmitStatementsInBasicBlock(block_info, &emitted)) { + return false; + } + if (!EmitSwitchStart(block_info)) { + return false; + } has_normal_terminator = false; - return Fail() << "unhandled: switch construct"; + break; } } @@ -1827,6 +1834,128 @@ bool FunctionEmitter::EmitIfStart(const BlockInfo& block_info) { return success(); } +bool FunctionEmitter::EmitSwitchStart(const BlockInfo& block_info) { + // The block is the if-header block. So its construct is the if construct. + auto* construct = block_info.construct; + assert(construct->kind == Construct::kSwitchSelection); + assert(construct->begin_id == block_info.id); + const auto* branch = block_info.basic_block->terminator(); + + auto* const switch_stmt = + AddStatement(std::make_unique())->AsSwitch(); + const auto selector_id = branch->GetSingleWordInOperand(0); + // Generate the code for the selector. + auto selector = MakeExpression(selector_id); + switch_stmt->set_condition(std::move(selector.expr)); + + // First, push the statement block for the entire switch. All the actual + // work is done by completion actions of the case/default clauses. + PushNewStatementBlock( + construct, construct->end_id, [switch_stmt](StatementBlock* s) { + switch_stmt->set_body(std::move(*std::move(s->cases_))); + }); + statements_stack_.back().cases_ = std::make_unique(); + // Grab a pointer to the case list. It will get buried in the statement block + // stack. + auto* cases = statements_stack_.back().cases_.get(); + + // We will push statement-blocks onto the stack to gather the statements in + // the default clause and cases clauses. Determine the list of blocks + // that start each clause. + std::vector clause_heads; + + // Collect the case clauses, even if they are just the merge block. + // First the default clause. + const auto default_id = branch->GetSingleWordInOperand(1); + const auto* default_info = GetBlockInfo(default_id); + clause_heads.push_back(default_info); + // Now the case clauses. + for (uint32_t iarg = 2; iarg + 1 < branch->NumInOperands(); iarg += 2) { + const auto case_target_id = branch->GetSingleWordInOperand(iarg + 1); + clause_heads.push_back(GetBlockInfo(case_target_id)); + } + + std::stable_sort(clause_heads.begin(), clause_heads.end(), + [](const BlockInfo* lhs, const BlockInfo* rhs) { + return lhs->pos < rhs->pos; + }); + // Remove duplicates + { + // Use read index r, and write index w. + // Invariant: w <= r; + size_t w = 0; + for (size_t r = 0; r < clause_heads.size(); ++r) { + if (clause_heads[r] != clause_heads[w]) { + ++w; // Advance the write cursor. + } + clause_heads[w] = clause_heads[r]; + } + // We know it's not empty because it always has at least a default clause. + assert(!clause_heads.empty()); + clause_heads.resize(w + 1); + } + + // Push them on in reverse order. + const auto last_clause_index = clause_heads.size() - 1; + for (size_t i = last_clause_index;; --i) { + // Create the case clause. Temporarily put it in the wrong order + // on the case statement list. + cases->emplace_back(std::make_unique()); + auto* clause = cases->back().get(); + + // Create a list of integer literals for the selector values leading to + // this case clause. + ast::CaseSelectorList selectors; + const auto* values_ptr = clause_heads[i]->case_values.get(); + const bool has_selectors = (values_ptr && !values_ptr->empty()); + if (has_selectors) { + std::vector values(values_ptr->begin(), values_ptr->end()); + std::stable_sort(values.begin(), values.end()); + for (auto value : values) { + // The rest of this module can handle up to 64 bit switch values. + // The Tint AST handles 32-bit values. + const uint32_t value32 = uint32_t(value & 0xFFFFFFFF); + if (selector.type->is_unsigned_scalar_or_vector()) { + selectors.emplace_back( + std::make_unique(selector.type, value32)); + } else { + selectors.emplace_back( + std::make_unique(selector.type, value32)); + } + } + clause->set_selectors(std::move(selectors)); + } + + // Where does this clause end? + const auto end_id = (i + 1 < clause_heads.size()) ? clause_heads[i + 1]->id + : construct->end_id; + + PushNewStatementBlock(construct, end_id, [clause](StatementBlock* s) { + clause->set_body(std::move(s->statements_)); + }); + + if ((default_info == clause_heads[i]) && has_selectors && + construct->ContainsPos(default_info->pos)) { + // Generate a default clause with a just fallthrough. + ast::StatementList stmts; + stmts.emplace_back(std::make_unique()); + auto case_stmt = std::make_unique(); + case_stmt->set_body(std::move(stmts)); + cases->emplace_back(std::move(case_stmt)); + } + + if (i == 0) { + break; + } + } + + // We've listed cases in reverse order in the switch statement. Reorder them + // to match the presentation order in WGSL. + std::reverse(cases->begin(), cases->end()); + + return success(); +} + bool FunctionEmitter::EmitLoopStart(const Construct* construct) { auto* loop = AddStatement(std::make_unique())->AsLoop(); PushNewStatementBlock( @@ -1946,7 +2075,31 @@ std::unique_ptr FunctionEmitter::MakeBranch( case EdgeKind::kBack: // Nothing to do. The loop backedge is implicit. break; - case EdgeKind::kSwitchBreak: + case EdgeKind::kSwitchBreak: { + // Don't bother with a break at the end of a case. + const auto header = dest_info.header_for_merge; + assert(header != 0); + const auto* exiting_construct = GetBlockInfo(header)->construct; + assert(exiting_construct->kind == Construct::kSwitchSelection); + const auto candidate_next_case_pos = src_info.pos + 1; + // Leaving the last block from the last case? + if (candidate_next_case_pos == dest_info.pos) { + // No break needed. + return nullptr; + } + // Leaving the last block from not-the-last-case? + if (exiting_construct->ContainsPos(candidate_next_case_pos)) { + const auto* candidate_next_case = + GetBlockInfo(block_order_[candidate_next_case_pos]); + if (candidate_next_case->case_head_for == exiting_construct || + candidate_next_case->default_head_for == exiting_construct) { + // No break needed. + return nullptr; + } + } + // We need a break. + return std::make_unique(); + } case EdgeKind::kLoopBreak: return std::make_unique(); case EdgeKind::kLoopContinue: diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h index a96f7e8cb6..a5a53471eb 100644 --- a/src/reader/spirv/function.h +++ b/src/reader/spirv/function.h @@ -289,6 +289,13 @@ class FunctionEmitter { /// @returns false if emission failed. bool EmitIfStart(const BlockInfo& block_info); + /// Emits a SwitchStatement, including its condition expression, and sets + /// up the statement stack to accumulate subsequent basic blocks into + /// the default clause and case clauses. + /// @param block_info the switch-selection header block + /// @returns false if emission failed. + bool EmitSwitchStart(const BlockInfo& block_info); + /// Emits a LoopStatement, and pushes a new StatementBlock to accumulate /// the remaining instructions in the current block and subsequent blocks /// in the loop. @@ -375,7 +382,7 @@ class FunctionEmitter { /// Gets the block info for a block ID, if any exists /// @param id the SPIR-V ID of the OpLabel instruction starting the block /// @returns the block info for the given ID, if it exists, or nullptr - BlockInfo* GetBlockInfo(uint32_t id) { + BlockInfo* GetBlockInfo(uint32_t id) const { auto where = block_info_.find(id); if (where == block_info_.end()) return nullptr; @@ -434,7 +441,7 @@ class FunctionEmitter { uint32_t end_id, CompletionAction completion_action, ast::StatementList statements, - ast::CaseStatementList cases); + std::unique_ptr cases); StatementBlock(StatementBlock&&); ~StatementBlock(); @@ -449,10 +456,13 @@ class FunctionEmitter { // Only one of |statements| or |cases| is active. - // The list of statements being built. + // The list of statements being built, if this construct is not a switch. ast::StatementList statements_; - // The list of cases being built, for a switch. - ast::CaseStatementList cases_; + // The list of switch cases being built, if this construct is a switch. + // The algorithm will cache a pointer to the vector. We want that pointer + // to be stable no matter how |statements_stack_| is resized. That's + // why we make this a unique_ptr rather than just a plain vector. + std::unique_ptr cases_; }; /// Pushes an empty statement block onto the statements stack. diff --git a/src/reader/spirv/function_cfg_test.cc b/src/reader/spirv/function_cfg_test.cc index 4d64864987..4cf029dc29 100644 --- a/src/reader/spirv/function_cfg_test.cc +++ b/src/reader/spirv/function_cfg_test.cc @@ -60,7 +60,9 @@ std::string CommonTypes() { %cond3 = OpConstantFalse %bool %uint = OpTypeInt 32 0 + %int = OpTypeInt 32 1 %selector = OpConstant %uint 42 + %signed_selector = OpConstant %int 42 %uintfn = OpTypeFunction %uint @@ -72,6 +74,11 @@ std::string CommonTypes() { %uint_5 = OpConstant %uint 5 %uint_6 = OpConstant %uint 6 %uint_7 = OpConstant %uint 7 + %uint_8 = OpConstant %uint 8 + %uint_20 = OpConstant %uint 20 + %uint_30 = OpConstant %uint 30 + %uint_40 = OpConstant %uint 40 + %uint_50 = OpConstant %uint 50 %ptr_Private_uint = OpTypePointer Private %uint %var = OpVariable %ptr_Private_uint Private @@ -8757,6 +8764,486 @@ Return{} )")) << ToString(fe.ast_body()); } +TEST_F(SpvParserTest, EmitBody_Switch_DefaultIsMerge_NoCases) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_1 + OpSelectionMerge %99 None + OpSwitch %selector %99 + + %99 = OpLabel + OpStore %var %uint_7 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{1} +} +Switch{ + ScalarConstructor{42} + { + Default{ + } + } +} +Assignment{ + Identifier{var} + ScalarConstructor{7} +} +Return{} +)")) << ToString(fe.ast_body()); +} + +// First do no special control flow: no fallthroughs, breaks, continues. +TEST_F(SpvParserTest, EmitBody_Switch_DefaultIsMerge_OneCase) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_1 + OpSelectionMerge %99 None + OpSwitch %selector %99 20 %20 + + %20 = OpLabel + OpStore %var %uint_20 + OpBranch %99 + + %99 = OpLabel + OpStore %var %uint_7 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{1} +} +Switch{ + ScalarConstructor{42} + { + Case 20{ + Assignment{ + Identifier{var} + ScalarConstructor{20} + } + } + Default{ + } + } +} +Assignment{ + Identifier{var} + ScalarConstructor{7} +} +Return{} +)")) << ToString(fe.ast_body()); +} + +TEST_F(SpvParserTest, EmitBody_Switch_DefaultIsMerge_TwoCases) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_1 + OpSelectionMerge %99 None + OpSwitch %selector %99 20 %20 30 %30 + + %20 = OpLabel + OpStore %var %uint_20 + OpBranch %99 + + %30 = OpLabel + OpStore %var %uint_30 + OpBranch %99 + + %99 = OpLabel + OpStore %var %uint_7 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{1} +} +Switch{ + ScalarConstructor{42} + { + Case 30{ + Assignment{ + Identifier{var} + ScalarConstructor{30} + } + } + Case 20{ + Assignment{ + Identifier{var} + ScalarConstructor{20} + } + } + Default{ + } + } +} +Assignment{ + Identifier{var} + ScalarConstructor{7} +} +Return{} +)")) << ToString(fe.ast_body()); +} + +TEST_F(SpvParserTest, EmitBody_Switch_DefaultIsMerge_CasesWithDup) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_1 + OpSelectionMerge %99 None + OpSwitch %selector %99 20 %20 30 %30 40 %20 + + %20 = OpLabel + OpStore %var %uint_20 + OpBranch %99 + + %30 = OpLabel + OpStore %var %uint_30 + OpBranch %99 + + %99 = OpLabel + OpStore %var %uint_7 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{1} +} +Switch{ + ScalarConstructor{42} + { + Case 30{ + Assignment{ + Identifier{var} + ScalarConstructor{30} + } + } + Case 20, 40{ + Assignment{ + Identifier{var} + ScalarConstructor{20} + } + } + Default{ + } + } +} +Assignment{ + Identifier{var} + ScalarConstructor{7} +} +Return{} +)")) << ToString(fe.ast_body()); +} + +TEST_F(SpvParserTest, EmitBody_Switch_DefaultIsCase_NoDupCases) { + // The default block is not the merge block. But not the same as a case + // either. + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_1 + OpSelectionMerge %99 None + OpSwitch %selector %30 20 %20 40 %40 + + %20 = OpLabel + OpStore %var %uint_20 + OpBranch %99 + + %30 = OpLabel ; the named default block + OpStore %var %uint_30 + OpBranch %99 + + %40 = OpLabel + OpStore %var %uint_40 + OpBranch %99 + + %99 = OpLabel + OpStore %var %uint_7 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{1} +} +Switch{ + ScalarConstructor{42} + { + Case 40{ + Assignment{ + Identifier{var} + ScalarConstructor{40} + } + } + Case 20{ + Assignment{ + Identifier{var} + ScalarConstructor{20} + } + } + Default{ + Assignment{ + Identifier{var} + ScalarConstructor{30} + } + } + } +} +Assignment{ + Identifier{var} + ScalarConstructor{7} +} +Return{} +)")) << ToString(fe.ast_body()); +} + +TEST_F(SpvParserTest, EmitBody_Switch_DefaultIsCase_WithDupCase) { + // The default block is not the merge block and is the same as a case. + // We emit the default case separately, but just before the labeled + // case, and with a fallthrough. + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_1 + OpSelectionMerge %99 None + OpSwitch %selector %30 20 %20 30 %30 40 %40 + + %20 = OpLabel + OpStore %var %uint_20 + OpBranch %99 + + %30 = OpLabel ; the named default block, also a case + OpStore %var %uint_30 + OpBranch %99 + + %40 = OpLabel + OpStore %var %uint_40 + OpBranch %99 + + %99 = OpLabel + OpStore %var %uint_7 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{1} +} +Switch{ + ScalarConstructor{42} + { + Case 40{ + Assignment{ + Identifier{var} + ScalarConstructor{40} + } + } + Case 20{ + Assignment{ + Identifier{var} + ScalarConstructor{20} + } + } + Default{ + Fallthrough{} + } + Case 30{ + Assignment{ + Identifier{var} + ScalarConstructor{30} + } + } + } +} +Assignment{ + Identifier{var} + ScalarConstructor{7} +} +Return{} +)")) << ToString(fe.ast_body()); +} + +TEST_F(SpvParserTest, EmitBody_Switch_Case_SintValue) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_1 + OpSelectionMerge %99 None + ; SPIR-V assembler doesn't support negative literals in switch + OpSwitch %signed_selector %99 20 %20 2000000000 %30 !4000000000 %40 + + %20 = OpLabel + OpStore %var %uint_20 + OpBranch %99 + + %30 = OpLabel + OpStore %var %uint_30 + OpBranch %99 + + %40 = OpLabel + OpStore %var %uint_40 + OpBranch %99 + + %99 = OpLabel + OpStore %var %uint_7 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{1} +} +Switch{ + ScalarConstructor{42} + { + Case -294967296{ + Assignment{ + Identifier{var} + ScalarConstructor{40} + } + } + Case 2000000000{ + Assignment{ + Identifier{var} + ScalarConstructor{30} + } + } + Case 20{ + Assignment{ + Identifier{var} + ScalarConstructor{20} + } + } + Default{ + } + } +} +Assignment{ + Identifier{var} + ScalarConstructor{7} +} +Return{} +)")) << ToString(fe.ast_body()); +} + +TEST_F(SpvParserTest, EmitBody_Switch_Case_UintValue) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_1 + OpSelectionMerge %99 None + OpSwitch %selector %99 20 %20 2000000000 %30 50 %40 + + %20 = OpLabel + OpStore %var %uint_20 + OpBranch %99 + + %30 = OpLabel + OpStore %var %uint_30 + OpBranch %99 + + %40 = OpLabel + OpStore %var %uint_40 + OpBranch %99 + + %99 = OpLabel + OpStore %var %uint_7 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{1} +} +Switch{ + ScalarConstructor{42} + { + Case 50{ + Assignment{ + Identifier{var} + ScalarConstructor{40} + } + } + Case 2000000000{ + Assignment{ + Identifier{var} + ScalarConstructor{30} + } + } + Case 20{ + Assignment{ + Identifier{var} + ScalarConstructor{20} + } + } + Default{ + } + } +} +Assignment{ + Identifier{var} + ScalarConstructor{7} +} +Return{} +)")) << ToString(fe.ast_body()); +} + TEST_F(SpvParserTest, EmitBody_Return_TopLevel) { auto* p = parser(test::Assemble(CommonTypes() + R"( %100 = OpFunction %void None %voidfn @@ -9227,8 +9714,128 @@ Return{} )")) << ToString(fe.ast_body()); } -TEST_F(SpvParserTest, DISABLED_EmitBody_Branch_SwitchBreak) { - // TODO(dneto): support switch first. +TEST_F(SpvParserTest, EmitBody_Branch_SwitchBreak_LastInCase) { + // When the break is last in its case, we omit it because it's implicit in + // WGSL. + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_1 + OpSelectionMerge %99 None + OpSwitch %selector %99 20 %20 + + %20 = OpLabel + OpStore %var %uint_20 + OpBranch %99 ; branch to merge. Last in case + + %99 = OpLabel + OpStore %var %uint_7 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{1} +} +Switch{ + ScalarConstructor{42} + { + Case 20{ + Assignment{ + Identifier{var} + ScalarConstructor{20} + } + } + Default{ + } + } +} +Assignment{ + Identifier{var} + ScalarConstructor{7} +} +Return{} +)")) << ToString(fe.ast_body()); +} + +TEST_F(SpvParserTest, EmitBody_Branch_SwitchBreak_NotLastInCase) { + // When the break is not last in its case, we must emit a 'break' + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_1 + OpSelectionMerge %99 None + OpSwitch %selector %99 20 %20 + + %20 = OpLabel + OpStore %var %uint_20 + OpSelectionMerge %50 None + OpBranchConditional %cond %40 %50 + + %40 = OpLabel + OpStore %var %uint_40 + OpBranch %99 ; branch to merge. Not last in case + + %50 = OpLabel ; inner merge + OpStore %var %uint_50 + OpBranch %99 + + %99 = OpLabel + OpStore %var %uint_7 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{1} +} +Switch{ + ScalarConstructor{42} + { + Case 20{ + Assignment{ + Identifier{var} + ScalarConstructor{20} + } + If{ + ( + ScalarConstructor{false} + ) + { + Assignment{ + Identifier{var} + ScalarConstructor{40} + } + Break{} + } + } + Assignment{ + Identifier{var} + ScalarConstructor{50} + } + } + Default{ + } + } +} +Assignment{ + Identifier{var} + ScalarConstructor{7} +} +Return{} +)")) << ToString(fe.ast_body()); } TEST_F(SpvParserTest, EmitBody_Branch_LoopBreak_MultiBlockLoop_FromBody) { @@ -9459,6 +10066,91 @@ Return{} )")) << ToString(fe.ast_body()); } +TEST_F(SpvParserTest, EmitBody_Branch_LoopContinue_FromSwitch) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_1 + OpBranch %20 + + %20 = OpLabel + OpStore %var %uint_2 + OpLoopMerge %99 %80 None + OpBranch %30 + + %30 = OpLabel + OpStore %var %uint_3 + OpSelectionMerge %79 None + OpSwitch %selector %79 40 %40 + + %40 = OpLabel + OpStore %var %uint_4 + OpBranch %80 ; continue edge + + %79 = OpLabel ; switch merge + OpStore %var %uint_5 + OpBranch %80 + + %80 = OpLabel ; continue target + OpStore %var %uint_6 + OpBranch %20 + + %99 = OpLabel + OpStore %var %uint_7 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{1} +} +Loop{ + Assignment{ + Identifier{var} + ScalarConstructor{2} + } + Assignment{ + Identifier{var} + ScalarConstructor{3} + } + Switch{ + ScalarConstructor{42} + { + Case 40{ + Assignment{ + Identifier{var} + ScalarConstructor{4} + } + Continue{} + } + Default{ + } + } + } + Assignment{ + Identifier{var} + ScalarConstructor{5} + } + continuing { + Assignment{ + Identifier{var} + ScalarConstructor{6} + } + } +} +Assignment{ + Identifier{var} + ScalarConstructor{7} +} +Return{} +)")) << ToString(fe.ast_body()); +} + TEST_F(SpvParserTest, EmitBody_Branch_IfBreak_FromThen) { // When unconditional, the if-break must be last in the then clause. auto* p = parser(test::Assemble(CommonTypes() + R"( @@ -9546,7 +10238,7 @@ Return{} } TEST_F(SpvParserTest, DISABLED_EmitBody_Branch_Fallthrough) { - // TODO(dneto): support switch first. + // TODO(dneto): support fallthrough first. } TEST_F(SpvParserTest, EmitBody_Branch_Forward) { @@ -9901,33 +10593,459 @@ Return{} } TEST_F(SpvParserTest, - DISABLED_EmitBody_BranchConditional_SwitchBreak_Continue_OnTrue) { - // TODO(dneto): needs switch support + EmitBody_BranchConditional_SwitchBreak_SwitchBreak_LastInCase) { + // When the break is last in its case, we omit it because it's implicit in + // WGSL. + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_1 + OpSelectionMerge %99 None + OpSwitch %selector %99 20 %20 + + %20 = OpLabel + OpStore %var %uint_20 + OpBranchConditional %cond2 %99 %99 ; branch to merge. Last in case + + %99 = OpLabel + OpStore %var %uint_7 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{1} +} +Switch{ + ScalarConstructor{42} + { + Case 20{ + Assignment{ + Identifier{var} + ScalarConstructor{20} + } + } + Default{ + } + } +} +Assignment{ + Identifier{var} + ScalarConstructor{7} +} +Return{} +)")) << ToString(fe.ast_body()); } TEST_F(SpvParserTest, - DISABLED_EmitBody_BranchConditional_SwitchBreak_Continue_OnFalse) { - // TODO(dneto): needs switch support + EmitBody_BranchConditional_SwitchBreak_SwitchBreak_NotLastInCase) { + // When the break is not last in its case, we must emit a 'break' + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_1 + OpSelectionMerge %99 None + OpSwitch %selector %99 20 %20 + + %20 = OpLabel + OpStore %var %uint_20 + OpSelectionMerge %50 None + OpBranchConditional %cond %40 %50 + + %40 = OpLabel + OpStore %var %uint_40 + OpBranchConditional %cond2 %99 %99 ; branch to merge. Not last in case + + %50 = OpLabel ; inner merge + OpStore %var %uint_50 + OpBranch %99 + + %99 = OpLabel + OpStore %var %uint_7 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{1} +} +Switch{ + ScalarConstructor{42} + { + Case 20{ + Assignment{ + Identifier{var} + ScalarConstructor{20} + } + If{ + ( + ScalarConstructor{false} + ) + { + Assignment{ + Identifier{var} + ScalarConstructor{40} + } + Break{} + } + } + Assignment{ + Identifier{var} + ScalarConstructor{50} + } + } + Default{ + } + } +} +Assignment{ + Identifier{var} + ScalarConstructor{7} +} +Return{} +)")) << ToString(fe.ast_body()); } -TEST_F(SpvParserTest, - DISABLED_EmitBody_BranchConditional_SwitchBreak_Forward_OnTrue) { - // TODO(dneto): needs switch support +TEST_F(SpvParserTest, EmitBody_BranchConditional_SwitchBreak_Continue_OnTrue) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_1 + OpBranch %20 + + %20 = OpLabel + OpStore %var %uint_2 + OpLoopMerge %99 %80 None + OpBranch %30 + + %30 = OpLabel + OpStore %var %uint_3 + OpSelectionMerge %79 None + OpSwitch %selector %79 40 %40 + + %40 = OpLabel + OpStore %var %uint_40 + OpBranchConditional %cond %80 %79 ; break; continue on true + + %79 = OpLabel + OpStore %var %uint_6 + OpBranch %80 + + %80 = OpLabel ; continue target + OpStore %var %uint_7 + OpBranch %20 + + %99 = OpLabel ; loop merge + OpStore %var %uint_8 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{1} +} +Loop{ + Assignment{ + Identifier{var} + ScalarConstructor{2} + } + Assignment{ + Identifier{var} + ScalarConstructor{3} + } + Switch{ + ScalarConstructor{42} + { + Case 40{ + Assignment{ + Identifier{var} + ScalarConstructor{40} + } + If{ + ( + ScalarConstructor{false} + ) + { + Continue{} + } + } + } + Default{ + } + } + } + Assignment{ + Identifier{var} + ScalarConstructor{6} + } + continuing { + Assignment{ + Identifier{var} + ScalarConstructor{7} + } + } +} +Assignment{ + Identifier{var} + ScalarConstructor{8} +} +Return{} +)")) << ToString(fe.ast_body()); } -TEST_F(SpvParserTest, - DISABLED_EmitBody_BranchConditional_SwitchBreak_Forward_OnFalse) { - // TODO(dneto): needs switch support +TEST_F(SpvParserTest, EmitBody_BranchConditional_SwitchBreak_Continue_OnFalse) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_1 + OpBranch %20 + + %20 = OpLabel + OpStore %var %uint_2 + OpLoopMerge %99 %80 None + OpBranch %30 + + %30 = OpLabel + OpStore %var %uint_3 + OpSelectionMerge %79 None + OpSwitch %selector %79 40 %40 + + %40 = OpLabel + OpStore %var %uint_40 + OpBranchConditional %cond %79 %80 ; break; continue on false + + %79 = OpLabel + OpStore %var %uint_6 + OpBranch %80 + + %80 = OpLabel ; continue target + OpStore %var %uint_7 + OpBranch %20 + + %99 = OpLabel ; loop merge + OpStore %var %uint_8 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{1} +} +Loop{ + Assignment{ + Identifier{var} + ScalarConstructor{2} + } + Assignment{ + Identifier{var} + ScalarConstructor{3} + } + Switch{ + ScalarConstructor{42} + { + Case 40{ + Assignment{ + Identifier{var} + ScalarConstructor{40} + } + If{ + ( + ScalarConstructor{false} + ) + { + } + } + Else{ + { + Continue{} + } + } + } + Default{ + } + } + } + Assignment{ + Identifier{var} + ScalarConstructor{6} + } + continuing { + Assignment{ + Identifier{var} + ScalarConstructor{7} + } + } +} +Assignment{ + Identifier{var} + ScalarConstructor{8} +} +Return{} +)")) << ToString(fe.ast_body()); +} + +TEST_F(SpvParserTest, EmitBody_BranchConditional_SwitchBreak_Forward_OnTrue) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_1 + OpSelectionMerge %99 None + OpSwitch %selector %99 20 %20 + + %20 = OpLabel + OpStore %var %uint_20 + OpBranchConditional %cond %30 %99 ; break; forward on true + + %30 = OpLabel + OpStore %var %uint_30 + OpBranch %99 + + %99 = OpLabel ; switch merge + OpStore %var %uint_8 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{1} +} +Switch{ + ScalarConstructor{42} + { + Case 20{ + Assignment{ + Identifier{var} + ScalarConstructor{20} + } + If{ + ( + ScalarConstructor{false} + ) + { + } + } + Else{ + { + Break{} + } + } + Assignment{ + Identifier{var} + ScalarConstructor{30} + } + } + Default{ + } + } +} +Assignment{ + Identifier{var} + ScalarConstructor{8} +} +Return{} +)")) << ToString(fe.ast_body()); +} + +TEST_F(SpvParserTest, EmitBody_BranchConditional_SwitchBreak_Forward_OnFalse) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_1 + OpSelectionMerge %99 None + OpSwitch %selector %99 20 %20 + + %20 = OpLabel + OpStore %var %uint_20 + OpBranchConditional %cond %99 %30 ; break; forward on false + + %30 = OpLabel + OpStore %var %uint_30 + OpBranch %99 + + %99 = OpLabel ; switch merge + OpStore %var %uint_8 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{1} +} +Switch{ + ScalarConstructor{42} + { + Case 20{ + Assignment{ + Identifier{var} + ScalarConstructor{20} + } + If{ + ( + ScalarConstructor{false} + ) + { + Break{} + } + } + Assignment{ + Identifier{var} + ScalarConstructor{30} + } + } + Default{ + } + } +} +Assignment{ + Identifier{var} + ScalarConstructor{8} +} +Return{} +)")) << ToString(fe.ast_body()); } TEST_F(SpvParserTest, DISABLED_EmitBody_BranchConditional_SwitchBreak_Fallthrough_OnTrue) { - // TODO(dneto): needs switch support + // TODO(dneto): needs fallthrough support } TEST_F(SpvParserTest, DISABLED_EmitBody_BranchConditional_SwitchBreak_Fallthrough_OnFalse) { - // TODO(dneto): needs switch support + // TODO(dneto): needs fallthrough support } TEST_F(SpvParserTest, @@ -10232,11 +11350,11 @@ Return{} TEST_F(SpvParserTest, DISABLED_EmitBody_BranchConditional_LoopBreak_Fallthrough_OnTrue) { - // TODO(dneto): needs switch support + // TODO(dneto): needs fallthrough support } TEST_F(SpvParserTest, DISABLED_EmitBody_BranchConditional_LoopBreak_Fallthrough_OnFalse) { - // TODO(dneto): needs switch support + // TODO(dneto): needs fallthrough support } TEST_F(SpvParserTest, EmitBody_BranchConditional_LoopBreak_Forward_OnTrue) { @@ -10668,6 +11786,91 @@ Return{} )")) << ToString(fe.ast_body()); } +TEST_F(SpvParserTest, EmitBody_BranchConditional_LoopContinue_FromSwitch) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_1 + OpBranch %20 + + %20 = OpLabel + OpStore %var %uint_2 + OpLoopMerge %99 %80 None + OpBranch %30 + + %30 = OpLabel + OpStore %var %uint_3 + OpSelectionMerge %79 None + OpSwitch %selector %79 40 %40 + + %40 = OpLabel + OpStore %var %uint_4 + OpBranchConditional %cond2 %80 %80; dup continue edge + + %79 = OpLabel ; switch merge + OpStore %var %uint_5 + OpBranch %80 + + %80 = OpLabel ; continue target + OpStore %var %uint_6 + OpBranch %20 + + %99 = OpLabel + OpStore %var %uint_7 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{1} +} +Loop{ + Assignment{ + Identifier{var} + ScalarConstructor{2} + } + Assignment{ + Identifier{var} + ScalarConstructor{3} + } + Switch{ + ScalarConstructor{42} + { + Case 40{ + Assignment{ + Identifier{var} + ScalarConstructor{4} + } + Continue{} + } + Default{ + } + } + } + Assignment{ + Identifier{var} + ScalarConstructor{5} + } + continuing { + Assignment{ + Identifier{var} + ScalarConstructor{6} + } + } +} +Assignment{ + Identifier{var} + ScalarConstructor{7} +} +Return{} +)")) << ToString(fe.ast_body()); +} + TEST_F(SpvParserTest, EmitBody_BranchConditional_Continue_IfBreak_OnTrue) { auto* p = parser(test::Assemble(CommonTypes() + R"( %100 = OpFunction %void None %voidfn @@ -10856,11 +12059,11 @@ Return{} TEST_F(SpvParserTest, DISABLED_EmitBody_BranchConditional_Continue_Fallthrough_OnTrue) { - // TODO(dneto): needs switch support + // TODO(dneto): needs fallthrough support } TEST_F(SpvParserTest, DISABLED_EmitBody_BranchConditional_Continue_Fallthrough_OnFalse) { - // TODO(dneto): needs switch support + // TODO(dneto): needs fallthrough support } TEST_F(SpvParserTest, EmitBody_BranchConditional_Continue_Forward_OnTrue) { @@ -11098,20 +12301,67 @@ TEST_F(SpvParserTest, TEST_F(SpvParserTest, DISABLED_EmitBody_BranchConditional_Fallthrough_Fallthrough_Same) { // Can only be to the same target. - // TODO(dneto): needs switch support + // TODO(dneto): needs fallthrough support } + TEST_F( SpvParserTest, DISABLED_EmitBody_BranchConditional_Fallthrough_Fallthrough_Different_IsError) { - // TODO(dneto): needs switch support + // TODO(dneto): needs fallthrough support } -TEST_F(SpvParserTest, - DISABLED_EmitBody_BranchConditional_Forward_Forward_Same) { - // TODO(dneto): needs switch support + +TEST_F(SpvParserTest, EmitBody_BranchConditional_Forward_Forward_Same) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpStore %var %uint_1 + OpBranchConditional %cond %99 %99; forward + + %99 = OpLabel + OpStore %var %uint_2 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{var} + ScalarConstructor{1} } +Assignment{ + Identifier{var} + ScalarConstructor{2} +} +Return{} +)")) << ToString(fe.ast_body()); +} + TEST_F(SpvParserTest, - DISABLED_EmitBody_BranchConditional_Forward_Forward_Different_IsError) { - // TODO(dneto): needs switch support + EmitBody_BranchConditional_Forward_Forward_Different_IsError) { + auto* p = parser(test::Assemble(CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + OpBranchConditional %cond %20 %99 + + %20 = OpLabel + OpReturn + + %99 = OpLabel + OpStore %var %uint_2 + OpReturn + + OpFunctionEnd + )")); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error(); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_FALSE(fe.EmitBody()); + EXPECT_THAT(p->error(), + Eq("Control flow diverges at block 10 (to 20, 99) but it is not " + "a structured header (it has no merge instruction)")); } TEST_F(SpvParserTest,