[spirv-writer] Add fallthrough support
This CL adds support for the fallthrough statement in a `case` block. Bug: tint:5 Change-Id: I282643a304846a19212d41bd8bd20a60398bd793 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/22220 Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
parent
916b408111
commit
dadd149d9b
|
@ -89,6 +89,10 @@ uint32_t pipeline_stage_to_execution_model(ast::PipelineStage stage) {
|
||||||
return model;
|
return model;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool LastIsFallthrough(const ast::StatementList& stmts) {
|
||||||
|
return !stmts.empty() && stmts.back()->IsFallthrough();
|
||||||
|
}
|
||||||
|
|
||||||
// A terminator is anything which will case a SPIR-V terminator to be emitted.
|
// A terminator is anything which will case a SPIR-V terminator to be emitted.
|
||||||
// This means things like breaks, fallthroughs and continues which all emit an
|
// This means things like breaks, fallthroughs and continues which all emit an
|
||||||
// OpBranch or return for the OpReturn emission.
|
// OpBranch or return for the OpReturn emission.
|
||||||
|
@ -1395,7 +1399,13 @@ bool Builder::GenerateSwitchStatement(ast::SwitchStatement* stmt) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!LastIsTerminator(item->body())) {
|
if (LastIsFallthrough(item->body())) {
|
||||||
|
if (i == (body.size() - 1)) {
|
||||||
|
error_ = "fallthrough of last case statement is disallowed";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
push_function_inst(spv::Op::OpBranch, {Operand::Int(case_ids[i + 1])});
|
||||||
|
} else if (!LastIsTerminator(item->body())) {
|
||||||
push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
|
push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1491,6 +1501,10 @@ bool Builder::GenerateStatement(ast::Statement* stmt) {
|
||||||
if (stmt->IsContinue()) {
|
if (stmt->IsContinue()) {
|
||||||
return GenerateContinueStatement(stmt->AsContinue());
|
return GenerateContinueStatement(stmt->AsContinue());
|
||||||
}
|
}
|
||||||
|
if (stmt->IsFallthrough()) {
|
||||||
|
// Do nothing here, the fallthrough gets handled by the switch code.
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (stmt->IsIf()) {
|
if (stmt->IsIf()) {
|
||||||
return GenerateIfStatement(stmt->AsIf());
|
return GenerateIfStatement(stmt->AsIf());
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include "src/ast/bool_literal.h"
|
#include "src/ast/bool_literal.h"
|
||||||
#include "src/ast/break_statement.h"
|
#include "src/ast/break_statement.h"
|
||||||
#include "src/ast/case_statement.h"
|
#include "src/ast/case_statement.h"
|
||||||
|
#include "src/ast/fallthrough_statement.h"
|
||||||
#include "src/ast/identifier_expression.h"
|
#include "src/ast/identifier_expression.h"
|
||||||
#include "src/ast/if_statement.h"
|
#include "src/ast/if_statement.h"
|
||||||
#include "src/ast/int_literal.h"
|
#include "src/ast/int_literal.h"
|
||||||
|
@ -321,15 +322,145 @@ OpFunctionEnd
|
||||||
)");
|
)");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BuilderTest, DISABLED_Switch_CaseWithFallthrough) {
|
TEST_F(BuilderTest, Switch_CaseWithFallthrough) {
|
||||||
|
ast::type::I32Type i32;
|
||||||
|
|
||||||
// switch(a) {
|
// switch(a) {
|
||||||
// case 1:
|
// case 1:
|
||||||
// v = 1;
|
// v = 1;
|
||||||
// fallthrough;
|
// fallthrough;
|
||||||
// case 2:
|
// case 2:
|
||||||
// v = 2;
|
// v = 2;
|
||||||
|
// default:
|
||||||
|
// v = 3;
|
||||||
// }
|
// }
|
||||||
FAIL();
|
|
||||||
|
auto v =
|
||||||
|
std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
|
||||||
|
auto a =
|
||||||
|
std::make_unique<ast::Variable>("a", ast::StorageClass::kPrivate, &i32);
|
||||||
|
|
||||||
|
ast::StatementList case_1_body;
|
||||||
|
case_1_body.push_back(std::make_unique<ast::AssignmentStatement>(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("v"),
|
||||||
|
std::make_unique<ast::ScalarConstructorExpression>(
|
||||||
|
std::make_unique<ast::IntLiteral>(&i32, 1))));
|
||||||
|
case_1_body.push_back(std::make_unique<ast::FallthroughStatement>());
|
||||||
|
|
||||||
|
ast::StatementList case_2_body;
|
||||||
|
case_2_body.push_back(std::make_unique<ast::AssignmentStatement>(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("v"),
|
||||||
|
std::make_unique<ast::ScalarConstructorExpression>(
|
||||||
|
std::make_unique<ast::IntLiteral>(&i32, 2))));
|
||||||
|
|
||||||
|
ast::StatementList default_body;
|
||||||
|
default_body.push_back(std::make_unique<ast::AssignmentStatement>(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("v"),
|
||||||
|
std::make_unique<ast::ScalarConstructorExpression>(
|
||||||
|
std::make_unique<ast::IntLiteral>(&i32, 3))));
|
||||||
|
|
||||||
|
ast::CaseStatementList cases;
|
||||||
|
cases.push_back(std::make_unique<ast::CaseStatement>(
|
||||||
|
std::make_unique<ast::IntLiteral>(&i32, 1), std::move(case_1_body)));
|
||||||
|
cases.push_back(std::make_unique<ast::CaseStatement>(
|
||||||
|
std::make_unique<ast::IntLiteral>(&i32, 2), std::move(case_2_body)));
|
||||||
|
cases.push_back(
|
||||||
|
std::make_unique<ast::CaseStatement>(std::move(default_body)));
|
||||||
|
|
||||||
|
ast::SwitchStatement expr(std::make_unique<ast::IdentifierExpression>("a"),
|
||||||
|
std::move(cases));
|
||||||
|
|
||||||
|
Context ctx;
|
||||||
|
ast::Module mod;
|
||||||
|
TypeDeterminer td(&ctx, &mod);
|
||||||
|
td.RegisterVariableForTesting(v.get());
|
||||||
|
td.RegisterVariableForTesting(a.get());
|
||||||
|
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||||
|
|
||||||
|
ast::Function func("a_func", {}, &i32);
|
||||||
|
|
||||||
|
Builder b(&mod);
|
||||||
|
ASSERT_TRUE(b.GenerateGlobalVariable(v.get())) << b.error();
|
||||||
|
ASSERT_TRUE(b.GenerateGlobalVariable(a.get())) << b.error();
|
||||||
|
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
|
||||||
|
|
||||||
|
EXPECT_TRUE(b.GenerateSwitchStatement(&expr)) << b.error();
|
||||||
|
|
||||||
|
EXPECT_EQ(DumpBuilder(b), R"(OpName %1 "v"
|
||||||
|
OpName %5 "a"
|
||||||
|
OpName %7 "a_func"
|
||||||
|
%3 = OpTypeInt 32 1
|
||||||
|
%2 = OpTypePointer Private %3
|
||||||
|
%4 = OpConstantNull %3
|
||||||
|
%1 = OpVariable %2 Private %4
|
||||||
|
%5 = OpVariable %2 Private %4
|
||||||
|
%6 = OpTypeFunction %3
|
||||||
|
%14 = OpConstant %3 1
|
||||||
|
%15 = OpConstant %3 2
|
||||||
|
%16 = OpConstant %3 3
|
||||||
|
%7 = OpFunction %3 None %6
|
||||||
|
%8 = OpLabel
|
||||||
|
%10 = OpLoad %3 %5
|
||||||
|
OpSelectionMerge %9 None
|
||||||
|
OpSwitch %10 %11 1 %12 2 %13
|
||||||
|
%12 = OpLabel
|
||||||
|
OpStore %1 %14
|
||||||
|
OpBranch %13
|
||||||
|
%13 = OpLabel
|
||||||
|
OpStore %1 %15
|
||||||
|
OpBranch %9
|
||||||
|
%11 = OpLabel
|
||||||
|
OpStore %1 %16
|
||||||
|
OpBranch %9
|
||||||
|
%9 = OpLabel
|
||||||
|
OpFunctionEnd
|
||||||
|
)");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BuilderTest, Switch_CaseFallthroughLastStatement) {
|
||||||
|
ast::type::I32Type i32;
|
||||||
|
|
||||||
|
// switch(a) {
|
||||||
|
// case 1:
|
||||||
|
// v = 1;
|
||||||
|
// fallthrough;
|
||||||
|
// }
|
||||||
|
|
||||||
|
auto v =
|
||||||
|
std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
|
||||||
|
auto a =
|
||||||
|
std::make_unique<ast::Variable>("a", ast::StorageClass::kPrivate, &i32);
|
||||||
|
|
||||||
|
ast::StatementList case_1_body;
|
||||||
|
case_1_body.push_back(std::make_unique<ast::AssignmentStatement>(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("v"),
|
||||||
|
std::make_unique<ast::ScalarConstructorExpression>(
|
||||||
|
std::make_unique<ast::IntLiteral>(&i32, 1))));
|
||||||
|
case_1_body.push_back(std::make_unique<ast::FallthroughStatement>());
|
||||||
|
|
||||||
|
ast::CaseStatementList cases;
|
||||||
|
cases.push_back(std::make_unique<ast::CaseStatement>(
|
||||||
|
std::make_unique<ast::IntLiteral>(&i32, 1), std::move(case_1_body)));
|
||||||
|
|
||||||
|
ast::SwitchStatement expr(std::make_unique<ast::IdentifierExpression>("a"),
|
||||||
|
std::move(cases));
|
||||||
|
|
||||||
|
Context ctx;
|
||||||
|
ast::Module mod;
|
||||||
|
TypeDeterminer td(&ctx, &mod);
|
||||||
|
td.RegisterVariableForTesting(v.get());
|
||||||
|
td.RegisterVariableForTesting(a.get());
|
||||||
|
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||||
|
|
||||||
|
ast::Function func("a_func", {}, &i32);
|
||||||
|
|
||||||
|
Builder b(&mod);
|
||||||
|
ASSERT_TRUE(b.GenerateGlobalVariable(v.get())) << b.error();
|
||||||
|
ASSERT_TRUE(b.GenerateGlobalVariable(a.get())) << b.error();
|
||||||
|
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
|
||||||
|
|
||||||
|
EXPECT_FALSE(b.GenerateSwitchStatement(&expr)) << b.error();
|
||||||
|
EXPECT_EQ(b.error(), "fallthrough of last case statement is disallowed");
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(dsinclair): Implement when parsing is handled for multi-value
|
// TODO(dsinclair): Implement when parsing is handled for multi-value
|
||||||
|
|
Loading…
Reference in New Issue