diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index d633f9c0e9..7169411c68 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -89,6 +89,10 @@ uint32_t pipeline_stage_to_execution_model(ast::PipelineStage stage) { return model; } +bool LastIsFallthrough(const ast::StatementList& stmts) { + return !stmts.empty() && stmts.back()->IsFallthrough(); +} + // A terminator is anything which will case a SPIR-V terminator to be emitted. // This means things like breaks, fallthroughs and continues which all emit an // OpBranch or return for the OpReturn emission. @@ -1395,7 +1399,13 @@ bool Builder::GenerateSwitchStatement(ast::SwitchStatement* stmt) { return false; } - if (!LastIsTerminator(item->body())) { + if (LastIsFallthrough(item->body())) { + if (i == (body.size() - 1)) { + error_ = "fallthrough of last case statement is disallowed"; + return false; + } + push_function_inst(spv::Op::OpBranch, {Operand::Int(case_ids[i + 1])}); + } else if (!LastIsTerminator(item->body())) { push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)}); } } @@ -1491,6 +1501,10 @@ bool Builder::GenerateStatement(ast::Statement* stmt) { if (stmt->IsContinue()) { return GenerateContinueStatement(stmt->AsContinue()); } + if (stmt->IsFallthrough()) { + // Do nothing here, the fallthrough gets handled by the switch code. + return true; + } if (stmt->IsIf()) { return GenerateIfStatement(stmt->AsIf()); } diff --git a/src/writer/spirv/builder_switch_test.cc b/src/writer/spirv/builder_switch_test.cc index 1284345765..d717c08d91 100644 --- a/src/writer/spirv/builder_switch_test.cc +++ b/src/writer/spirv/builder_switch_test.cc @@ -19,6 +19,7 @@ #include "src/ast/bool_literal.h" #include "src/ast/break_statement.h" #include "src/ast/case_statement.h" +#include "src/ast/fallthrough_statement.h" #include "src/ast/identifier_expression.h" #include "src/ast/if_statement.h" #include "src/ast/int_literal.h" @@ -321,15 +322,145 @@ OpFunctionEnd )"); } -TEST_F(BuilderTest, DISABLED_Switch_CaseWithFallthrough) { - // switch (a) { +TEST_F(BuilderTest, Switch_CaseWithFallthrough) { + ast::type::I32Type i32; + + // switch(a) { // case 1: - // v = 1; - // fallthrough; + // v = 1; + // fallthrough; // case 2: - // v = 2; + // v = 2; + // default: + // v = 3; // } - FAIL(); + + auto v = + std::make_unique("v", ast::StorageClass::kPrivate, &i32); + auto a = + std::make_unique("a", ast::StorageClass::kPrivate, &i32); + + ast::StatementList case_1_body; + case_1_body.push_back(std::make_unique( + std::make_unique("v"), + std::make_unique( + std::make_unique(&i32, 1)))); + case_1_body.push_back(std::make_unique()); + + ast::StatementList case_2_body; + case_2_body.push_back(std::make_unique( + std::make_unique("v"), + std::make_unique( + std::make_unique(&i32, 2)))); + + ast::StatementList default_body; + default_body.push_back(std::make_unique( + std::make_unique("v"), + std::make_unique( + std::make_unique(&i32, 3)))); + + ast::CaseStatementList cases; + cases.push_back(std::make_unique( + std::make_unique(&i32, 1), std::move(case_1_body))); + cases.push_back(std::make_unique( + std::make_unique(&i32, 2), std::move(case_2_body))); + cases.push_back( + std::make_unique(std::move(default_body))); + + ast::SwitchStatement expr(std::make_unique("a"), + std::move(cases)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(v.get()); + td.RegisterVariableForTesting(a.get()); + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + ast::Function func("a_func", {}, &i32); + + Builder b(&mod); + ASSERT_TRUE(b.GenerateGlobalVariable(v.get())) << b.error(); + ASSERT_TRUE(b.GenerateGlobalVariable(a.get())) << b.error(); + ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); + + EXPECT_TRUE(b.GenerateSwitchStatement(&expr)) << b.error(); + + EXPECT_EQ(DumpBuilder(b), R"(OpName %1 "v" +OpName %5 "a" +OpName %7 "a_func" +%3 = OpTypeInt 32 1 +%2 = OpTypePointer Private %3 +%4 = OpConstantNull %3 +%1 = OpVariable %2 Private %4 +%5 = OpVariable %2 Private %4 +%6 = OpTypeFunction %3 +%14 = OpConstant %3 1 +%15 = OpConstant %3 2 +%16 = OpConstant %3 3 +%7 = OpFunction %3 None %6 +%8 = OpLabel +%10 = OpLoad %3 %5 +OpSelectionMerge %9 None +OpSwitch %10 %11 1 %12 2 %13 +%12 = OpLabel +OpStore %1 %14 +OpBranch %13 +%13 = OpLabel +OpStore %1 %15 +OpBranch %9 +%11 = OpLabel +OpStore %1 %16 +OpBranch %9 +%9 = OpLabel +OpFunctionEnd +)"); +} + +TEST_F(BuilderTest, Switch_CaseFallthroughLastStatement) { + ast::type::I32Type i32; + + // switch(a) { + // case 1: + // v = 1; + // fallthrough; + // } + + auto v = + std::make_unique("v", ast::StorageClass::kPrivate, &i32); + auto a = + std::make_unique("a", ast::StorageClass::kPrivate, &i32); + + ast::StatementList case_1_body; + case_1_body.push_back(std::make_unique( + std::make_unique("v"), + std::make_unique( + std::make_unique(&i32, 1)))); + case_1_body.push_back(std::make_unique()); + + ast::CaseStatementList cases; + cases.push_back(std::make_unique( + std::make_unique(&i32, 1), std::move(case_1_body))); + + ast::SwitchStatement expr(std::make_unique("a"), + std::move(cases)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(v.get()); + td.RegisterVariableForTesting(a.get()); + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + ast::Function func("a_func", {}, &i32); + + Builder b(&mod); + ASSERT_TRUE(b.GenerateGlobalVariable(v.get())) << b.error(); + ASSERT_TRUE(b.GenerateGlobalVariable(a.get())) << b.error(); + ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); + + EXPECT_FALSE(b.GenerateSwitchStatement(&expr)) << b.error(); + EXPECT_EQ(b.error(), "fallthrough of last case statement is disallowed"); } // TODO(dsinclair): Implement when parsing is handled for multi-value