[spirv-writer] Add fallthrough support
This CL adds support for the fallthrough statement in a `case` block. Bug: tint:5 Change-Id: I282643a304846a19212d41bd8bd20a60398bd793 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/22220 Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
parent
916b408111
commit
dadd149d9b
|
@ -89,6 +89,10 @@ uint32_t pipeline_stage_to_execution_model(ast::PipelineStage stage) {
|
|||
return model;
|
||||
}
|
||||
|
||||
bool LastIsFallthrough(const ast::StatementList& stmts) {
|
||||
return !stmts.empty() && stmts.back()->IsFallthrough();
|
||||
}
|
||||
|
||||
// A terminator is anything which will case a SPIR-V terminator to be emitted.
|
||||
// This means things like breaks, fallthroughs and continues which all emit an
|
||||
// OpBranch or return for the OpReturn emission.
|
||||
|
@ -1395,7 +1399,13 @@ bool Builder::GenerateSwitchStatement(ast::SwitchStatement* stmt) {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (!LastIsTerminator(item->body())) {
|
||||
if (LastIsFallthrough(item->body())) {
|
||||
if (i == (body.size() - 1)) {
|
||||
error_ = "fallthrough of last case statement is disallowed";
|
||||
return false;
|
||||
}
|
||||
push_function_inst(spv::Op::OpBranch, {Operand::Int(case_ids[i + 1])});
|
||||
} else if (!LastIsTerminator(item->body())) {
|
||||
push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
|
||||
}
|
||||
}
|
||||
|
@ -1491,6 +1501,10 @@ bool Builder::GenerateStatement(ast::Statement* stmt) {
|
|||
if (stmt->IsContinue()) {
|
||||
return GenerateContinueStatement(stmt->AsContinue());
|
||||
}
|
||||
if (stmt->IsFallthrough()) {
|
||||
// Do nothing here, the fallthrough gets handled by the switch code.
|
||||
return true;
|
||||
}
|
||||
if (stmt->IsIf()) {
|
||||
return GenerateIfStatement(stmt->AsIf());
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "src/ast/bool_literal.h"
|
||||
#include "src/ast/break_statement.h"
|
||||
#include "src/ast/case_statement.h"
|
||||
#include "src/ast/fallthrough_statement.h"
|
||||
#include "src/ast/identifier_expression.h"
|
||||
#include "src/ast/if_statement.h"
|
||||
#include "src/ast/int_literal.h"
|
||||
|
@ -321,15 +322,145 @@ OpFunctionEnd
|
|||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, DISABLED_Switch_CaseWithFallthrough) {
|
||||
// switch (a) {
|
||||
TEST_F(BuilderTest, Switch_CaseWithFallthrough) {
|
||||
ast::type::I32Type i32;
|
||||
|
||||
// switch(a) {
|
||||
// case 1:
|
||||
// v = 1;
|
||||
// fallthrough;
|
||||
// case 2:
|
||||
// v = 2;
|
||||
// default:
|
||||
// v = 3;
|
||||
// }
|
||||
FAIL();
|
||||
|
||||
auto v =
|
||||
std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
|
||||
auto a =
|
||||
std::make_unique<ast::Variable>("a", ast::StorageClass::kPrivate, &i32);
|
||||
|
||||
ast::StatementList case_1_body;
|
||||
case_1_body.push_back(std::make_unique<ast::AssignmentStatement>(
|
||||
std::make_unique<ast::IdentifierExpression>("v"),
|
||||
std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::IntLiteral>(&i32, 1))));
|
||||
case_1_body.push_back(std::make_unique<ast::FallthroughStatement>());
|
||||
|
||||
ast::StatementList case_2_body;
|
||||
case_2_body.push_back(std::make_unique<ast::AssignmentStatement>(
|
||||
std::make_unique<ast::IdentifierExpression>("v"),
|
||||
std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::IntLiteral>(&i32, 2))));
|
||||
|
||||
ast::StatementList default_body;
|
||||
default_body.push_back(std::make_unique<ast::AssignmentStatement>(
|
||||
std::make_unique<ast::IdentifierExpression>("v"),
|
||||
std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::IntLiteral>(&i32, 3))));
|
||||
|
||||
ast::CaseStatementList cases;
|
||||
cases.push_back(std::make_unique<ast::CaseStatement>(
|
||||
std::make_unique<ast::IntLiteral>(&i32, 1), std::move(case_1_body)));
|
||||
cases.push_back(std::make_unique<ast::CaseStatement>(
|
||||
std::make_unique<ast::IntLiteral>(&i32, 2), std::move(case_2_body)));
|
||||
cases.push_back(
|
||||
std::make_unique<ast::CaseStatement>(std::move(default_body)));
|
||||
|
||||
ast::SwitchStatement expr(std::make_unique<ast::IdentifierExpression>("a"),
|
||||
std::move(cases));
|
||||
|
||||
Context ctx;
|
||||
ast::Module mod;
|
||||
TypeDeterminer td(&ctx, &mod);
|
||||
td.RegisterVariableForTesting(v.get());
|
||||
td.RegisterVariableForTesting(a.get());
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
|
||||
ast::Function func("a_func", {}, &i32);
|
||||
|
||||
Builder b(&mod);
|
||||
ASSERT_TRUE(b.GenerateGlobalVariable(v.get())) << b.error();
|
||||
ASSERT_TRUE(b.GenerateGlobalVariable(a.get())) << b.error();
|
||||
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
|
||||
|
||||
EXPECT_TRUE(b.GenerateSwitchStatement(&expr)) << b.error();
|
||||
|
||||
EXPECT_EQ(DumpBuilder(b), R"(OpName %1 "v"
|
||||
OpName %5 "a"
|
||||
OpName %7 "a_func"
|
||||
%3 = OpTypeInt 32 1
|
||||
%2 = OpTypePointer Private %3
|
||||
%4 = OpConstantNull %3
|
||||
%1 = OpVariable %2 Private %4
|
||||
%5 = OpVariable %2 Private %4
|
||||
%6 = OpTypeFunction %3
|
||||
%14 = OpConstant %3 1
|
||||
%15 = OpConstant %3 2
|
||||
%16 = OpConstant %3 3
|
||||
%7 = OpFunction %3 None %6
|
||||
%8 = OpLabel
|
||||
%10 = OpLoad %3 %5
|
||||
OpSelectionMerge %9 None
|
||||
OpSwitch %10 %11 1 %12 2 %13
|
||||
%12 = OpLabel
|
||||
OpStore %1 %14
|
||||
OpBranch %13
|
||||
%13 = OpLabel
|
||||
OpStore %1 %15
|
||||
OpBranch %9
|
||||
%11 = OpLabel
|
||||
OpStore %1 %16
|
||||
OpBranch %9
|
||||
%9 = OpLabel
|
||||
OpFunctionEnd
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Switch_CaseFallthroughLastStatement) {
|
||||
ast::type::I32Type i32;
|
||||
|
||||
// switch(a) {
|
||||
// case 1:
|
||||
// v = 1;
|
||||
// fallthrough;
|
||||
// }
|
||||
|
||||
auto v =
|
||||
std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
|
||||
auto a =
|
||||
std::make_unique<ast::Variable>("a", ast::StorageClass::kPrivate, &i32);
|
||||
|
||||
ast::StatementList case_1_body;
|
||||
case_1_body.push_back(std::make_unique<ast::AssignmentStatement>(
|
||||
std::make_unique<ast::IdentifierExpression>("v"),
|
||||
std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::IntLiteral>(&i32, 1))));
|
||||
case_1_body.push_back(std::make_unique<ast::FallthroughStatement>());
|
||||
|
||||
ast::CaseStatementList cases;
|
||||
cases.push_back(std::make_unique<ast::CaseStatement>(
|
||||
std::make_unique<ast::IntLiteral>(&i32, 1), std::move(case_1_body)));
|
||||
|
||||
ast::SwitchStatement expr(std::make_unique<ast::IdentifierExpression>("a"),
|
||||
std::move(cases));
|
||||
|
||||
Context ctx;
|
||||
ast::Module mod;
|
||||
TypeDeterminer td(&ctx, &mod);
|
||||
td.RegisterVariableForTesting(v.get());
|
||||
td.RegisterVariableForTesting(a.get());
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
|
||||
ast::Function func("a_func", {}, &i32);
|
||||
|
||||
Builder b(&mod);
|
||||
ASSERT_TRUE(b.GenerateGlobalVariable(v.get())) << b.error();
|
||||
ASSERT_TRUE(b.GenerateGlobalVariable(a.get())) << b.error();
|
||||
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
|
||||
|
||||
EXPECT_FALSE(b.GenerateSwitchStatement(&expr)) << b.error();
|
||||
EXPECT_EQ(b.error(), "fallthrough of last case statement is disallowed");
|
||||
}
|
||||
|
||||
// TODO(dsinclair): Implement when parsing is handled for multi-value
|
||||
|
|
Loading…
Reference in New Issue