[spirv-writer] Add elseif support.
This CL adds support for having elseif statements after an if statement. Bug: tint:5 Change-Id: I3cd3c5bddaa57c998b1a3fbee7bd87536533301d Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/19500 Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
parent
631a7ac72b
commit
f963128c88
|
@ -630,8 +630,12 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
|
||||||
return result_id;
|
return result_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Builder::GenerateIfStatement(ast::IfStatement* stmt) {
|
bool Builder::GenerateConditionalBlock(
|
||||||
auto cond_id = GenerateExpression(stmt->condition());
|
ast::Expression* cond,
|
||||||
|
const ast::StatementList& true_body,
|
||||||
|
size_t cur_else_idx,
|
||||||
|
const ast::ElseStatementList& else_stmts) {
|
||||||
|
auto cond_id = GenerateExpression(cond);
|
||||||
if (cond_id == 0) {
|
if (cond_id == 0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -646,10 +650,10 @@ bool Builder::GenerateIfStatement(ast::IfStatement* stmt) {
|
||||||
auto true_block = result_op();
|
auto true_block = result_op();
|
||||||
auto true_block_id = true_block.to_i();
|
auto true_block_id = true_block.to_i();
|
||||||
|
|
||||||
// if there are no else statements we branch on false to the merge block
|
// if there are no more else statements we branch on false to the merge block
|
||||||
// otherwise we branch to the false block
|
// otherwise we branch to the false block
|
||||||
auto false_block_id =
|
auto false_block_id =
|
||||||
stmt->has_else_statements() ? next_id() : merge_block_id;
|
cur_else_idx < else_stmts.size() ? next_id() : merge_block_id;
|
||||||
|
|
||||||
push_function_inst(spv::Op::OpBranchConditional,
|
push_function_inst(spv::Op::OpBranchConditional,
|
||||||
{Operand::Int(cond_id), Operand::Int(true_block_id),
|
{Operand::Int(cond_id), Operand::Int(true_block_id),
|
||||||
|
@ -657,28 +661,33 @@ bool Builder::GenerateIfStatement(ast::IfStatement* stmt) {
|
||||||
|
|
||||||
// Output true block
|
// Output true block
|
||||||
push_function_inst(spv::Op::OpLabel, {true_block});
|
push_function_inst(spv::Op::OpLabel, {true_block});
|
||||||
for (const auto& inst : stmt->body()) {
|
if (!GenerateStatementList(true_body)) {
|
||||||
if (!GenerateStatement(inst.get())) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(dsinclair): The branch should be optional based on how the
|
// TODO(dsinclair): The branch should be optional based on how the
|
||||||
// StatementList ended ...
|
// StatementList ended ...
|
||||||
push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
|
push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
|
||||||
|
|
||||||
|
// Start the false block if needed
|
||||||
if (false_block_id != merge_block_id) {
|
if (false_block_id != merge_block_id) {
|
||||||
push_function_inst(spv::Op::OpLabel, {Operand::Int(false_block_id)});
|
push_function_inst(spv::Op::OpLabel, {Operand::Int(false_block_id)});
|
||||||
|
|
||||||
for (const auto& else_stmt : stmt->else_statements()) {
|
auto* else_stmt = else_stmts[cur_else_idx].get();
|
||||||
if (!GenerateElseStatement(else_stmt.get())) {
|
// Handle the else case by just outputting the statements.
|
||||||
|
if (!else_stmt->HasCondition()) {
|
||||||
|
if (!GenerateStatementList(else_stmt->body())) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(dsinclair): The branch should be optional based on how the
|
// TODO(dsinclair): The branch should be optional based on how the
|
||||||
// StatementList ended ...
|
// StatementList ended ...
|
||||||
push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
|
push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
|
||||||
|
} else {
|
||||||
|
if (!GenerateConditionalBlock(else_stmt->condition(), else_stmt->body(),
|
||||||
|
cur_else_idx + 1, else_stmts)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Output the merge block
|
// Output the merge block
|
||||||
|
@ -687,18 +696,11 @@ bool Builder::GenerateIfStatement(ast::IfStatement* stmt) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Builder::GenerateElseStatement(ast::ElseStatement* stmt) {
|
bool Builder::GenerateIfStatement(ast::IfStatement* stmt) {
|
||||||
// TODO(dsinclair): handle else if
|
if (!GenerateConditionalBlock(stmt->condition(), stmt->body(), 0,
|
||||||
if (stmt->HasCondition()) {
|
stmt->else_statements())) {
|
||||||
error_ = "else if not handled yet";
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const auto& inst : stmt->body()) {
|
|
||||||
if (!GenerateStatement(inst.get())) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -716,6 +718,15 @@ bool Builder::GenerateReturnStatement(ast::ReturnStatement* stmt) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Builder::GenerateStatementList(const ast::StatementList& list) {
|
||||||
|
for (const auto& inst : list) {
|
||||||
|
if (!GenerateStatement(inst.get())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool Builder::GenerateStatement(ast::Statement* stmt) {
|
bool Builder::GenerateStatement(ast::Statement* stmt) {
|
||||||
if (stmt->IsAssign()) {
|
if (stmt->IsAssign()) {
|
||||||
return GenerateAssignStatement(stmt->AsAssign());
|
return GenerateAssignStatement(stmt->AsAssign());
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
|
|
||||||
#include "spirv/unified1/spirv.h"
|
#include "spirv/unified1/spirv.h"
|
||||||
#include "src/ast/builtin.h"
|
#include "src/ast/builtin.h"
|
||||||
|
#include "src/ast/else_statement.h"
|
||||||
#include "src/ast/literal.h"
|
#include "src/ast/literal.h"
|
||||||
#include "src/ast/module.h"
|
#include "src/ast/module.h"
|
||||||
#include "src/ast/struct_member.h"
|
#include "src/ast/struct_member.h"
|
||||||
|
@ -148,10 +149,6 @@ class Builder {
|
||||||
/// @param assign the statement to generate
|
/// @param assign the statement to generate
|
||||||
/// @returns true if the statement was successfully generated
|
/// @returns true if the statement was successfully generated
|
||||||
bool GenerateAssignStatement(ast::AssignmentStatement* assign);
|
bool GenerateAssignStatement(ast::AssignmentStatement* assign);
|
||||||
/// Generates an else statement
|
|
||||||
/// @param stmt the statement to generate
|
|
||||||
/// @returns true on successfull generation
|
|
||||||
bool GenerateElseStatement(ast::ElseStatement* stmt);
|
|
||||||
/// Generates an entry point instruction
|
/// Generates an entry point instruction
|
||||||
/// @param ep the entry point
|
/// @param ep the entry point
|
||||||
/// @returns true if the instruction was generated, false otherwise
|
/// @returns true if the instruction was generated, false otherwise
|
||||||
|
@ -209,10 +206,24 @@ class Builder {
|
||||||
/// @param stmt the statement to generate
|
/// @param stmt the statement to generate
|
||||||
/// @returns true on success, false otherwise
|
/// @returns true on success, false otherwise
|
||||||
bool GenerateReturnStatement(ast::ReturnStatement* stmt);
|
bool GenerateReturnStatement(ast::ReturnStatement* stmt);
|
||||||
|
/// Generates a conditional section merge block
|
||||||
|
/// @param cond the condition
|
||||||
|
/// @param true_body the statements making up the true block
|
||||||
|
/// @param cur_else_idx the index of the current else statement to process
|
||||||
|
/// @param else_stmts the list of all else statements
|
||||||
|
/// @returns true on success, false on failure
|
||||||
|
bool GenerateConditionalBlock(ast::Expression* cond,
|
||||||
|
const ast::StatementList& true_body,
|
||||||
|
size_t cur_else_idx,
|
||||||
|
const ast::ElseStatementList& else_stmts);
|
||||||
/// Generates a statement
|
/// Generates a statement
|
||||||
/// @param stmt the statement to generate
|
/// @param stmt the statement to generate
|
||||||
/// @returns true if the statement was generated
|
/// @returns true if the statement was generated
|
||||||
bool GenerateStatement(ast::Statement* stmt);
|
bool GenerateStatement(ast::Statement* stmt);
|
||||||
|
/// Generates a list of statements
|
||||||
|
/// @param list the statement list to generate
|
||||||
|
/// @returns true on successful generation
|
||||||
|
bool GenerateStatementList(const ast::StatementList& list);
|
||||||
/// Geneates an OpStore
|
/// Geneates an OpStore
|
||||||
/// @param to the ID to store too
|
/// @param to the ID to store too
|
||||||
/// @param from the ID to store from
|
/// @param from the ID to store from
|
||||||
|
|
|
@ -39,6 +39,8 @@ using BuilderTest = testing::Test;
|
||||||
TEST_F(BuilderTest, If_Empty) {
|
TEST_F(BuilderTest, If_Empty) {
|
||||||
ast::type::BoolType bool_type;
|
ast::type::BoolType bool_type;
|
||||||
|
|
||||||
|
// if (true) {
|
||||||
|
// }
|
||||||
auto cond = std::make_unique<ast::ScalarConstructorExpression>(
|
auto cond = std::make_unique<ast::ScalarConstructorExpression>(
|
||||||
std::make_unique<ast::BoolLiteral>(&bool_type, true));
|
std::make_unique<ast::BoolLiteral>(&bool_type, true));
|
||||||
|
|
||||||
|
@ -68,6 +70,9 @@ TEST_F(BuilderTest, If_WithStatements) {
|
||||||
ast::type::BoolType bool_type;
|
ast::type::BoolType bool_type;
|
||||||
ast::type::I32Type i32;
|
ast::type::I32Type i32;
|
||||||
|
|
||||||
|
// if (true) {
|
||||||
|
// v = 2;
|
||||||
|
// }
|
||||||
auto var =
|
auto var =
|
||||||
std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
|
std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
|
||||||
|
|
||||||
|
@ -114,6 +119,11 @@ TEST_F(BuilderTest, If_WithElse) {
|
||||||
ast::type::BoolType bool_type;
|
ast::type::BoolType bool_type;
|
||||||
ast::type::I32Type i32;
|
ast::type::I32Type i32;
|
||||||
|
|
||||||
|
// if (true) {
|
||||||
|
// v = 2;
|
||||||
|
// } else {
|
||||||
|
// v = 3;
|
||||||
|
// }
|
||||||
auto var =
|
auto var =
|
||||||
std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
|
std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
|
||||||
|
|
||||||
|
@ -171,9 +181,186 @@ OpBranch %6
|
||||||
)");
|
)");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BuilderTest, DISABLED_If_WithElseIf) {}
|
TEST_F(BuilderTest, If_WithElseIf) {
|
||||||
|
ast::type::BoolType bool_type;
|
||||||
|
ast::type::I32Type i32;
|
||||||
|
|
||||||
TEST_F(BuilderTest, DISABLED_If_WithMultiple) {}
|
// if (true) {
|
||||||
|
// v = 2;
|
||||||
|
// } elseif (true) {
|
||||||
|
// v = 3;
|
||||||
|
// }
|
||||||
|
auto var =
|
||||||
|
std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
|
||||||
|
|
||||||
|
ast::StatementList body;
|
||||||
|
body.push_back(std::make_unique<ast::AssignmentStatement>(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("v"),
|
||||||
|
std::make_unique<ast::ScalarConstructorExpression>(
|
||||||
|
std::make_unique<ast::IntLiteral>(&i32, 2))));
|
||||||
|
|
||||||
|
ast::StatementList else_body;
|
||||||
|
else_body.push_back(std::make_unique<ast::AssignmentStatement>(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("v"),
|
||||||
|
std::make_unique<ast::ScalarConstructorExpression>(
|
||||||
|
std::make_unique<ast::IntLiteral>(&i32, 3))));
|
||||||
|
|
||||||
|
auto else_cond = std::make_unique<ast::ScalarConstructorExpression>(
|
||||||
|
std::make_unique<ast::BoolLiteral>(&bool_type, true));
|
||||||
|
|
||||||
|
ast::ElseStatementList else_stmts;
|
||||||
|
else_stmts.push_back(std::make_unique<ast::ElseStatement>(
|
||||||
|
std::move(else_cond), std::move(else_body)));
|
||||||
|
|
||||||
|
auto cond = std::make_unique<ast::ScalarConstructorExpression>(
|
||||||
|
std::make_unique<ast::BoolLiteral>(&bool_type, true));
|
||||||
|
|
||||||
|
ast::IfStatement expr(std::move(cond), std::move(body));
|
||||||
|
expr.set_else_statements(std::move(else_stmts));
|
||||||
|
|
||||||
|
Context ctx;
|
||||||
|
TypeDeterminer td(&ctx);
|
||||||
|
td.RegisterVariableForTesting(var.get());
|
||||||
|
|
||||||
|
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||||
|
|
||||||
|
Builder b;
|
||||||
|
b.push_function(Function{});
|
||||||
|
ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
|
||||||
|
|
||||||
|
EXPECT_TRUE(b.GenerateIfStatement(&expr)) << b.error();
|
||||||
|
EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeInt 32 1
|
||||||
|
%2 = OpTypePointer Private %3
|
||||||
|
%1 = OpVariable %2 Private
|
||||||
|
%4 = OpTypeBool
|
||||||
|
%5 = OpConstantTrue %4
|
||||||
|
%9 = OpConstant %3 2
|
||||||
|
%12 = OpConstant %3 3
|
||||||
|
)");
|
||||||
|
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||||
|
R"(OpSelectionMerge %6 None
|
||||||
|
OpBranchConditional %5 %7 %8
|
||||||
|
%7 = OpLabel
|
||||||
|
OpStore %1 %9
|
||||||
|
OpBranch %6
|
||||||
|
%8 = OpLabel
|
||||||
|
OpSelectionMerge %10 None
|
||||||
|
OpBranchConditional %5 %11 %10
|
||||||
|
%11 = OpLabel
|
||||||
|
OpStore %1 %12
|
||||||
|
OpBranch %10
|
||||||
|
%10 = OpLabel
|
||||||
|
OpBranch %6
|
||||||
|
%6 = OpLabel
|
||||||
|
)");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BuilderTest, If_WithMultiple) {
|
||||||
|
ast::type::BoolType bool_type;
|
||||||
|
ast::type::I32Type i32;
|
||||||
|
|
||||||
|
// if (true) {
|
||||||
|
// v = 2;
|
||||||
|
// } elseif (true) {
|
||||||
|
// v = 3;
|
||||||
|
// } elseif (false) {
|
||||||
|
// v = 4;
|
||||||
|
// } else {
|
||||||
|
// v = 5;
|
||||||
|
// }
|
||||||
|
auto var =
|
||||||
|
std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
|
||||||
|
|
||||||
|
ast::StatementList body;
|
||||||
|
body.push_back(std::make_unique<ast::AssignmentStatement>(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("v"),
|
||||||
|
std::make_unique<ast::ScalarConstructorExpression>(
|
||||||
|
std::make_unique<ast::IntLiteral>(&i32, 2))));
|
||||||
|
ast::StatementList elseif_1_body;
|
||||||
|
elseif_1_body.push_back(std::make_unique<ast::AssignmentStatement>(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("v"),
|
||||||
|
std::make_unique<ast::ScalarConstructorExpression>(
|
||||||
|
std::make_unique<ast::IntLiteral>(&i32, 3))));
|
||||||
|
ast::StatementList elseif_2_body;
|
||||||
|
elseif_2_body.push_back(std::make_unique<ast::AssignmentStatement>(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("v"),
|
||||||
|
std::make_unique<ast::ScalarConstructorExpression>(
|
||||||
|
std::make_unique<ast::IntLiteral>(&i32, 4))));
|
||||||
|
ast::StatementList else_body;
|
||||||
|
else_body.push_back(std::make_unique<ast::AssignmentStatement>(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("v"),
|
||||||
|
std::make_unique<ast::ScalarConstructorExpression>(
|
||||||
|
std::make_unique<ast::IntLiteral>(&i32, 5))));
|
||||||
|
|
||||||
|
auto elseif_1_cond = std::make_unique<ast::ScalarConstructorExpression>(
|
||||||
|
std::make_unique<ast::BoolLiteral>(&bool_type, true));
|
||||||
|
auto elseif_2_cond = std::make_unique<ast::ScalarConstructorExpression>(
|
||||||
|
std::make_unique<ast::BoolLiteral>(&bool_type, false));
|
||||||
|
|
||||||
|
ast::ElseStatementList else_stmts;
|
||||||
|
else_stmts.push_back(std::make_unique<ast::ElseStatement>(
|
||||||
|
std::move(elseif_1_cond), std::move(elseif_1_body)));
|
||||||
|
else_stmts.push_back(std::make_unique<ast::ElseStatement>(
|
||||||
|
std::move(elseif_2_cond), std::move(elseif_2_body)));
|
||||||
|
else_stmts.push_back(
|
||||||
|
std::make_unique<ast::ElseStatement>(std::move(else_body)));
|
||||||
|
|
||||||
|
auto cond = std::make_unique<ast::ScalarConstructorExpression>(
|
||||||
|
std::make_unique<ast::BoolLiteral>(&bool_type, true));
|
||||||
|
|
||||||
|
ast::IfStatement expr(std::move(cond), std::move(body));
|
||||||
|
expr.set_else_statements(std::move(else_stmts));
|
||||||
|
|
||||||
|
Context ctx;
|
||||||
|
TypeDeterminer td(&ctx);
|
||||||
|
td.RegisterVariableForTesting(var.get());
|
||||||
|
|
||||||
|
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||||
|
|
||||||
|
Builder b;
|
||||||
|
b.push_function(Function{});
|
||||||
|
ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
|
||||||
|
|
||||||
|
EXPECT_TRUE(b.GenerateIfStatement(&expr)) << b.error();
|
||||||
|
EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeInt 32 1
|
||||||
|
%2 = OpTypePointer Private %3
|
||||||
|
%1 = OpVariable %2 Private
|
||||||
|
%4 = OpTypeBool
|
||||||
|
%5 = OpConstantTrue %4
|
||||||
|
%9 = OpConstant %3 2
|
||||||
|
%13 = OpConstant %3 3
|
||||||
|
%14 = OpConstantFalse %4
|
||||||
|
%18 = OpConstant %3 4
|
||||||
|
%19 = OpConstant %3 5
|
||||||
|
)");
|
||||||
|
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||||
|
R"(OpSelectionMerge %6 None
|
||||||
|
OpBranchConditional %5 %7 %8
|
||||||
|
%7 = OpLabel
|
||||||
|
OpStore %1 %9
|
||||||
|
OpBranch %6
|
||||||
|
%8 = OpLabel
|
||||||
|
OpSelectionMerge %10 None
|
||||||
|
OpBranchConditional %5 %11 %12
|
||||||
|
%11 = OpLabel
|
||||||
|
OpStore %1 %13
|
||||||
|
OpBranch %10
|
||||||
|
%12 = OpLabel
|
||||||
|
OpSelectionMerge %15 None
|
||||||
|
OpBranchConditional %14 %16 %17
|
||||||
|
%16 = OpLabel
|
||||||
|
OpStore %1 %18
|
||||||
|
OpBranch %15
|
||||||
|
%17 = OpLabel
|
||||||
|
OpStore %1 %19
|
||||||
|
OpBranch %15
|
||||||
|
%15 = OpLabel
|
||||||
|
OpBranch %10
|
||||||
|
%10 = OpLabel
|
||||||
|
OpBranch %6
|
||||||
|
%6 = OpLabel
|
||||||
|
)");
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(BuilderTest, DISABLED_If_WithBreak) {
|
TEST_F(BuilderTest, DISABLED_If_WithBreak) {
|
||||||
// if (a) {
|
// if (a) {
|
||||||
|
|
Loading…
Reference in New Issue