From be66f9faf90729df43d681ac0aca35dcb4295d14 Mon Sep 17 00:00:00 2001 From: dan sinclair Date: Fri, 19 Jun 2020 19:44:38 +0000 Subject: [PATCH] [spirv-writer] Emit logical and and logical or This CL adds support for the && and || operators to the SPIR-V backend. Bug: tint:5 Change-Id: I63b23d9904b5b8027e189034d24949df71cbbe42 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23501 Reviewed-by: David Neto --- src/writer/spirv/builder.cc | 90 +++++++- src/writer/spirv/builder.h | 13 +- .../spirv/builder_binary_expression_test.cc | 212 ++++++++++++++++++ src/writer/spirv/function.h | 3 + 4 files changed, 307 insertions(+), 11 deletions(-) diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index c0dbf31430..3f2d289484 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -226,6 +226,11 @@ void Builder::push_capability(uint32_t cap) { Instruction{spv::Op::OpCapability, {Operand::Int(cap)}}); } +void Builder::GenerateLabel(uint32_t id) { + push_function_inst(spv::Op::OpLabel, {Operand::Int(id)}); + current_label_id_ = id; +} + uint32_t Builder::GenerateU32Literal(uint32_t val) { ast::type::U32Type u32; ast::SintLiteral lit(&u32, val); @@ -1083,7 +1088,72 @@ uint32_t Builder::GenerateLiteralIfNeeded(ast::Literal* lit) { return result_id; } +uint32_t Builder::GenerateShortCircuitBinaryExpression( + ast::BinaryExpression* expr) { + auto lhs_id = GenerateExpression(expr->lhs()); + if (lhs_id == 0) { + return false; + } + lhs_id = GenerateLoadIfNeeded(expr->lhs()->result_type(), lhs_id); + + auto original_label_id = current_label_id_; + + auto type_id = GenerateTypeIfNeeded(expr->result_type()); + if (type_id == 0) { + return 0; + } + + auto merge_block = result_op(); + auto merge_block_id = merge_block.to_i(); + + auto block = result_op(); + auto block_id = block.to_i(); + + auto true_block_id = block_id; + auto false_block_id = merge_block_id; + + // For a logical or we want to only check the RHS if the LHS is failed. + if (expr->IsLogicalOr()) { + std::swap(true_block_id, false_block_id); + } + + push_function_inst(spv::Op::OpSelectionMerge, + {Operand::Int(merge_block_id), + Operand::Int(SpvSelectionControlMaskNone)}); + push_function_inst(spv::Op::OpBranchConditional, + {Operand::Int(lhs_id), Operand::Int(true_block_id), + Operand::Int(false_block_id)}); + + // Output block to check the RHS + GenerateLabel(block_id); + auto rhs_id = GenerateExpression(expr->rhs()); + if (rhs_id == 0) { + return 0; + } + rhs_id = GenerateLoadIfNeeded(expr->rhs()->result_type(), rhs_id); + + push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)}); + + // Output the merge block + GenerateLabel(merge_block_id); + + auto result = result_op(); + auto result_id = result.to_i(); + + push_function_inst(spv::Op::OpPhi, + {Operand::Int(type_id), result, Operand::Int(lhs_id), + Operand::Int(original_label_id), Operand::Int(rhs_id), + Operand::Int(block_id)}); + + return result_id; +} + uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) { + // There is special logic for short circuiting operators. + if (expr->IsLogicalAnd() || expr->IsLogicalOr()) { + return GenerateShortCircuitBinaryExpression(expr); + } + auto lhs_id = GenerateExpression(expr->lhs()); if (lhs_id == 0) { return 0; @@ -1466,7 +1536,7 @@ bool Builder::GenerateConditionalBlock( Operand::Int(false_block_id)}); // Output true block - push_function_inst(spv::Op::OpLabel, {true_block}); + GenerateLabel(true_block_id); if (!GenerateStatementList(true_body)) { return false; } @@ -1477,7 +1547,7 @@ bool Builder::GenerateConditionalBlock( // Start the false block if needed if (false_block_id != merge_block_id) { - push_function_inst(spv::Op::OpLabel, {Operand::Int(false_block_id)}); + GenerateLabel(false_block_id); auto* else_stmt = else_stmts[cur_else_idx].get(); // Handle the else case by just outputting the statements. @@ -1497,7 +1567,7 @@ bool Builder::GenerateConditionalBlock( } // Output the merge block - push_function_inst(spv::Op::OpLabel, {merge_block}); + GenerateLabel(merge_block_id); return true; } @@ -1568,7 +1638,7 @@ bool Builder::GenerateSwitchStatement(ast::SwitchStatement* stmt) { generated_default = true; } - push_function_inst(spv::Op::OpLabel, {Operand::Int(case_ids[i])}); + GenerateLabel(case_ids[i]); if (!GenerateStatementList(item->body())) { return false; } @@ -1585,13 +1655,13 @@ bool Builder::GenerateSwitchStatement(ast::SwitchStatement* stmt) { } if (!generated_default) { - push_function_inst(spv::Op::OpLabel, {Operand::Int(default_block_id)}); + GenerateLabel(default_block_id); push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)}); } merge_stack_.pop_back(); - push_function_inst(spv::Op::OpLabel, {Operand::Int(merge_block_id)}); + GenerateLabel(merge_block_id); return true; } @@ -1613,7 +1683,7 @@ bool Builder::GenerateLoopStatement(ast::LoopStatement* stmt) { auto loop_header = result_op(); auto loop_header_id = loop_header.to_i(); push_function_inst(spv::Op::OpBranch, {Operand::Int(loop_header_id)}); - push_function_inst(spv::Op::OpLabel, {loop_header}); + GenerateLabel(loop_header_id); auto merge_block = result_op(); auto merge_block_id = merge_block.to_i(); @@ -1632,7 +1702,7 @@ bool Builder::GenerateLoopStatement(ast::LoopStatement* stmt) { merge_stack_.push_back(merge_block_id); push_function_inst(spv::Op::OpBranch, {Operand::Int(body_block_id)}); - push_function_inst(spv::Op::OpLabel, {body_block}); + GenerateLabel(body_block_id); if (!GenerateStatementList(stmt->body())) { return false; } @@ -1642,7 +1712,7 @@ bool Builder::GenerateLoopStatement(ast::LoopStatement* stmt) { push_function_inst(spv::Op::OpBranch, {Operand::Int(continue_block_id)}); } - push_function_inst(spv::Op::OpLabel, {continue_block}); + GenerateLabel(continue_block_id); if (!GenerateStatementList(stmt->continuing())) { return false; } @@ -1651,7 +1721,7 @@ bool Builder::GenerateLoopStatement(ast::LoopStatement* stmt) { merge_stack_.pop_back(); continue_stack_.pop_back(); - push_function_inst(spv::Op::OpLabel, {merge_block}); + GenerateLabel(merge_block_id); return true; } diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h index 094add447f..a111094d1e 100644 --- a/src/writer/spirv/builder.h +++ b/src/writer/spirv/builder.h @@ -159,7 +159,10 @@ class Builder { /// Adds a function to the builder /// @param func the function to add - void push_function(const Function& func) { functions_.push_back(func); } + void push_function(const Function& func) { + functions_.push_back(func); + current_label_id_ = func.label_id(); + } /// @returns the functions const std::vector& functions() const { return functions_; } /// Pushes an instruction to the current function @@ -183,6 +186,9 @@ class Builder { /// @returns the SPIR-V builtin or SpvBuiltInMax on error. SpvBuiltIn ConvertBuiltin(ast::Builtin builtin) const; + /// Generates a label for the given id + /// @param id the id to use for the label + void GenerateLabel(uint32_t id); /// Generates a uint32_t literal. /// @param val the value to generate /// @returns the ID of the generated literal @@ -291,6 +297,10 @@ class Builder { /// @param expr the expression to generate /// @returns the expression ID on success or 0 otherwise uint32_t GenerateBinaryExpression(ast::BinaryExpression* expr); + /// Generates a short circuting binary expression + /// @param expr the expression to generate + /// @returns teh expression ID on success or 0 otherwise + uint32_t GenerateShortCircuitBinaryExpression(ast::BinaryExpression* expr); /// Generates a call expression /// @param expr the expression to generate /// @returns the expression ID on success or 0 otherwise @@ -395,6 +405,7 @@ class Builder { ast::Module* mod_; std::string error_; uint32_t next_id_ = 1; + uint32_t current_label_id_ = 0; std::vector capabilities_; std::vector preamble_; std::vector debug_; diff --git a/src/writer/spirv/builder_binary_expression_test.cc b/src/writer/spirv/builder_binary_expression_test.cc index 3d4899f223..62986eced5 100644 --- a/src/writer/spirv/builder_binary_expression_test.cc +++ b/src/writer/spirv/builder_binary_expression_test.cc @@ -16,10 +16,12 @@ #include "gtest/gtest.h" #include "src/ast/binary_expression.h" +#include "src/ast/bool_literal.h" #include "src/ast/float_literal.h" #include "src/ast/identifier_expression.h" #include "src/ast/scalar_constructor_expression.h" #include "src/ast/sint_literal.h" +#include "src/ast/type/bool_type.h" #include "src/ast/type/f32_type.h" #include "src/ast/type/i32_type.h" #include "src/ast/type/matrix_type.h" @@ -866,6 +868,216 @@ TEST_F(BuilderTest, Binary_Multiply_MatrixMatrix) { )"); } +TEST_F(BuilderTest, Binary_LogicalAnd) { + ast::type::I32Type i32; + + auto lhs = std::make_unique( + ast::BinaryOp::kEqual, + std::make_unique( + std::make_unique(&i32, 1)), + std::make_unique( + std::make_unique(&i32, 2))); + + auto rhs = std::make_unique( + ast::BinaryOp::kEqual, + std::make_unique( + std::make_unique(&i32, 3)), + std::make_unique( + std::make_unique(&i32, 4))); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + + ast::BinaryExpression expr(ast::BinaryOp::kLogicalAnd, std::move(lhs), + std::move(rhs)); + + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + b.GenerateLabel(b.next_id()); + + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 12u) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 +%3 = OpConstant %2 1 +%4 = OpConstant %2 2 +%6 = OpTypeBool +%9 = OpConstant %2 3 +%10 = OpConstant %2 4 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%1 = OpLabel +%5 = OpIEqual %6 %3 %4 +OpSelectionMerge %7 None +OpBranchConditional %5 %8 %7 +%8 = OpLabel +%11 = OpIEqual %6 %9 %10 +OpBranch %7 +%7 = OpLabel +%12 = OpPhi %6 %5 %1 %11 %8 +)"); +} + +TEST_F(BuilderTest, Binary_LogicalAnd_WithLoads) { + ast::type::BoolType bool_type; + + auto a_var = std::make_unique( + "a", ast::StorageClass::kFunction, &bool_type); + a_var->set_constructor(std::make_unique( + std::make_unique(&bool_type, true))); + auto b_var = std::make_unique( + "b", ast::StorageClass::kFunction, &bool_type); + b_var->set_constructor(std::make_unique( + std::make_unique(&bool_type, false))); + + auto lhs = std::make_unique("a"); + auto rhs = std::make_unique("b"); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(a_var.get()); + td.RegisterVariableForTesting(b_var.get()); + + ast::BinaryExpression expr(ast::BinaryOp::kLogicalAnd, std::move(lhs), + std::move(rhs)); + + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + b.GenerateLabel(b.next_id()); + + ASSERT_TRUE(b.GenerateGlobalVariable(a_var.get())) << b.error(); + ASSERT_TRUE(b.GenerateGlobalVariable(b_var.get())) << b.error(); + + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 12u) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeBool +%3 = OpConstantTrue %2 +%5 = OpTypePointer Function %2 +%4 = OpVariable %5 Function %3 +%6 = OpConstantFalse %2 +%7 = OpVariable %5 Function %6 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%1 = OpLabel +%8 = OpLoad %2 %4 +OpSelectionMerge %9 None +OpBranchConditional %8 %10 %9 +%10 = OpLabel +%11 = OpLoad %2 %7 +OpBranch %9 +%9 = OpLabel +%12 = OpPhi %2 %8 %1 %11 %10 +)"); +} + +TEST_F(BuilderTest, Binary_LogicalOr) { + ast::type::I32Type i32; + + auto lhs = std::make_unique( + ast::BinaryOp::kEqual, + std::make_unique( + std::make_unique(&i32, 1)), + std::make_unique( + std::make_unique(&i32, 2))); + + auto rhs = std::make_unique( + ast::BinaryOp::kEqual, + std::make_unique( + std::make_unique(&i32, 3)), + std::make_unique( + std::make_unique(&i32, 4))); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + + ast::BinaryExpression expr(ast::BinaryOp::kLogicalOr, std::move(lhs), + std::move(rhs)); + + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + b.GenerateLabel(b.next_id()); + + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 12u) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 +%3 = OpConstant %2 1 +%4 = OpConstant %2 2 +%6 = OpTypeBool +%9 = OpConstant %2 3 +%10 = OpConstant %2 4 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%1 = OpLabel +%5 = OpIEqual %6 %3 %4 +OpSelectionMerge %7 None +OpBranchConditional %5 %7 %8 +%8 = OpLabel +%11 = OpIEqual %6 %9 %10 +OpBranch %7 +%7 = OpLabel +%12 = OpPhi %6 %5 %1 %11 %8 +)"); +} + +TEST_F(BuilderTest, Binary_LogicalOr_WithLoads) { + ast::type::BoolType bool_type; + + auto a_var = std::make_unique( + "a", ast::StorageClass::kFunction, &bool_type); + a_var->set_constructor(std::make_unique( + std::make_unique(&bool_type, true))); + auto b_var = std::make_unique( + "b", ast::StorageClass::kFunction, &bool_type); + b_var->set_constructor(std::make_unique( + std::make_unique(&bool_type, false))); + + auto lhs = std::make_unique("a"); + auto rhs = std::make_unique("b"); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(a_var.get()); + td.RegisterVariableForTesting(b_var.get()); + + ast::BinaryExpression expr(ast::BinaryOp::kLogicalOr, std::move(lhs), + std::move(rhs)); + + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + b.GenerateLabel(b.next_id()); + + ASSERT_TRUE(b.GenerateGlobalVariable(a_var.get())) << b.error(); + ASSERT_TRUE(b.GenerateGlobalVariable(b_var.get())) << b.error(); + + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 12u) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeBool +%3 = OpConstantTrue %2 +%5 = OpTypePointer Function %2 +%4 = OpVariable %5 Function %3 +%6 = OpConstantFalse %2 +%7 = OpVariable %5 Function %6 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%1 = OpLabel +%8 = OpLoad %2 %4 +OpSelectionMerge %9 None +OpBranchConditional %8 %9 %10 +%10 = OpLabel +%11 = OpLoad %2 %7 +OpBranch %9 +%9 = OpLabel +%12 = OpPhi %2 %8 %1 %11 %10 +)"); +} + } // namespace } // namespace spirv } // namespace writer diff --git a/src/writer/spirv/function.h b/src/writer/spirv/function.h index 26ddfa7a9c..ddc2d0d55e 100644 --- a/src/writer/spirv/function.h +++ b/src/writer/spirv/function.h @@ -52,6 +52,9 @@ class Function { /// @returns the declaration const Instruction& declaration() const { return declaration_; } + /// @returns the function label id + uint32_t label_id() const { return label_op_.to_i(); } + /// Adds an instruction to the instruction list /// @param op the op to set /// @param operands the operands for the instruction