[ir] Emit short-circuit as an `If` node

This Cl removes the `&&` and `||` logical binary nodes and replaces them
with a var declaration and if node.

Bug: tint:1925
Change-Id: I9f25411a9b9c909fa25f2f37cbd51181ac584acc
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/130500
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
dan sinclair 2023-05-03 14:59:03 +00:00 committed by Dawn LUCI CQ
parent 83d8c42fc6
commit e903396ff2
4 changed files with 171 additions and 49 deletions

View File

@ -778,7 +778,66 @@ utils::Result<Value*> BuilderImpl::EmitUnary(const ast::UnaryOpExpression* expr)
return inst;
}
// A short-circut needs special treatment. The short-circuit is decomposed into the relevant if
// statements and declarations.
utils::Result<Value*> BuilderImpl::EmitShortCircuit(const ast::BinaryExpression* expr) {
switch (expr->op) {
case ast::BinaryOp::kLogicalAnd:
case ast::BinaryOp::kLogicalOr:
break;
default:
TINT_ICE(IR, diagnostics_) << "invalid operation type for short-circut decomposition";
return utils::Failure;
}
auto lhs = EmitExpression(expr->lhs);
if (!lhs) {
return utils::Failure;
}
auto* ty = builder.ir.types.Get<type::Bool>();
auto* result_var =
builder.Declare(ty, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite);
current_flow_block->instructions.Push(result_var);
auto* lhs_store = builder.Store(result_var, lhs.Get());
current_flow_block->instructions.Push(lhs_store);
auto* if_node = builder.CreateIf();
if_node->condition = lhs.Get();
BranchTo(if_node);
utils::Result<Value*> rhs;
{
FlowStackScope scope(this, if_node);
// If this is an `&&` then we only evaluate the RHS expression in the true block.
// If this is an `||` then we only evaluate the RHS expression in the false block.
if (expr->op == ast::BinaryOp::kLogicalAnd) {
current_flow_block = if_node->true_.target->As<Block>();
} else {
current_flow_block = if_node->false_.target->As<Block>();
}
rhs = EmitExpression(expr->rhs);
if (!rhs) {
return utils::Failure;
}
auto* rhs_store = builder.Store(result_var, rhs.Get());
current_flow_block->instructions.Push(rhs_store);
BranchTo(if_node->merge.target);
}
current_flow_block = if_node->merge.target->As<Block>();
return result_var;
}
utils::Result<Value*> BuilderImpl::EmitBinary(const ast::BinaryExpression* expr) {
if (expr->op == ast::BinaryOp::kLogicalAnd || expr->op == ast::BinaryOp::kLogicalOr) {
return EmitShortCircuit(expr);
}
auto lhs = EmitExpression(expr->lhs);
if (!lhs) {
return utils::Failure;
@ -803,12 +862,6 @@ utils::Result<Value*> BuilderImpl::EmitBinary(const ast::BinaryExpression* expr)
case ast::BinaryOp::kXor:
inst = builder.Xor(ty, lhs.Get(), rhs.Get());
break;
case ast::BinaryOp::kLogicalAnd:
inst = builder.LogicalAnd(ty, lhs.Get(), rhs.Get());
break;
case ast::BinaryOp::kLogicalOr:
inst = builder.LogicalOr(ty, lhs.Get(), rhs.Get());
break;
case ast::BinaryOp::kEqual:
inst = builder.Equal(ty, lhs.Get(), rhs.Get());
break;
@ -848,6 +901,10 @@ utils::Result<Value*> BuilderImpl::EmitBinary(const ast::BinaryExpression* expr)
case ast::BinaryOp::kModulo:
inst = builder.Modulo(ty, lhs.Get(), rhs.Get());
break;
case ast::BinaryOp::kLogicalAnd:
case ast::BinaryOp::kLogicalOr:
TINT_ICE(IR, diagnostics_) << "short circuit op should have already been handled";
return utils::Failure;
case ast::BinaryOp::kNone:
TINT_ICE(IR, diagnostics_) << "missing binary operand type";
return utils::Failure;

View File

@ -166,6 +166,11 @@ class BuilderImpl {
/// @returns the value storing the result if successful, utils::Failure otherwise
utils::Result<Value*> EmitUnary(const ast::UnaryOpExpression* expr);
/// Emits a short-circult binary expression
/// @param expr the binary expression
/// @returns the value storing the result if successful, utils::Failure otherwise
utils::Result<Value*> EmitShortCircuit(const ast::BinaryExpression* expr);
/// Emits a binary expression
/// @param expr the binary expression
/// @returns the value storing the result if successful, utils::Failure otherwise

View File

@ -1773,16 +1773,33 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_LogicalAnd) {
auto* expr = LogicalAnd(Call("my_func"), false);
WrapInFunction(expr);
auto& b = CreateBuilder();
InjectFlowBlock();
auto r = b.EmitExpression(expr);
ASSERT_THAT(b.Diagnostics(), testing::IsEmpty());
ASSERT_TRUE(r);
auto r = Build();
ASSERT_TRUE(r) << Error();
auto m = r.Move();
EXPECT_EQ(Disassemble(m), R"(%fn0 = func my_func
%fn1 = block
ret true
func_end
%fn2 = func test_function
%fn3 = block
%1(bool) = call my_func
%2(bool) = var function read_write
store %2(bool), %1(bool)
branch %fn4
%fn4 = if %1(bool) [t: %fn5, f: %fn6, m: %fn7]
# true branch
%fn5 = block
store %2(bool), false
branch %fn7
# if merge
%fn7 = block
ret
func_end
Disassembler d(b.builder.ir);
d.EmitBlockInstructions(b.current_flow_block->As<ir::Block>());
EXPECT_EQ(d.AsString(), R"(%1(bool) = call my_func
%2(bool) = log_and %1(bool), false
)");
}
@ -1791,16 +1808,34 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_LogicalOr) {
auto* expr = LogicalOr(Call("my_func"), true);
WrapInFunction(expr);
auto& b = CreateBuilder();
InjectFlowBlock();
auto r = b.EmitExpression(expr);
ASSERT_THAT(b.Diagnostics(), testing::IsEmpty());
ASSERT_TRUE(r);
auto r = Build();
ASSERT_TRUE(r) << Error();
auto m = r.Move();
EXPECT_EQ(Disassemble(m), R"(%fn0 = func my_func
%fn1 = block
ret true
func_end
%fn2 = func test_function
%fn3 = block
%1(bool) = call my_func
%2(bool) = var function read_write
store %2(bool), %1(bool)
branch %fn4
%fn4 = if %1(bool) [t: %fn5, f: %fn6, m: %fn7]
# true branch
# false branch
%fn6 = block
store %2(bool), true
branch %fn7
# if merge
%fn7 = block
ret
func_end
Disassembler d(b.builder.ir);
d.EmitBlockInstructions(b.current_flow_block->As<ir::Block>());
EXPECT_EQ(d.AsString(), R"(%1(bool) = call my_func
%2(bool) = log_or %1(bool), true
)");
}
@ -1955,22 +1990,39 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_Compound) {
GreaterThan(2.5_f, Div(Call("my_func"), Mul(2.3_f, Call("my_func")))));
WrapInFunction(expr);
auto& b = CreateBuilder();
InjectFlowBlock();
auto r = b.EmitExpression(expr);
ASSERT_THAT(b.Diagnostics(), testing::IsEmpty());
ASSERT_TRUE(r);
auto r = Build();
ASSERT_TRUE(r) << Error();
auto m = r.Move();
Disassembler d(b.builder.ir);
d.EmitBlockInstructions(b.current_flow_block->As<ir::Block>());
EXPECT_EQ(d.AsString(), R"(%1(f32) = call my_func
EXPECT_EQ(Disassemble(m), R"(%fn0 = func my_func
%fn1 = block
ret 0.0f
func_end
%fn2 = func test_function
%fn3 = block
%1(f32) = call my_func
%2(bool) = lt %1(f32), 2.0f
%3(f32) = call my_func
%3(bool) = var function read_write
store %3(bool), %2(bool)
branch %fn4
%fn4 = if %2(bool) [t: %fn5, f: %fn6, m: %fn7]
# true branch
%fn5 = block
%4(f32) = call my_func
%5(f32) = mul 2.29999995231628417969f, %4(f32)
%6(f32) = div %3(f32), %5(f32)
%7(bool) = gt 2.5f, %6(f32)
%8(bool) = log_and %2(bool), %7(bool)
%5(f32) = call my_func
%6(f32) = mul 2.29999995231628417969f, %5(f32)
%7(f32) = div %4(f32), %6(f32)
%8(bool) = gt 2.5f, %7(f32)
store %3(bool), %8(bool)
branch %fn7
# if merge
%fn7 = block
ret
func_end
)");
}
@ -1980,15 +2032,21 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_Compound_WithConstEval) {
GreaterThan(2.5_f, Div(10_f, Mul(2.3_f, 9.4_f)))));
WrapInFunction(expr);
auto& b = CreateBuilder();
InjectFlowBlock();
auto r = b.EmitExpression(expr);
ASSERT_THAT(b.Diagnostics(), testing::IsEmpty());
ASSERT_TRUE(r);
auto r = Build();
ASSERT_TRUE(r) << Error();
auto m = r.Move();
EXPECT_EQ(Disassemble(m), R"(%fn0 = func my_func
%fn1 = block
ret true
func_end
%fn2 = func test_function
%fn3 = block
%1(bool) = call my_func, false
ret
func_end
Disassembler d(b.builder.ir);
d.EmitBlockInstructions(b.current_flow_block->As<ir::Block>());
EXPECT_EQ(d.AsString(), R"(%1(bool) = call my_func, false
)");
}

View File

@ -203,9 +203,11 @@ void Disassembler::Walk(const FlowNode* node) {
Indent() << "# true branch" << std::endl;
Walk(i->true_.target);
if (!i->false_.target->IsDead()) {
Indent() << "# false branch" << std::endl;
Walk(i->false_.target);
}
}
if (i->merge.target->IsConnected()) {
Indent() << "# if merge" << std::endl;