diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index 2de9310131..617d1f6e47 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -1940,6 +1940,7 @@ if (tint_build_unittests) { "writer/spirv/generator_impl_binary_test.cc", "writer/spirv/generator_impl_constant_test.cc", "writer/spirv/generator_impl_function_test.cc", + "writer/spirv/generator_impl_if_test.cc", "writer/spirv/generator_impl_ir_test.cc", "writer/spirv/generator_impl_type_test.cc", ] diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index 1f32aaf01a..9bf2714fa0 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -1241,6 +1241,7 @@ if(TINT_BUILD_TESTS) writer/spirv/generator_impl_binary_test.cc writer/spirv/generator_impl_constant_test.cc writer/spirv/generator_impl_function_test.cc + writer/spirv/generator_impl_if_test.cc writer/spirv/generator_impl_ir_test.cc writer/spirv/generator_impl_type_test.cc writer/spirv/test_helper_ir.h diff --git a/src/tint/writer/spirv/generator_impl_if_test.cc b/src/tint/writer/spirv/generator_impl_if_test.cc new file mode 100644 index 0000000000..f594c8bd2e --- /dev/null +++ b/src/tint/writer/spirv/generator_impl_if_test.cc @@ -0,0 +1,149 @@ +// Copyright 2023 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/writer/spirv/test_helper_ir.h" + +using namespace tint::number_suffixes; // NOLINT + +namespace tint::writer::spirv { +namespace { + +TEST_F(SpvGeneratorImplTest, If_TrueEmpty_FalseEmpty) { + auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get()); + + auto* i = b.CreateIf(b.Constant(true)); + b.Branch(i->true_.target->As(), i->merge.target); + b.Branch(i->false_.target->As(), i->merge.target); + b.Branch(i->merge.target->As(), func->end_target); + + b.Branch(func->start_target, i); + + generator_.EmitFunction(func); + EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%7 = OpTypeBool +%6 = OpConstantTrue %7 +%1 = OpFunction %2 None %3 +%4 = OpLabel +OpSelectionMerge %5 None +OpBranchConditional %6 %5 %5 +%5 = OpLabel +OpReturn +OpFunctionEnd +)"); +} + +TEST_F(SpvGeneratorImplTest, If_FalseEmpty) { + auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get()); + + auto* i = b.CreateIf(b.Constant(true)); + b.Branch(i->false_.target->As(), i->merge.target); + b.Branch(i->merge.target->As(), func->end_target); + + auto* true_block = i->true_.target->As(); + true_block->instructions.Push( + b.Add(mod.types.Get(), b.Constant(1_i), b.Constant(1_i))); + b.Branch(true_block, i->merge.target); + + b.Branch(func->start_target, i); + + generator_.EmitFunction(func); + EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%8 = OpTypeBool +%7 = OpConstantTrue %8 +%10 = OpTypeInt 32 1 +%11 = OpConstant %10 1 +%1 = OpFunction %2 None %3 +%4 = OpLabel +OpSelectionMerge %5 None +OpBranchConditional %7 %6 %5 +%6 = OpLabel +%9 = OpIAdd %10 %11 %11 +OpBranch %5 +%5 = OpLabel +OpReturn +OpFunctionEnd +)"); +} + +TEST_F(SpvGeneratorImplTest, If_TrueEmpty) { + auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get()); + + auto* i = b.CreateIf(b.Constant(true)); + b.Branch(i->true_.target->As(), i->merge.target); + b.Branch(i->merge.target->As(), func->end_target); + + auto* false_block = i->false_.target->As(); + false_block->instructions.Push( + b.Add(mod.types.Get(), b.Constant(1_i), b.Constant(1_i))); + b.Branch(false_block, i->merge.target); + + b.Branch(func->start_target, i); + + generator_.EmitFunction(func); + EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%8 = OpTypeBool +%7 = OpConstantTrue %8 +%10 = OpTypeInt 32 1 +%11 = OpConstant %10 1 +%1 = OpFunction %2 None %3 +%4 = OpLabel +OpSelectionMerge %5 None +OpBranchConditional %7 %5 %6 +%6 = OpLabel +%9 = OpIAdd %10 %11 %11 +OpBranch %5 +%5 = OpLabel +OpReturn +OpFunctionEnd +)"); +} + +TEST_F(SpvGeneratorImplTest, If_BothBranchesReturn) { + auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get()); + + auto* i = b.CreateIf(b.Constant(true)); + b.Branch(i->true_.target->As(), func->end_target); + b.Branch(i->false_.target->As(), func->end_target); + i->merge.target->As()->branch.target = nullptr; + + b.Branch(func->start_target, i); + + generator_.EmitFunction(func); + EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%9 = OpTypeBool +%8 = OpConstantTrue %9 +%1 = OpFunction %2 None %3 +%4 = OpLabel +OpSelectionMerge %5 None +OpBranchConditional %8 %6 %7 +%6 = OpLabel +OpReturn +%7 = OpLabel +OpReturn +%5 = OpLabel +OpUnreachable +OpFunctionEnd +)"); +} + +} // namespace +} // namespace tint::writer::spirv diff --git a/src/tint/writer/spirv/generator_impl_ir.cc b/src/tint/writer/spirv/generator_impl_ir.cc index 124955447c..120499525d 100644 --- a/src/tint/writer/spirv/generator_impl_ir.cc +++ b/src/tint/writer/spirv/generator_impl_ir.cc @@ -18,6 +18,7 @@ #include "src/tint/ir/binary.h" #include "src/tint/ir/block.h" #include "src/tint/ir/function_terminator.h" +#include "src/tint/ir/if.h" #include "src/tint/ir/module.h" #include "src/tint/ir/transform/add_empty_entry_point.h" #include "src/tint/switch.h" @@ -178,6 +179,10 @@ uint32_t GeneratorImplIr::Value(const ir::Value* value) { }); } +uint32_t GeneratorImplIr::Label(const ir::Block* block) { + return block_labels_.GetOrCreate(block, [&]() { return module_.NextId(); }); +} + void GeneratorImplIr::EmitFunction(const ir::Function* func) { // Make an ID for the function. auto id = module_.NextId(); @@ -255,6 +260,12 @@ void GeneratorImplIr::EmitEntryPoint(const ir::Function* func, uint32_t id) { } void GeneratorImplIr::EmitBlock(const ir::Block* block) { + // Emit the label. + // Skip if this is the function's entry block, as it will be emitted by the function object. + if (!current_function_.instructions().empty()) { + current_function_.push_inst(spv::Op::OpLabel, {Label(block)}); + } + // Emit the instructions. for (auto* inst : block->instructions) { auto result = Switch( @@ -271,6 +282,8 @@ void GeneratorImplIr::EmitBlock(const ir::Block* block) { // Handle the branch at the end of the block. Switch( block->branch.target, + [&](const ir::Block* b) { current_function_.push_inst(spv::Op::OpBranch, {Label(b)}); }, + [&](const ir::If* i) { EmitIf(i); }, [&](const ir::FunctionTerminator*) { // TODO(jrprice): Handle the return value, which will be a branch argument. if (!block->branch.args.IsEmpty()) { @@ -278,7 +291,52 @@ void GeneratorImplIr::EmitBlock(const ir::Block* block) { } current_function_.push_inst(spv::Op::OpReturn, {}); }, - [&](Default) { TINT_ICE(Writer, diagnostics_) << "unimplemented branch target"; }); + [&](Default) { + if (!block->branch.target) { + // A block may not have an outward branch (e.g. an unreachable merge block). + current_function_.push_inst(spv::Op::OpUnreachable, {}); + } else { + TINT_ICE(Writer, diagnostics_) + << "unimplemented branch target: " << block->branch.target->TypeInfo().name; + } + }); +} + +void GeneratorImplIr::EmitIf(const ir::If* i) { + auto* merge_block = i->merge.target->As(); + auto* true_block = i->true_.target->As(); + auto* false_block = i->false_.target->As(); + + // Generate labels for the blocks. We emit the true or false block if it: + // 1. contains instructions, or + // 2. branches somewhere other then the merge target. + // Otherwise we skip them and branch straight to the merge block. + uint32_t merge_label = Label(merge_block); + uint32_t true_label = merge_label; + uint32_t false_label = merge_label; + if (!true_block->instructions.IsEmpty() || true_block->branch.target != merge_block) { + true_label = Label(true_block); + } + if (!false_block->instructions.IsEmpty() || false_block->branch.target != merge_block) { + false_label = Label(false_block); + } + + // Emit the OpSelectionMerge and OpBranchConditional instructions. + current_function_.push_inst(spv::Op::OpSelectionMerge, + {merge_label, U32Operand(SpvSelectionControlMaskNone)}); + current_function_.push_inst(spv::Op::OpBranchConditional, + {Value(i->condition), true_label, false_label}); + + // Emit the `true` and `false` blocks, if they're not being skipped. + if (true_label != merge_label) { + EmitBlock(true_block); + } + if (false_label != merge_label) { + EmitBlock(false_block); + } + + // Emit the merge block. + EmitBlock(merge_block); } uint32_t GeneratorImplIr::EmitBinary(const ir::Binary* binary) { diff --git a/src/tint/writer/spirv/generator_impl_ir.h b/src/tint/writer/spirv/generator_impl_ir.h index d20754149a..bff13e6c76 100644 --- a/src/tint/writer/spirv/generator_impl_ir.h +++ b/src/tint/writer/spirv/generator_impl_ir.h @@ -30,6 +30,7 @@ namespace tint::ir { class Binary; class Block; +class If; class Function; class Module; class Value; @@ -76,6 +77,11 @@ class GeneratorImplIr { /// @returns the result ID of the value uint32_t Value(const ir::Value* value); + /// Get the ID of the label for `block`. + /// @param block the block to get the label ID for + /// @returns the ID of the block's label + uint32_t Label(const ir::Block* block); + /// Emit a function. /// @param func the function to emit void EmitFunction(const ir::Function* func); @@ -89,6 +95,10 @@ class GeneratorImplIr { /// @param block the block to emit void EmitBlock(const ir::Block* block); + /// Emit an `if` flow node. + /// @param i the if node to emit + void EmitIf(const ir::If* i); + /// Emit a binary instruction. /// @param binary the binary instruction to emit /// @returns the result ID of the instruction @@ -161,6 +171,9 @@ class GeneratorImplIr { /// The map of instructions to their result IDs. utils::Hashmap instructions_; + /// The map of blocks to the IDs of their label instructions. + utils::Hashmap block_labels_; + /// The current function that is being emitted. Function current_function_;