diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index f15e7e29f5..7f2dd0036c 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -1968,6 +1968,7 @@ if (tint_build_unittests) { "writer/spirv/ir/generator_impl_ir_constant_test.cc", "writer/spirv/ir/generator_impl_ir_function_test.cc", "writer/spirv/ir/generator_impl_ir_if_test.cc", + "writer/spirv/ir/generator_impl_ir_loop_test.cc", "writer/spirv/ir/generator_impl_ir_test.cc", "writer/spirv/ir/generator_impl_ir_type_test.cc", "writer/spirv/ir/generator_impl_ir_var_test.cc", diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index 7217a3513d..2e097f449f 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -1258,6 +1258,7 @@ if(TINT_BUILD_TESTS) writer/spirv/ir/generator_impl_ir_constant_test.cc writer/spirv/ir/generator_impl_ir_function_test.cc writer/spirv/ir/generator_impl_ir_if_test.cc + writer/spirv/ir/generator_impl_ir_loop_test.cc writer/spirv/ir/generator_impl_ir_test.cc writer/spirv/ir/generator_impl_ir_type_test.cc writer/spirv/ir/generator_impl_ir_var_test.cc diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc index 4bc7955fba..89a1a1ca44 100644 --- a/src/tint/writer/spirv/ir/generator_impl_ir.cc +++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc @@ -20,11 +20,16 @@ #include "spirv/unified1/spirv.h" #include "src/tint/ir/binary.h" #include "src/tint/ir/block.h" +#include "src/tint/ir/break_if.h" #include "src/tint/ir/builtin.h" +#include "src/tint/ir/continue.h" #include "src/tint/ir/exit_if.h" +#include "src/tint/ir/exit_loop.h" #include "src/tint/ir/if.h" #include "src/tint/ir/load.h" +#include "src/tint/ir/loop.h" #include "src/tint/ir/module.h" +#include "src/tint/ir/next_iteration.h" #include "src/tint/ir/return.h" #include "src/tint/ir/store.h" #include "src/tint/ir/transform/add_empty_entry_point.h" @@ -334,6 +339,10 @@ void GeneratorImplIr::EmitBlock(const ir::Block* block) { [&](const ir::Binary* b) { return EmitBinary(b); }, [&](const ir::Builtin* b) { return EmitBuiltin(b); }, [&](const ir::Load* l) { return EmitLoad(l); }, + [&](const ir::Loop* l) { + EmitLoop(l); + return 0u; + }, [&](const ir::Store* s) { EmitStore(s); return 0u; @@ -371,9 +380,26 @@ void GeneratorImplIr::EmitBranch(const ir::Branch* b) { } return; }, + [&](const ir::BreakIf* breakif) { + current_function_.push_inst(spv::Op::OpBranchConditional, + { + Value(breakif->Condition()), + Label(breakif->Loop()->Merge()), + Label(breakif->Loop()->Start()), + }); + }, + [&](const ir::Continue* cont) { + current_function_.push_inst(spv::Op::OpBranch, {Label(cont->Loop()->Continuing())}); + }, [&](const ir::ExitIf* if_) { current_function_.push_inst(spv::Op::OpBranch, {Label(if_->If()->Merge())}); }, + [&](const ir::ExitLoop* loop) { + current_function_.push_inst(spv::Op::OpBranch, {Label(loop->Loop()->Merge())}); + }, + [&](const ir::NextIteration* loop) { + current_function_.push_inst(spv::Op::OpBranch, {Label(loop->Loop()->Start())}); + }, [&](Default) { TINT_ICE(Writer, diagnostics_) << "unimplemented branch: " << b->TypeInfo().name; }); @@ -594,6 +620,37 @@ uint32_t GeneratorImplIr::EmitLoad(const ir::Load* load) { return id; } +void GeneratorImplIr::EmitLoop(const ir::Loop* loop) { + auto header_label = module_.NextId(); + auto body_label = Label(loop->Start()); + auto continuing_label = Label(loop->Continuing()); + auto merge_label = Label(loop->Merge()); + + // Branch to and emit the loop header, which contains OpLoopMerge and OpBranch instructions. + current_function_.push_inst(spv::Op::OpBranch, {header_label}); + current_function_.push_inst(spv::Op::OpLabel, {header_label}); + current_function_.push_inst( + spv::Op::OpLoopMerge, {merge_label, continuing_label, U32Operand(SpvLoopControlMaskNone)}); + current_function_.push_inst(spv::Op::OpBranch, {body_label}); + + // Emit the loop body. + EmitBlock(loop->Start()); + + // Emit the loop continuing block. + // The back-edge needs to go to the loop header, so update the label for the start block. + block_labels_.Replace(loop->Start(), header_label); + if (loop->Continuing()->HasBranchTarget()) { + EmitBlock(loop->Continuing()); + } else { + // We still need to emit a continuing block with a back-edge, even if it is unreachable. + current_function_.push_inst(spv::Op::OpLabel, {continuing_label}); + current_function_.push_inst(spv::Op::OpBranch, {header_label}); + } + + // Emit the loop merge block. + EmitBlock(loop->Merge()); +} + void GeneratorImplIr::EmitStore(const ir::Store* store) { current_function_.push_inst(spv::Op::OpStore, {Value(store->To()), Value(store->From())}); } diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.h b/src/tint/writer/spirv/ir/generator_impl_ir.h index d42cce694e..5293fea35c 100644 --- a/src/tint/writer/spirv/ir/generator_impl_ir.h +++ b/src/tint/writer/spirv/ir/generator_impl_ir.h @@ -36,6 +36,7 @@ class Builtin; class If; class Function; class Load; +class Loop; class Module; class Store; class UserCall; @@ -121,6 +122,10 @@ class GeneratorImplIr { /// @returns the result ID of the instruction uint32_t EmitLoad(const ir::Load* load); + /// Emit a loop instruction. + /// @param loop the loop instruction to emit + void EmitLoop(const ir::Loop* loop); + /// Emit a store instruction. /// @param store the store instruction to emit void EmitStore(const ir::Store* store); diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc new file mode 100644 index 0000000000..8ec7a9c8c2 --- /dev/null +++ b/src/tint/writer/spirv/ir/generator_impl_ir_loop_test.cc @@ -0,0 +1,334 @@ +// Copyright 2023 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/writer/spirv/ir/test_helper_ir.h" + +using namespace tint::number_suffixes; // NOLINT + +namespace tint::writer::spirv { +namespace { + +TEST_F(SpvGeneratorImplTest, Loop_BreakIf) { + auto* func = b.CreateFunction("foo", mod.Types().void_()); + + auto* loop = b.CreateLoop(); + + loop->Start()->Instructions().Push(b.Continue(loop)); + loop->Continuing()->Instructions().Push(b.BreakIf(b.Constant(true), loop)); + loop->Merge()->Instructions().Push(b.Return(func)); + + func->StartTarget()->Instructions().Push(loop); + + generator_.EmitFunction(func); + EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%10 = OpTypeBool +%9 = OpConstantTrue %10 +%1 = OpFunction %2 None %3 +%4 = OpLabel +OpBranch %5 +%5 = OpLabel +OpLoopMerge %8 %7 None +OpBranch %6 +%6 = OpLabel +OpBranch %7 +%7 = OpLabel +OpBranchConditional %9 %8 %5 +%8 = OpLabel +OpReturn +OpFunctionEnd +)"); +} + +// Test that we still emit the continuing block with a back-edge, even when it is unreachable. +TEST_F(SpvGeneratorImplTest, Loop_UnconditionalBreakInBody) { + auto* func = b.CreateFunction("foo", mod.Types().void_()); + + auto* loop = b.CreateLoop(); + + loop->Start()->Instructions().Push(b.ExitLoop(loop)); + loop->Merge()->Instructions().Push(b.Return(func)); + + func->StartTarget()->Instructions().Push(loop); + + generator_.EmitFunction(func); + EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%1 = OpFunction %2 None %3 +%4 = OpLabel +OpBranch %5 +%5 = OpLabel +OpLoopMerge %8 %7 None +OpBranch %6 +%6 = OpLabel +OpBranch %8 +%7 = OpLabel +OpBranch %5 +%8 = OpLabel +OpReturn +OpFunctionEnd +)"); +} + +TEST_F(SpvGeneratorImplTest, Loop_ConditionalBreakInBody) { + auto* func = b.CreateFunction("foo", mod.Types().void_()); + + auto* loop = b.CreateLoop(); + + auto* cond_break = b.CreateIf(b.Constant(true)); + cond_break->True()->Instructions().Push(b.ExitLoop(loop)); + cond_break->False()->Instructions().Push(b.ExitIf(cond_break)); + cond_break->Merge()->Instructions().Push(b.Continue(loop)); + + loop->Start()->Instructions().Push(cond_break); + loop->Continuing()->Instructions().Push(b.NextIteration(loop)); + loop->Merge()->Instructions().Push(b.Return(func)); + + func->StartTarget()->Instructions().Push(loop); + + generator_.EmitFunction(func); + EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%12 = OpTypeBool +%11 = OpConstantTrue %12 +%1 = OpFunction %2 None %3 +%4 = OpLabel +OpBranch %5 +%5 = OpLabel +OpLoopMerge %8 %7 None +OpBranch %6 +%6 = OpLabel +OpSelectionMerge %9 None +OpBranchConditional %11 %10 %9 +%10 = OpLabel +OpBranch %8 +%9 = OpLabel +OpBranch %7 +%7 = OpLabel +OpBranch %5 +%8 = OpLabel +OpReturn +OpFunctionEnd +)"); +} + +TEST_F(SpvGeneratorImplTest, Loop_ConditionalContinueInBody) { + auto* func = b.CreateFunction("foo", mod.Types().void_()); + + auto* loop = b.CreateLoop(); + + auto* cond_break = b.CreateIf(b.Constant(true)); + cond_break->True()->Instructions().Push(b.Continue(loop)); + cond_break->False()->Instructions().Push(b.ExitIf(cond_break)); + cond_break->Merge()->Instructions().Push(b.ExitLoop(loop)); + + loop->Start()->Instructions().Push(cond_break); + loop->Continuing()->Instructions().Push(b.NextIteration(loop)); + loop->Merge()->Instructions().Push(b.Return(func)); + + func->StartTarget()->Instructions().Push(loop); + + generator_.EmitFunction(func); + EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%12 = OpTypeBool +%11 = OpConstantTrue %12 +%1 = OpFunction %2 None %3 +%4 = OpLabel +OpBranch %5 +%5 = OpLabel +OpLoopMerge %8 %7 None +OpBranch %6 +%6 = OpLabel +OpSelectionMerge %9 None +OpBranchConditional %11 %10 %9 +%10 = OpLabel +OpBranch %7 +%9 = OpLabel +OpBranch %8 +%7 = OpLabel +OpBranch %5 +%8 = OpLabel +OpReturn +OpFunctionEnd +)"); +} + +// Test that we still emit the continuing block with a back-edge, and the merge block, even when +// they are unreachable. +TEST_F(SpvGeneratorImplTest, Loop_UnconditionalReturnInBody) { + auto* func = b.CreateFunction("foo", mod.Types().void_()); + + auto* loop = b.CreateLoop(); + + loop->Start()->Instructions().Push(b.Return(func)); + + func->StartTarget()->Instructions().Push(loop); + + generator_.EmitFunction(func); + EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%1 = OpFunction %2 None %3 +%4 = OpLabel +OpBranch %5 +%5 = OpLabel +OpLoopMerge %8 %7 None +OpBranch %6 +%6 = OpLabel +OpReturn +%7 = OpLabel +OpBranch %5 +%8 = OpLabel +OpUnreachable +OpFunctionEnd +)"); +} + +TEST_F(SpvGeneratorImplTest, Loop_UseResultFromBodyInContinuing) { + auto* func = b.CreateFunction("foo", mod.Types().void_()); + + auto* loop = b.CreateLoop(); + + auto* result = b.Equal(mod.Types().i32(), b.Constant(1_i), b.Constant(2_i)); + + loop->Start()->Instructions().Push(result); + loop->Continuing()->Instructions().Push(b.BreakIf(result, loop)); + loop->Merge()->Instructions().Push(b.Return(func)); + + func->StartTarget()->Instructions().Push(loop); + + generator_.EmitFunction(func); + EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%10 = OpTypeInt 32 1 +%11 = OpConstant %10 1 +%12 = OpConstant %10 2 +%1 = OpFunction %2 None %3 +%4 = OpLabel +OpBranch %5 +%5 = OpLabel +OpLoopMerge %8 %7 None +OpBranch %6 +%6 = OpLabel +%9 = OpIEqual %10 %11 %12 +%7 = OpLabel +OpBranchConditional %9 %8 %5 +%8 = OpLabel +OpReturn +OpFunctionEnd +)"); +} + +TEST_F(SpvGeneratorImplTest, Loop_NestedLoopInBody) { + auto* func = b.CreateFunction("foo", mod.Types().void_()); + + auto* outer_loop = b.CreateLoop(); + auto* inner_loop = b.CreateLoop(); + + inner_loop->Start()->Instructions().Push(b.ExitLoop(inner_loop)); + inner_loop->Continuing()->Instructions().Push(b.NextIteration(inner_loop)); + inner_loop->Merge()->Instructions().Push(b.Continue(outer_loop)); + + outer_loop->Start()->Instructions().Push(inner_loop); + outer_loop->Continuing()->Instructions().Push(b.BreakIf(b.Constant(true), outer_loop)); + outer_loop->Merge()->Instructions().Push(b.Return(func)); + + func->StartTarget()->Instructions().Push(outer_loop); + + generator_.EmitFunction(func); + EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%14 = OpTypeBool +%13 = OpConstantTrue %14 +%1 = OpFunction %2 None %3 +%4 = OpLabel +OpBranch %5 +%5 = OpLabel +OpLoopMerge %8 %7 None +OpBranch %6 +%6 = OpLabel +OpBranch %9 +%9 = OpLabel +OpLoopMerge %12 %11 None +OpBranch %10 +%10 = OpLabel +OpBranch %12 +%11 = OpLabel +OpBranch %9 +%12 = OpLabel +OpBranch %7 +%7 = OpLabel +OpBranchConditional %13 %8 %5 +%8 = OpLabel +OpReturn +OpFunctionEnd +)"); +} + +TEST_F(SpvGeneratorImplTest, Loop_NestedLoopInContinuing) { + auto* func = b.CreateFunction("foo", mod.Types().void_()); + + auto* outer_loop = b.CreateLoop(); + auto* inner_loop = b.CreateLoop(); + + inner_loop->Start()->Instructions().Push(b.Continue(inner_loop)); + inner_loop->Continuing()->Instructions().Push(b.BreakIf(b.Constant(true), inner_loop)); + inner_loop->Merge()->Instructions().Push(b.BreakIf(b.Constant(true), outer_loop)); + + outer_loop->Start()->Instructions().Push(b.Continue(outer_loop)); + outer_loop->Continuing()->Instructions().Push(inner_loop); + outer_loop->Merge()->Instructions().Push(b.Return(func)); + + func->StartTarget()->Instructions().Push(outer_loop); + + generator_.EmitFunction(func); + EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%14 = OpTypeBool +%13 = OpConstantTrue %14 +%1 = OpFunction %2 None %3 +%4 = OpLabel +OpBranch %5 +%5 = OpLabel +OpLoopMerge %8 %7 None +OpBranch %6 +%6 = OpLabel +OpBranch %7 +%7 = OpLabel +OpBranch %9 +%9 = OpLabel +OpLoopMerge %12 %11 None +OpBranch %10 +%10 = OpLabel +OpBranch %11 +%11 = OpLabel +OpBranchConditional %13 %12 %9 +%12 = OpLabel +OpBranchConditional %13 %8 %5 +%8 = OpLabel +OpReturn +OpFunctionEnd +)"); +} + +} // namespace +} // namespace tint::writer::spirv