[spirv-writer] Emit logical and and logical or
This CL adds support for the && and || operators to the SPIR-V backend. Bug: tint:5 Change-Id: I63b23d9904b5b8027e189034d24949df71cbbe42 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23501 Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
parent
6b6e6a16ea
commit
be66f9faf9
|
@ -226,6 +226,11 @@ void Builder::push_capability(uint32_t cap) {
|
|||
Instruction{spv::Op::OpCapability, {Operand::Int(cap)}});
|
||||
}
|
||||
|
||||
void Builder::GenerateLabel(uint32_t id) {
|
||||
push_function_inst(spv::Op::OpLabel, {Operand::Int(id)});
|
||||
current_label_id_ = id;
|
||||
}
|
||||
|
||||
uint32_t Builder::GenerateU32Literal(uint32_t val) {
|
||||
ast::type::U32Type u32;
|
||||
ast::SintLiteral lit(&u32, val);
|
||||
|
@ -1083,7 +1088,72 @@ uint32_t Builder::GenerateLiteralIfNeeded(ast::Literal* lit) {
|
|||
return result_id;
|
||||
}
|
||||
|
||||
uint32_t Builder::GenerateShortCircuitBinaryExpression(
|
||||
ast::BinaryExpression* expr) {
|
||||
auto lhs_id = GenerateExpression(expr->lhs());
|
||||
if (lhs_id == 0) {
|
||||
return false;
|
||||
}
|
||||
lhs_id = GenerateLoadIfNeeded(expr->lhs()->result_type(), lhs_id);
|
||||
|
||||
auto original_label_id = current_label_id_;
|
||||
|
||||
auto type_id = GenerateTypeIfNeeded(expr->result_type());
|
||||
if (type_id == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto merge_block = result_op();
|
||||
auto merge_block_id = merge_block.to_i();
|
||||
|
||||
auto block = result_op();
|
||||
auto block_id = block.to_i();
|
||||
|
||||
auto true_block_id = block_id;
|
||||
auto false_block_id = merge_block_id;
|
||||
|
||||
// For a logical or we want to only check the RHS if the LHS is failed.
|
||||
if (expr->IsLogicalOr()) {
|
||||
std::swap(true_block_id, false_block_id);
|
||||
}
|
||||
|
||||
push_function_inst(spv::Op::OpSelectionMerge,
|
||||
{Operand::Int(merge_block_id),
|
||||
Operand::Int(SpvSelectionControlMaskNone)});
|
||||
push_function_inst(spv::Op::OpBranchConditional,
|
||||
{Operand::Int(lhs_id), Operand::Int(true_block_id),
|
||||
Operand::Int(false_block_id)});
|
||||
|
||||
// Output block to check the RHS
|
||||
GenerateLabel(block_id);
|
||||
auto rhs_id = GenerateExpression(expr->rhs());
|
||||
if (rhs_id == 0) {
|
||||
return 0;
|
||||
}
|
||||
rhs_id = GenerateLoadIfNeeded(expr->rhs()->result_type(), rhs_id);
|
||||
|
||||
push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
|
||||
|
||||
// Output the merge block
|
||||
GenerateLabel(merge_block_id);
|
||||
|
||||
auto result = result_op();
|
||||
auto result_id = result.to_i();
|
||||
|
||||
push_function_inst(spv::Op::OpPhi,
|
||||
{Operand::Int(type_id), result, Operand::Int(lhs_id),
|
||||
Operand::Int(original_label_id), Operand::Int(rhs_id),
|
||||
Operand::Int(block_id)});
|
||||
|
||||
return result_id;
|
||||
}
|
||||
|
||||
uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
|
||||
// There is special logic for short circuiting operators.
|
||||
if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
|
||||
return GenerateShortCircuitBinaryExpression(expr);
|
||||
}
|
||||
|
||||
auto lhs_id = GenerateExpression(expr->lhs());
|
||||
if (lhs_id == 0) {
|
||||
return 0;
|
||||
|
@ -1466,7 +1536,7 @@ bool Builder::GenerateConditionalBlock(
|
|||
Operand::Int(false_block_id)});
|
||||
|
||||
// Output true block
|
||||
push_function_inst(spv::Op::OpLabel, {true_block});
|
||||
GenerateLabel(true_block_id);
|
||||
if (!GenerateStatementList(true_body)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -1477,7 +1547,7 @@ bool Builder::GenerateConditionalBlock(
|
|||
|
||||
// Start the false block if needed
|
||||
if (false_block_id != merge_block_id) {
|
||||
push_function_inst(spv::Op::OpLabel, {Operand::Int(false_block_id)});
|
||||
GenerateLabel(false_block_id);
|
||||
|
||||
auto* else_stmt = else_stmts[cur_else_idx].get();
|
||||
// Handle the else case by just outputting the statements.
|
||||
|
@ -1497,7 +1567,7 @@ bool Builder::GenerateConditionalBlock(
|
|||
}
|
||||
|
||||
// Output the merge block
|
||||
push_function_inst(spv::Op::OpLabel, {merge_block});
|
||||
GenerateLabel(merge_block_id);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -1568,7 +1638,7 @@ bool Builder::GenerateSwitchStatement(ast::SwitchStatement* stmt) {
|
|||
generated_default = true;
|
||||
}
|
||||
|
||||
push_function_inst(spv::Op::OpLabel, {Operand::Int(case_ids[i])});
|
||||
GenerateLabel(case_ids[i]);
|
||||
if (!GenerateStatementList(item->body())) {
|
||||
return false;
|
||||
}
|
||||
|
@ -1585,13 +1655,13 @@ bool Builder::GenerateSwitchStatement(ast::SwitchStatement* stmt) {
|
|||
}
|
||||
|
||||
if (!generated_default) {
|
||||
push_function_inst(spv::Op::OpLabel, {Operand::Int(default_block_id)});
|
||||
GenerateLabel(default_block_id);
|
||||
push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
|
||||
}
|
||||
|
||||
merge_stack_.pop_back();
|
||||
|
||||
push_function_inst(spv::Op::OpLabel, {Operand::Int(merge_block_id)});
|
||||
GenerateLabel(merge_block_id);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -1613,7 +1683,7 @@ bool Builder::GenerateLoopStatement(ast::LoopStatement* stmt) {
|
|||
auto loop_header = result_op();
|
||||
auto loop_header_id = loop_header.to_i();
|
||||
push_function_inst(spv::Op::OpBranch, {Operand::Int(loop_header_id)});
|
||||
push_function_inst(spv::Op::OpLabel, {loop_header});
|
||||
GenerateLabel(loop_header_id);
|
||||
|
||||
auto merge_block = result_op();
|
||||
auto merge_block_id = merge_block.to_i();
|
||||
|
@ -1632,7 +1702,7 @@ bool Builder::GenerateLoopStatement(ast::LoopStatement* stmt) {
|
|||
merge_stack_.push_back(merge_block_id);
|
||||
|
||||
push_function_inst(spv::Op::OpBranch, {Operand::Int(body_block_id)});
|
||||
push_function_inst(spv::Op::OpLabel, {body_block});
|
||||
GenerateLabel(body_block_id);
|
||||
if (!GenerateStatementList(stmt->body())) {
|
||||
return false;
|
||||
}
|
||||
|
@ -1642,7 +1712,7 @@ bool Builder::GenerateLoopStatement(ast::LoopStatement* stmt) {
|
|||
push_function_inst(spv::Op::OpBranch, {Operand::Int(continue_block_id)});
|
||||
}
|
||||
|
||||
push_function_inst(spv::Op::OpLabel, {continue_block});
|
||||
GenerateLabel(continue_block_id);
|
||||
if (!GenerateStatementList(stmt->continuing())) {
|
||||
return false;
|
||||
}
|
||||
|
@ -1651,7 +1721,7 @@ bool Builder::GenerateLoopStatement(ast::LoopStatement* stmt) {
|
|||
merge_stack_.pop_back();
|
||||
continue_stack_.pop_back();
|
||||
|
||||
push_function_inst(spv::Op::OpLabel, {merge_block});
|
||||
GenerateLabel(merge_block_id);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -159,7 +159,10 @@ class Builder {
|
|||
|
||||
/// Adds a function to the builder
|
||||
/// @param func the function to add
|
||||
void push_function(const Function& func) { functions_.push_back(func); }
|
||||
void push_function(const Function& func) {
|
||||
functions_.push_back(func);
|
||||
current_label_id_ = func.label_id();
|
||||
}
|
||||
/// @returns the functions
|
||||
const std::vector<Function>& functions() const { return functions_; }
|
||||
/// Pushes an instruction to the current function
|
||||
|
@ -183,6 +186,9 @@ class Builder {
|
|||
/// @returns the SPIR-V builtin or SpvBuiltInMax on error.
|
||||
SpvBuiltIn ConvertBuiltin(ast::Builtin builtin) const;
|
||||
|
||||
/// Generates a label for the given id
|
||||
/// @param id the id to use for the label
|
||||
void GenerateLabel(uint32_t id);
|
||||
/// Generates a uint32_t literal.
|
||||
/// @param val the value to generate
|
||||
/// @returns the ID of the generated literal
|
||||
|
@ -291,6 +297,10 @@ class Builder {
|
|||
/// @param expr the expression to generate
|
||||
/// @returns the expression ID on success or 0 otherwise
|
||||
uint32_t GenerateBinaryExpression(ast::BinaryExpression* expr);
|
||||
/// Generates a short circuting binary expression
|
||||
/// @param expr the expression to generate
|
||||
/// @returns teh expression ID on success or 0 otherwise
|
||||
uint32_t GenerateShortCircuitBinaryExpression(ast::BinaryExpression* expr);
|
||||
/// Generates a call expression
|
||||
/// @param expr the expression to generate
|
||||
/// @returns the expression ID on success or 0 otherwise
|
||||
|
@ -395,6 +405,7 @@ class Builder {
|
|||
ast::Module* mod_;
|
||||
std::string error_;
|
||||
uint32_t next_id_ = 1;
|
||||
uint32_t current_label_id_ = 0;
|
||||
std::vector<Instruction> capabilities_;
|
||||
std::vector<Instruction> preamble_;
|
||||
std::vector<Instruction> debug_;
|
||||
|
|
|
@ -16,10 +16,12 @@
|
|||
|
||||
#include "gtest/gtest.h"
|
||||
#include "src/ast/binary_expression.h"
|
||||
#include "src/ast/bool_literal.h"
|
||||
#include "src/ast/float_literal.h"
|
||||
#include "src/ast/identifier_expression.h"
|
||||
#include "src/ast/scalar_constructor_expression.h"
|
||||
#include "src/ast/sint_literal.h"
|
||||
#include "src/ast/type/bool_type.h"
|
||||
#include "src/ast/type/f32_type.h"
|
||||
#include "src/ast/type/i32_type.h"
|
||||
#include "src/ast/type/matrix_type.h"
|
||||
|
@ -866,6 +868,216 @@ TEST_F(BuilderTest, Binary_Multiply_MatrixMatrix) {
|
|||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Binary_LogicalAnd) {
|
||||
ast::type::I32Type i32;
|
||||
|
||||
auto lhs = std::make_unique<ast::BinaryExpression>(
|
||||
ast::BinaryOp::kEqual,
|
||||
std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::SintLiteral>(&i32, 1)),
|
||||
std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::SintLiteral>(&i32, 2)));
|
||||
|
||||
auto rhs = std::make_unique<ast::BinaryExpression>(
|
||||
ast::BinaryOp::kEqual,
|
||||
std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::SintLiteral>(&i32, 3)),
|
||||
std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::SintLiteral>(&i32, 4)));
|
||||
|
||||
Context ctx;
|
||||
ast::Module mod;
|
||||
TypeDeterminer td(&ctx, &mod);
|
||||
|
||||
ast::BinaryExpression expr(ast::BinaryOp::kLogicalAnd, std::move(lhs),
|
||||
std::move(rhs));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
|
||||
Builder b(&mod);
|
||||
b.push_function(Function{});
|
||||
b.GenerateLabel(b.next_id());
|
||||
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 12u) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1
|
||||
%3 = OpConstant %2 1
|
||||
%4 = OpConstant %2 2
|
||||
%6 = OpTypeBool
|
||||
%9 = OpConstant %2 3
|
||||
%10 = OpConstant %2 4
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
R"(%1 = OpLabel
|
||||
%5 = OpIEqual %6 %3 %4
|
||||
OpSelectionMerge %7 None
|
||||
OpBranchConditional %5 %8 %7
|
||||
%8 = OpLabel
|
||||
%11 = OpIEqual %6 %9 %10
|
||||
OpBranch %7
|
||||
%7 = OpLabel
|
||||
%12 = OpPhi %6 %5 %1 %11 %8
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Binary_LogicalAnd_WithLoads) {
|
||||
ast::type::BoolType bool_type;
|
||||
|
||||
auto a_var = std::make_unique<ast::Variable>(
|
||||
"a", ast::StorageClass::kFunction, &bool_type);
|
||||
a_var->set_constructor(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::BoolLiteral>(&bool_type, true)));
|
||||
auto b_var = std::make_unique<ast::Variable>(
|
||||
"b", ast::StorageClass::kFunction, &bool_type);
|
||||
b_var->set_constructor(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::BoolLiteral>(&bool_type, false)));
|
||||
|
||||
auto lhs = std::make_unique<ast::IdentifierExpression>("a");
|
||||
auto rhs = std::make_unique<ast::IdentifierExpression>("b");
|
||||
|
||||
Context ctx;
|
||||
ast::Module mod;
|
||||
TypeDeterminer td(&ctx, &mod);
|
||||
td.RegisterVariableForTesting(a_var.get());
|
||||
td.RegisterVariableForTesting(b_var.get());
|
||||
|
||||
ast::BinaryExpression expr(ast::BinaryOp::kLogicalAnd, std::move(lhs),
|
||||
std::move(rhs));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
|
||||
Builder b(&mod);
|
||||
b.push_function(Function{});
|
||||
b.GenerateLabel(b.next_id());
|
||||
|
||||
ASSERT_TRUE(b.GenerateGlobalVariable(a_var.get())) << b.error();
|
||||
ASSERT_TRUE(b.GenerateGlobalVariable(b_var.get())) << b.error();
|
||||
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 12u) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeBool
|
||||
%3 = OpConstantTrue %2
|
||||
%5 = OpTypePointer Function %2
|
||||
%4 = OpVariable %5 Function %3
|
||||
%6 = OpConstantFalse %2
|
||||
%7 = OpVariable %5 Function %6
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
R"(%1 = OpLabel
|
||||
%8 = OpLoad %2 %4
|
||||
OpSelectionMerge %9 None
|
||||
OpBranchConditional %8 %10 %9
|
||||
%10 = OpLabel
|
||||
%11 = OpLoad %2 %7
|
||||
OpBranch %9
|
||||
%9 = OpLabel
|
||||
%12 = OpPhi %2 %8 %1 %11 %10
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Binary_LogicalOr) {
|
||||
ast::type::I32Type i32;
|
||||
|
||||
auto lhs = std::make_unique<ast::BinaryExpression>(
|
||||
ast::BinaryOp::kEqual,
|
||||
std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::SintLiteral>(&i32, 1)),
|
||||
std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::SintLiteral>(&i32, 2)));
|
||||
|
||||
auto rhs = std::make_unique<ast::BinaryExpression>(
|
||||
ast::BinaryOp::kEqual,
|
||||
std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::SintLiteral>(&i32, 3)),
|
||||
std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::SintLiteral>(&i32, 4)));
|
||||
|
||||
Context ctx;
|
||||
ast::Module mod;
|
||||
TypeDeterminer td(&ctx, &mod);
|
||||
|
||||
ast::BinaryExpression expr(ast::BinaryOp::kLogicalOr, std::move(lhs),
|
||||
std::move(rhs));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
|
||||
Builder b(&mod);
|
||||
b.push_function(Function{});
|
||||
b.GenerateLabel(b.next_id());
|
||||
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 12u) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1
|
||||
%3 = OpConstant %2 1
|
||||
%4 = OpConstant %2 2
|
||||
%6 = OpTypeBool
|
||||
%9 = OpConstant %2 3
|
||||
%10 = OpConstant %2 4
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
R"(%1 = OpLabel
|
||||
%5 = OpIEqual %6 %3 %4
|
||||
OpSelectionMerge %7 None
|
||||
OpBranchConditional %5 %7 %8
|
||||
%8 = OpLabel
|
||||
%11 = OpIEqual %6 %9 %10
|
||||
OpBranch %7
|
||||
%7 = OpLabel
|
||||
%12 = OpPhi %6 %5 %1 %11 %8
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Binary_LogicalOr_WithLoads) {
|
||||
ast::type::BoolType bool_type;
|
||||
|
||||
auto a_var = std::make_unique<ast::Variable>(
|
||||
"a", ast::StorageClass::kFunction, &bool_type);
|
||||
a_var->set_constructor(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::BoolLiteral>(&bool_type, true)));
|
||||
auto b_var = std::make_unique<ast::Variable>(
|
||||
"b", ast::StorageClass::kFunction, &bool_type);
|
||||
b_var->set_constructor(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::BoolLiteral>(&bool_type, false)));
|
||||
|
||||
auto lhs = std::make_unique<ast::IdentifierExpression>("a");
|
||||
auto rhs = std::make_unique<ast::IdentifierExpression>("b");
|
||||
|
||||
Context ctx;
|
||||
ast::Module mod;
|
||||
TypeDeterminer td(&ctx, &mod);
|
||||
td.RegisterVariableForTesting(a_var.get());
|
||||
td.RegisterVariableForTesting(b_var.get());
|
||||
|
||||
ast::BinaryExpression expr(ast::BinaryOp::kLogicalOr, std::move(lhs),
|
||||
std::move(rhs));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
|
||||
Builder b(&mod);
|
||||
b.push_function(Function{});
|
||||
b.GenerateLabel(b.next_id());
|
||||
|
||||
ASSERT_TRUE(b.GenerateGlobalVariable(a_var.get())) << b.error();
|
||||
ASSERT_TRUE(b.GenerateGlobalVariable(b_var.get())) << b.error();
|
||||
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 12u) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeBool
|
||||
%3 = OpConstantTrue %2
|
||||
%5 = OpTypePointer Function %2
|
||||
%4 = OpVariable %5 Function %3
|
||||
%6 = OpConstantFalse %2
|
||||
%7 = OpVariable %5 Function %6
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
R"(%1 = OpLabel
|
||||
%8 = OpLoad %2 %4
|
||||
OpSelectionMerge %9 None
|
||||
OpBranchConditional %8 %9 %10
|
||||
%10 = OpLabel
|
||||
%11 = OpLoad %2 %7
|
||||
OpBranch %9
|
||||
%9 = OpLabel
|
||||
%12 = OpPhi %2 %8 %1 %11 %10
|
||||
)");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace spirv
|
||||
} // namespace writer
|
||||
|
|
|
@ -52,6 +52,9 @@ class Function {
|
|||
/// @returns the declaration
|
||||
const Instruction& declaration() const { return declaration_; }
|
||||
|
||||
/// @returns the function label id
|
||||
uint32_t label_id() const { return label_op_.to_i(); }
|
||||
|
||||
/// Adds an instruction to the instruction list
|
||||
/// @param op the op to set
|
||||
/// @param operands the operands for the instruction
|
||||
|
|
Loading…
Reference in New Issue