[spirv-writer] Add elseif support.

This CL adds support for having elseif statements after an if statement.

Bug: tint:5
Change-Id: I3cd3c5bddaa57c998b1a3fbee7bd87536533301d
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/19500
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
dan sinclair 2020-04-14 16:53:27 +00:00
parent 631a7ac72b
commit f963128c88
3 changed files with 240 additions and 31 deletions

View File

@ -630,8 +630,12 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
return result_id; return result_id;
} }
bool Builder::GenerateIfStatement(ast::IfStatement* stmt) { bool Builder::GenerateConditionalBlock(
auto cond_id = GenerateExpression(stmt->condition()); ast::Expression* cond,
const ast::StatementList& true_body,
size_t cur_else_idx,
const ast::ElseStatementList& else_stmts) {
auto cond_id = GenerateExpression(cond);
if (cond_id == 0) { if (cond_id == 0) {
return false; return false;
} }
@ -646,10 +650,10 @@ bool Builder::GenerateIfStatement(ast::IfStatement* stmt) {
auto true_block = result_op(); auto true_block = result_op();
auto true_block_id = true_block.to_i(); auto true_block_id = true_block.to_i();
// if there are no else statements we branch on false to the merge block // if there are no more else statements we branch on false to the merge block
// otherwise we branch to the false block // otherwise we branch to the false block
auto false_block_id = auto false_block_id =
stmt->has_else_statements() ? next_id() : merge_block_id; cur_else_idx < else_stmts.size() ? next_id() : merge_block_id;
push_function_inst(spv::Op::OpBranchConditional, push_function_inst(spv::Op::OpBranchConditional,
{Operand::Int(cond_id), Operand::Int(true_block_id), {Operand::Int(cond_id), Operand::Int(true_block_id),
@ -657,28 +661,33 @@ bool Builder::GenerateIfStatement(ast::IfStatement* stmt) {
// Output true block // Output true block
push_function_inst(spv::Op::OpLabel, {true_block}); push_function_inst(spv::Op::OpLabel, {true_block});
for (const auto& inst : stmt->body()) { if (!GenerateStatementList(true_body)) {
if (!GenerateStatement(inst.get())) {
return false; return false;
} }
}
// TODO(dsinclair): The branch should be optional based on how the // TODO(dsinclair): The branch should be optional based on how the
// StatementList ended ... // StatementList ended ...
push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)}); push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
// Start the false block if needed
if (false_block_id != merge_block_id) { if (false_block_id != merge_block_id) {
push_function_inst(spv::Op::OpLabel, {Operand::Int(false_block_id)}); push_function_inst(spv::Op::OpLabel, {Operand::Int(false_block_id)});
for (const auto& else_stmt : stmt->else_statements()) { auto* else_stmt = else_stmts[cur_else_idx].get();
if (!GenerateElseStatement(else_stmt.get())) { // Handle the else case by just outputting the statements.
if (!else_stmt->HasCondition()) {
if (!GenerateStatementList(else_stmt->body())) {
return false; return false;
} }
}
// TODO(dsinclair): The branch should be optional based on how the // TODO(dsinclair): The branch should be optional based on how the
// StatementList ended ... // StatementList ended ...
push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)}); push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
} else {
if (!GenerateConditionalBlock(else_stmt->condition(), else_stmt->body(),
cur_else_idx + 1, else_stmts)) {
return false;
}
push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
}
} }
// Output the merge block // Output the merge block
@ -687,18 +696,11 @@ bool Builder::GenerateIfStatement(ast::IfStatement* stmt) {
return true; return true;
} }
bool Builder::GenerateElseStatement(ast::ElseStatement* stmt) { bool Builder::GenerateIfStatement(ast::IfStatement* stmt) {
// TODO(dsinclair): handle else if if (!GenerateConditionalBlock(stmt->condition(), stmt->body(), 0,
if (stmt->HasCondition()) { stmt->else_statements())) {
error_ = "else if not handled yet";
return false; return false;
} }
for (const auto& inst : stmt->body()) {
if (!GenerateStatement(inst.get())) {
return false;
}
}
return true; return true;
} }
@ -716,6 +718,15 @@ bool Builder::GenerateReturnStatement(ast::ReturnStatement* stmt) {
return true; return true;
} }
bool Builder::GenerateStatementList(const ast::StatementList& list) {
for (const auto& inst : list) {
if (!GenerateStatement(inst.get())) {
return false;
}
}
return true;
}
bool Builder::GenerateStatement(ast::Statement* stmt) { bool Builder::GenerateStatement(ast::Statement* stmt) {
if (stmt->IsAssign()) { if (stmt->IsAssign()) {
return GenerateAssignStatement(stmt->AsAssign()); return GenerateAssignStatement(stmt->AsAssign());

View File

@ -22,6 +22,7 @@
#include "spirv/unified1/spirv.h" #include "spirv/unified1/spirv.h"
#include "src/ast/builtin.h" #include "src/ast/builtin.h"
#include "src/ast/else_statement.h"
#include "src/ast/literal.h" #include "src/ast/literal.h"
#include "src/ast/module.h" #include "src/ast/module.h"
#include "src/ast/struct_member.h" #include "src/ast/struct_member.h"
@ -148,10 +149,6 @@ class Builder {
/// @param assign the statement to generate /// @param assign the statement to generate
/// @returns true if the statement was successfully generated /// @returns true if the statement was successfully generated
bool GenerateAssignStatement(ast::AssignmentStatement* assign); bool GenerateAssignStatement(ast::AssignmentStatement* assign);
/// Generates an else statement
/// @param stmt the statement to generate
/// @returns true on successfull generation
bool GenerateElseStatement(ast::ElseStatement* stmt);
/// Generates an entry point instruction /// Generates an entry point instruction
/// @param ep the entry point /// @param ep the entry point
/// @returns true if the instruction was generated, false otherwise /// @returns true if the instruction was generated, false otherwise
@ -209,10 +206,24 @@ class Builder {
/// @param stmt the statement to generate /// @param stmt the statement to generate
/// @returns true on success, false otherwise /// @returns true on success, false otherwise
bool GenerateReturnStatement(ast::ReturnStatement* stmt); bool GenerateReturnStatement(ast::ReturnStatement* stmt);
/// Generates a conditional section merge block
/// @param cond the condition
/// @param true_body the statements making up the true block
/// @param cur_else_idx the index of the current else statement to process
/// @param else_stmts the list of all else statements
/// @returns true on success, false on failure
bool GenerateConditionalBlock(ast::Expression* cond,
const ast::StatementList& true_body,
size_t cur_else_idx,
const ast::ElseStatementList& else_stmts);
/// Generates a statement /// Generates a statement
/// @param stmt the statement to generate /// @param stmt the statement to generate
/// @returns true if the statement was generated /// @returns true if the statement was generated
bool GenerateStatement(ast::Statement* stmt); bool GenerateStatement(ast::Statement* stmt);
/// Generates a list of statements
/// @param list the statement list to generate
/// @returns true on successful generation
bool GenerateStatementList(const ast::StatementList& list);
/// Geneates an OpStore /// Geneates an OpStore
/// @param to the ID to store too /// @param to the ID to store too
/// @param from the ID to store from /// @param from the ID to store from

View File

@ -39,6 +39,8 @@ using BuilderTest = testing::Test;
TEST_F(BuilderTest, If_Empty) { TEST_F(BuilderTest, If_Empty) {
ast::type::BoolType bool_type; ast::type::BoolType bool_type;
// if (true) {
// }
auto cond = std::make_unique<ast::ScalarConstructorExpression>( auto cond = std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::BoolLiteral>(&bool_type, true)); std::make_unique<ast::BoolLiteral>(&bool_type, true));
@ -68,6 +70,9 @@ TEST_F(BuilderTest, If_WithStatements) {
ast::type::BoolType bool_type; ast::type::BoolType bool_type;
ast::type::I32Type i32; ast::type::I32Type i32;
// if (true) {
// v = 2;
// }
auto var = auto var =
std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32); std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
@ -114,6 +119,11 @@ TEST_F(BuilderTest, If_WithElse) {
ast::type::BoolType bool_type; ast::type::BoolType bool_type;
ast::type::I32Type i32; ast::type::I32Type i32;
// if (true) {
// v = 2;
// } else {
// v = 3;
// }
auto var = auto var =
std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32); std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
@ -171,9 +181,186 @@ OpBranch %6
)"); )");
} }
TEST_F(BuilderTest, DISABLED_If_WithElseIf) {} TEST_F(BuilderTest, If_WithElseIf) {
ast::type::BoolType bool_type;
ast::type::I32Type i32;
TEST_F(BuilderTest, DISABLED_If_WithMultiple) {} // if (true) {
// v = 2;
// } elseif (true) {
// v = 3;
// }
auto var =
std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
ast::StatementList body;
body.push_back(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("v"),
std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 2))));
ast::StatementList else_body;
else_body.push_back(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("v"),
std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 3))));
auto else_cond = std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::BoolLiteral>(&bool_type, true));
ast::ElseStatementList else_stmts;
else_stmts.push_back(std::make_unique<ast::ElseStatement>(
std::move(else_cond), std::move(else_body)));
auto cond = std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::BoolLiteral>(&bool_type, true));
ast::IfStatement expr(std::move(cond), std::move(body));
expr.set_else_statements(std::move(else_stmts));
Context ctx;
TypeDeterminer td(&ctx);
td.RegisterVariableForTesting(var.get());
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
Builder b;
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
EXPECT_TRUE(b.GenerateIfStatement(&expr)) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeInt 32 1
%2 = OpTypePointer Private %3
%1 = OpVariable %2 Private
%4 = OpTypeBool
%5 = OpConstantTrue %4
%9 = OpConstant %3 2
%12 = OpConstant %3 3
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(OpSelectionMerge %6 None
OpBranchConditional %5 %7 %8
%7 = OpLabel
OpStore %1 %9
OpBranch %6
%8 = OpLabel
OpSelectionMerge %10 None
OpBranchConditional %5 %11 %10
%11 = OpLabel
OpStore %1 %12
OpBranch %10
%10 = OpLabel
OpBranch %6
%6 = OpLabel
)");
}
TEST_F(BuilderTest, If_WithMultiple) {
ast::type::BoolType bool_type;
ast::type::I32Type i32;
// if (true) {
// v = 2;
// } elseif (true) {
// v = 3;
// } elseif (false) {
// v = 4;
// } else {
// v = 5;
// }
auto var =
std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
ast::StatementList body;
body.push_back(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("v"),
std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 2))));
ast::StatementList elseif_1_body;
elseif_1_body.push_back(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("v"),
std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 3))));
ast::StatementList elseif_2_body;
elseif_2_body.push_back(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("v"),
std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 4))));
ast::StatementList else_body;
else_body.push_back(std::make_unique<ast::AssignmentStatement>(
std::make_unique<ast::IdentifierExpression>("v"),
std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::IntLiteral>(&i32, 5))));
auto elseif_1_cond = std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::BoolLiteral>(&bool_type, true));
auto elseif_2_cond = std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::BoolLiteral>(&bool_type, false));
ast::ElseStatementList else_stmts;
else_stmts.push_back(std::make_unique<ast::ElseStatement>(
std::move(elseif_1_cond), std::move(elseif_1_body)));
else_stmts.push_back(std::make_unique<ast::ElseStatement>(
std::move(elseif_2_cond), std::move(elseif_2_body)));
else_stmts.push_back(
std::make_unique<ast::ElseStatement>(std::move(else_body)));
auto cond = std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::BoolLiteral>(&bool_type, true));
ast::IfStatement expr(std::move(cond), std::move(body));
expr.set_else_statements(std::move(else_stmts));
Context ctx;
TypeDeterminer td(&ctx);
td.RegisterVariableForTesting(var.get());
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
Builder b;
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
EXPECT_TRUE(b.GenerateIfStatement(&expr)) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeInt 32 1
%2 = OpTypePointer Private %3
%1 = OpVariable %2 Private
%4 = OpTypeBool
%5 = OpConstantTrue %4
%9 = OpConstant %3 2
%13 = OpConstant %3 3
%14 = OpConstantFalse %4
%18 = OpConstant %3 4
%19 = OpConstant %3 5
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(OpSelectionMerge %6 None
OpBranchConditional %5 %7 %8
%7 = OpLabel
OpStore %1 %9
OpBranch %6
%8 = OpLabel
OpSelectionMerge %10 None
OpBranchConditional %5 %11 %12
%11 = OpLabel
OpStore %1 %13
OpBranch %10
%12 = OpLabel
OpSelectionMerge %15 None
OpBranchConditional %14 %16 %17
%16 = OpLabel
OpStore %1 %18
OpBranch %15
%17 = OpLabel
OpStore %1 %19
OpBranch %15
%15 = OpLabel
OpBranch %10
%10 = OpLabel
OpBranch %6
%6 = OpLabel
)");
}
TEST_F(BuilderTest, DISABLED_If_WithBreak) { TEST_F(BuilderTest, DISABLED_If_WithBreak) {
// if (a) { // if (a) {