[spirv-writer] Add fallthrough support

This CL adds support for the fallthrough statement in a `case` block.

Bug: tint:5
Change-Id: I282643a304846a19212d41bd8bd20a60398bd793
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/22220
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
dan sinclair 2020-06-01 18:56:56 +00:00
parent 916b408111
commit dadd149d9b
2 changed files with 152 additions and 7 deletions

View File

@ -89,6 +89,10 @@ uint32_t pipeline_stage_to_execution_model(ast::PipelineStage stage) {
return model; return model;
} }
bool LastIsFallthrough(const ast::StatementList& stmts) {
return !stmts.empty() && stmts.back()->IsFallthrough();
}
// A terminator is anything which will case a SPIR-V terminator to be emitted. // A terminator is anything which will case a SPIR-V terminator to be emitted.
// This means things like breaks, fallthroughs and continues which all emit an // This means things like breaks, fallthroughs and continues which all emit an
// OpBranch or return for the OpReturn emission. // OpBranch or return for the OpReturn emission.
@ -1395,7 +1399,13 @@ bool Builder::GenerateSwitchStatement(ast::SwitchStatement* stmt) {
return false; return false;
} }
if (!LastIsTerminator(item->body())) { if (LastIsFallthrough(item->body())) {
if (i == (body.size() - 1)) {
error_ = "fallthrough of last case statement is disallowed";
return false;
}
push_function_inst(spv::Op::OpBranch, {Operand::Int(case_ids[i + 1])});
} else if (!LastIsTerminator(item->body())) {
push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)}); push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
} }
} }
@ -1491,6 +1501,10 @@ bool Builder::GenerateStatement(ast::Statement* stmt) {
if (stmt->IsContinue()) { if (stmt->IsContinue()) {
return GenerateContinueStatement(stmt->AsContinue()); return GenerateContinueStatement(stmt->AsContinue());
} }
if (stmt->IsFallthrough()) {
// Do nothing here, the fallthrough gets handled by the switch code.
return true;
}
if (stmt->IsIf()) { if (stmt->IsIf()) {
return GenerateIfStatement(stmt->AsIf()); return GenerateIfStatement(stmt->AsIf());
} }

View File

@ -19,6 +19,7 @@
#include "src/ast/bool_literal.h" #include "src/ast/bool_literal.h"
#include "src/ast/break_statement.h" #include "src/ast/break_statement.h"
#include "src/ast/case_statement.h" #include "src/ast/case_statement.h"
#include "src/ast/fallthrough_statement.h"
#include "src/ast/identifier_expression.h" #include "src/ast/identifier_expression.h"
#include "src/ast/if_statement.h" #include "src/ast/if_statement.h"
#include "src/ast/int_literal.h" #include "src/ast/int_literal.h"
@ -321,15 +322,145 @@ OpFunctionEnd
)"); )");
} }
TEST_F(BuilderTest, DISABLED_Switch_CaseWithFallthrough) { TEST_F(BuilderTest, Switch_CaseWithFallthrough) {
// switch (a) { ast::type::I32Type i32;
// switch(a) {
// case 1: // case 1:
// v = 1; // v = 1;
// fallthrough; // fallthrough;
// case 2: // case 2:
// v = 2; // v = 2;
// default:
// v = 3;
// } // }
FAIL();
auto v =
std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
auto a =
std::make_unique<ast::Variable>("a", ast::StorageClass::kPrivate, &i32);
ast::StatementList case_1_body;
case_1_body.push_back(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("v"),
std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 1))));
case_1_body.push_back(std::make_unique<ast::FallthroughStatement>());
ast::StatementList case_2_body;
case_2_body.push_back(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("v"),
std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 2))));
ast::StatementList default_body;
default_body.push_back(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("v"),
std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 3))));
ast::CaseStatementList cases;
cases.push_back(std::make_unique<ast::CaseStatement>(
std::make_unique<ast::IntLiteral>(&i32, 1), std::move(case_1_body)));
cases.push_back(std::make_unique<ast::CaseStatement>(
std::make_unique<ast::IntLiteral>(&i32, 2), std::move(case_2_body)));
cases.push_back(
std::make_unique<ast::CaseStatement>(std::move(default_body)));
ast::SwitchStatement expr(std::make_unique<ast::IdentifierExpression>("a"),
std::move(cases));
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(v.get());
td.RegisterVariableForTesting(a.get());
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func("a_func", {}, &i32);
Builder b(&mod);
ASSERT_TRUE(b.GenerateGlobalVariable(v.get())) << b.error();
ASSERT_TRUE(b.GenerateGlobalVariable(a.get())) << b.error();
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
EXPECT_TRUE(b.GenerateSwitchStatement(&expr)) << b.error();
EXPECT_EQ(DumpBuilder(b), R"(OpName %1 "v"
OpName %5 "a"
OpName %7 "a_func"
%3 = OpTypeInt 32 1
%2 = OpTypePointer Private %3
%4 = OpConstantNull %3
%1 = OpVariable %2 Private %4
%5 = OpVariable %2 Private %4
%6 = OpTypeFunction %3
%14 = OpConstant %3 1
%15 = OpConstant %3 2
%16 = OpConstant %3 3
%7 = OpFunction %3 None %6
%8 = OpLabel
%10 = OpLoad %3 %5
OpSelectionMerge %9 None
OpSwitch %10 %11 1 %12 2 %13
%12 = OpLabel
OpStore %1 %14
OpBranch %13
%13 = OpLabel
OpStore %1 %15
OpBranch %9
%11 = OpLabel
OpStore %1 %16
OpBranch %9
%9 = OpLabel
OpFunctionEnd
)");
}
TEST_F(BuilderTest, Switch_CaseFallthroughLastStatement) {
ast::type::I32Type i32;
// switch(a) {
// case 1:
// v = 1;
// fallthrough;
// }
auto v =
std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
auto a =
std::make_unique<ast::Variable>("a", ast::StorageClass::kPrivate, &i32);
ast::StatementList case_1_body;
case_1_body.push_back(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("v"),
std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 1))));
case_1_body.push_back(std::make_unique<ast::FallthroughStatement>());
ast::CaseStatementList cases;
cases.push_back(std::make_unique<ast::CaseStatement>(
std::make_unique<ast::IntLiteral>(&i32, 1), std::move(case_1_body)));
ast::SwitchStatement expr(std::make_unique<ast::IdentifierExpression>("a"),
std::move(cases));
Context ctx;
ast::Module mod;
TypeDeterminer td(&ctx, &mod);
td.RegisterVariableForTesting(v.get());
td.RegisterVariableForTesting(a.get());
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func("a_func", {}, &i32);
Builder b(&mod);
ASSERT_TRUE(b.GenerateGlobalVariable(v.get())) << b.error();
ASSERT_TRUE(b.GenerateGlobalVariable(a.get())) << b.error();
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
EXPECT_FALSE(b.GenerateSwitchStatement(&expr)) << b.error();
EXPECT_EQ(b.error(), "fallthrough of last case statement is disallowed");
} }
// TODO(dsinclair): Implement when parsing is handled for multi-value // TODO(dsinclair): Implement when parsing is handled for multi-value