diff --git a/src/tint/ir/builder_impl.cc b/src/tint/ir/builder_impl.cc index 19caf6fc3c..fd6f500298 100644 --- a/src/tint/ir/builder_impl.cc +++ b/src/tint/ir/builder_impl.cc @@ -778,7 +778,66 @@ utils::Result BuilderImpl::EmitUnary(const ast::UnaryOpExpression* expr) return inst; } +// A short-circut needs special treatment. The short-circuit is decomposed into the relevant if +// statements and declarations. +utils::Result BuilderImpl::EmitShortCircuit(const ast::BinaryExpression* expr) { + switch (expr->op) { + case ast::BinaryOp::kLogicalAnd: + case ast::BinaryOp::kLogicalOr: + break; + default: + TINT_ICE(IR, diagnostics_) << "invalid operation type for short-circut decomposition"; + return utils::Failure; + } + + auto lhs = EmitExpression(expr->lhs); + if (!lhs) { + return utils::Failure; + } + + auto* ty = builder.ir.types.Get(); + auto* result_var = + builder.Declare(ty, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite); + current_flow_block->instructions.Push(result_var); + + auto* lhs_store = builder.Store(result_var, lhs.Get()); + current_flow_block->instructions.Push(lhs_store); + + auto* if_node = builder.CreateIf(); + if_node->condition = lhs.Get(); + BranchTo(if_node); + + utils::Result rhs; + { + FlowStackScope scope(this, if_node); + + // If this is an `&&` then we only evaluate the RHS expression in the true block. + // If this is an `||` then we only evaluate the RHS expression in the false block. + if (expr->op == ast::BinaryOp::kLogicalAnd) { + current_flow_block = if_node->true_.target->As(); + } else { + current_flow_block = if_node->false_.target->As(); + } + + rhs = EmitExpression(expr->rhs); + if (!rhs) { + return utils::Failure; + } + auto* rhs_store = builder.Store(result_var, rhs.Get()); + current_flow_block->instructions.Push(rhs_store); + + BranchTo(if_node->merge.target); + } + current_flow_block = if_node->merge.target->As(); + + return result_var; +} + utils::Result BuilderImpl::EmitBinary(const ast::BinaryExpression* expr) { + if (expr->op == ast::BinaryOp::kLogicalAnd || expr->op == ast::BinaryOp::kLogicalOr) { + return EmitShortCircuit(expr); + } + auto lhs = EmitExpression(expr->lhs); if (!lhs) { return utils::Failure; @@ -803,12 +862,6 @@ utils::Result BuilderImpl::EmitBinary(const ast::BinaryExpression* expr) case ast::BinaryOp::kXor: inst = builder.Xor(ty, lhs.Get(), rhs.Get()); break; - case ast::BinaryOp::kLogicalAnd: - inst = builder.LogicalAnd(ty, lhs.Get(), rhs.Get()); - break; - case ast::BinaryOp::kLogicalOr: - inst = builder.LogicalOr(ty, lhs.Get(), rhs.Get()); - break; case ast::BinaryOp::kEqual: inst = builder.Equal(ty, lhs.Get(), rhs.Get()); break; @@ -848,6 +901,10 @@ utils::Result BuilderImpl::EmitBinary(const ast::BinaryExpression* expr) case ast::BinaryOp::kModulo: inst = builder.Modulo(ty, lhs.Get(), rhs.Get()); break; + case ast::BinaryOp::kLogicalAnd: + case ast::BinaryOp::kLogicalOr: + TINT_ICE(IR, diagnostics_) << "short circuit op should have already been handled"; + return utils::Failure; case ast::BinaryOp::kNone: TINT_ICE(IR, diagnostics_) << "missing binary operand type"; return utils::Failure; diff --git a/src/tint/ir/builder_impl.h b/src/tint/ir/builder_impl.h index 7b18492fdf..58d4844d53 100644 --- a/src/tint/ir/builder_impl.h +++ b/src/tint/ir/builder_impl.h @@ -166,6 +166,11 @@ class BuilderImpl { /// @returns the value storing the result if successful, utils::Failure otherwise utils::Result EmitUnary(const ast::UnaryOpExpression* expr); + /// Emits a short-circult binary expression + /// @param expr the binary expression + /// @returns the value storing the result if successful, utils::Failure otherwise + utils::Result EmitShortCircuit(const ast::BinaryExpression* expr); + /// Emits a binary expression /// @param expr the binary expression /// @returns the value storing the result if successful, utils::Failure otherwise diff --git a/src/tint/ir/builder_impl_test.cc b/src/tint/ir/builder_impl_test.cc index 1e6a26dae9..bcba641637 100644 --- a/src/tint/ir/builder_impl_test.cc +++ b/src/tint/ir/builder_impl_test.cc @@ -1773,16 +1773,33 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_LogicalAnd) { auto* expr = LogicalAnd(Call("my_func"), false); WrapInFunction(expr); - auto& b = CreateBuilder(); - InjectFlowBlock(); - auto r = b.EmitExpression(expr); - ASSERT_THAT(b.Diagnostics(), testing::IsEmpty()); - ASSERT_TRUE(r); + auto r = Build(); + ASSERT_TRUE(r) << Error(); + auto m = r.Move(); + + EXPECT_EQ(Disassemble(m), R"(%fn0 = func my_func + %fn1 = block + ret true +func_end + +%fn2 = func test_function + %fn3 = block + %1(bool) = call my_func + %2(bool) = var function read_write + store %2(bool), %1(bool) + branch %fn4 + + %fn4 = if %1(bool) [t: %fn5, f: %fn6, m: %fn7] + # true branch + %fn5 = block + store %2(bool), false + branch %fn7 + + # if merge + %fn7 = block + ret +func_end - Disassembler d(b.builder.ir); - d.EmitBlockInstructions(b.current_flow_block->As()); - EXPECT_EQ(d.AsString(), R"(%1(bool) = call my_func -%2(bool) = log_and %1(bool), false )"); } @@ -1791,16 +1808,34 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_LogicalOr) { auto* expr = LogicalOr(Call("my_func"), true); WrapInFunction(expr); - auto& b = CreateBuilder(); - InjectFlowBlock(); - auto r = b.EmitExpression(expr); - ASSERT_THAT(b.Diagnostics(), testing::IsEmpty()); - ASSERT_TRUE(r); + auto r = Build(); + ASSERT_TRUE(r) << Error(); + auto m = r.Move(); + + EXPECT_EQ(Disassemble(m), R"(%fn0 = func my_func + %fn1 = block + ret true +func_end + +%fn2 = func test_function + %fn3 = block + %1(bool) = call my_func + %2(bool) = var function read_write + store %2(bool), %1(bool) + branch %fn4 + + %fn4 = if %1(bool) [t: %fn5, f: %fn6, m: %fn7] + # true branch + # false branch + %fn6 = block + store %2(bool), true + branch %fn7 + + # if merge + %fn7 = block + ret +func_end - Disassembler d(b.builder.ir); - d.EmitBlockInstructions(b.current_flow_block->As()); - EXPECT_EQ(d.AsString(), R"(%1(bool) = call my_func -%2(bool) = log_or %1(bool), true )"); } @@ -1955,22 +1990,39 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_Compound) { GreaterThan(2.5_f, Div(Call("my_func"), Mul(2.3_f, Call("my_func"))))); WrapInFunction(expr); - auto& b = CreateBuilder(); - InjectFlowBlock(); - auto r = b.EmitExpression(expr); - ASSERT_THAT(b.Diagnostics(), testing::IsEmpty()); - ASSERT_TRUE(r); + auto r = Build(); + ASSERT_TRUE(r) << Error(); + auto m = r.Move(); + + EXPECT_EQ(Disassemble(m), R"(%fn0 = func my_func + %fn1 = block + ret 0.0f +func_end + +%fn2 = func test_function + %fn3 = block + %1(f32) = call my_func + %2(bool) = lt %1(f32), 2.0f + %3(bool) = var function read_write + store %3(bool), %2(bool) + branch %fn4 + + %fn4 = if %2(bool) [t: %fn5, f: %fn6, m: %fn7] + # true branch + %fn5 = block + %4(f32) = call my_func + %5(f32) = call my_func + %6(f32) = mul 2.29999995231628417969f, %5(f32) + %7(f32) = div %4(f32), %6(f32) + %8(bool) = gt 2.5f, %7(f32) + store %3(bool), %8(bool) + branch %fn7 + + # if merge + %fn7 = block + ret +func_end - Disassembler d(b.builder.ir); - d.EmitBlockInstructions(b.current_flow_block->As()); - EXPECT_EQ(d.AsString(), R"(%1(f32) = call my_func -%2(bool) = lt %1(f32), 2.0f -%3(f32) = call my_func -%4(f32) = call my_func -%5(f32) = mul 2.29999995231628417969f, %4(f32) -%6(f32) = div %3(f32), %5(f32) -%7(bool) = gt 2.5f, %6(f32) -%8(bool) = log_and %2(bool), %7(bool) )"); } @@ -1980,15 +2032,21 @@ TEST_F(IR_BuilderImplTest, EmitExpression_Binary_Compound_WithConstEval) { GreaterThan(2.5_f, Div(10_f, Mul(2.3_f, 9.4_f))))); WrapInFunction(expr); - auto& b = CreateBuilder(); - InjectFlowBlock(); - auto r = b.EmitExpression(expr); - ASSERT_THAT(b.Diagnostics(), testing::IsEmpty()); - ASSERT_TRUE(r); + auto r = Build(); + ASSERT_TRUE(r) << Error(); + auto m = r.Move(); + + EXPECT_EQ(Disassemble(m), R"(%fn0 = func my_func + %fn1 = block + ret true +func_end + +%fn2 = func test_function + %fn3 = block + %1(bool) = call my_func, false + ret +func_end - Disassembler d(b.builder.ir); - d.EmitBlockInstructions(b.current_flow_block->As()); - EXPECT_EQ(d.AsString(), R"(%1(bool) = call my_func, false )"); } diff --git a/src/tint/ir/disassembler.cc b/src/tint/ir/disassembler.cc index 2b8de0bdf4..30da559ae4 100644 --- a/src/tint/ir/disassembler.cc +++ b/src/tint/ir/disassembler.cc @@ -203,8 +203,10 @@ void Disassembler::Walk(const FlowNode* node) { Indent() << "# true branch" << std::endl; Walk(i->true_.target); - Indent() << "# false branch" << std::endl; - Walk(i->false_.target); + if (!i->false_.target->IsDead()) { + Indent() << "# false branch" << std::endl; + Walk(i->false_.target); + } } if (i->merge.target->IsConnected()) {