[ir][spirv-writer] Emit `If` flow nodes
Adds support for both `If` and `Block` flow nodes as branch targets. Also support a nullptr branch target by emitting OpUnreachable. Bug: tint:1906 Change-Id: I1adea83ce6c7c85c6a2e2dae9327499cb7f850bd Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/132861 Reviewed-by: Dan Sinclair <dsinclair@chromium.org> Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: James Price <jrprice@google.com>
This commit is contained in:
parent
0dec17aeb6
commit
82db91ac96
|
@ -1940,6 +1940,7 @@ if (tint_build_unittests) {
|
|||
"writer/spirv/generator_impl_binary_test.cc",
|
||||
"writer/spirv/generator_impl_constant_test.cc",
|
||||
"writer/spirv/generator_impl_function_test.cc",
|
||||
"writer/spirv/generator_impl_if_test.cc",
|
||||
"writer/spirv/generator_impl_ir_test.cc",
|
||||
"writer/spirv/generator_impl_type_test.cc",
|
||||
]
|
||||
|
|
|
@ -1241,6 +1241,7 @@ if(TINT_BUILD_TESTS)
|
|||
writer/spirv/generator_impl_binary_test.cc
|
||||
writer/spirv/generator_impl_constant_test.cc
|
||||
writer/spirv/generator_impl_function_test.cc
|
||||
writer/spirv/generator_impl_if_test.cc
|
||||
writer/spirv/generator_impl_ir_test.cc
|
||||
writer/spirv/generator_impl_type_test.cc
|
||||
writer/spirv/test_helper_ir.h
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
// Copyright 2023 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/writer/spirv/test_helper_ir.h"
|
||||
|
||||
using namespace tint::number_suffixes; // NOLINT
|
||||
|
||||
namespace tint::writer::spirv {
|
||||
namespace {
|
||||
|
||||
TEST_F(SpvGeneratorImplTest, If_TrueEmpty_FalseEmpty) {
|
||||
auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
|
||||
|
||||
auto* i = b.CreateIf(b.Constant(true));
|
||||
b.Branch(i->true_.target->As<ir::Block>(), i->merge.target);
|
||||
b.Branch(i->false_.target->As<ir::Block>(), i->merge.target);
|
||||
b.Branch(i->merge.target->As<ir::Block>(), func->end_target);
|
||||
|
||||
b.Branch(func->start_target, i);
|
||||
|
||||
generator_.EmitFunction(func);
|
||||
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
|
||||
%2 = OpTypeVoid
|
||||
%3 = OpTypeFunction %2
|
||||
%7 = OpTypeBool
|
||||
%6 = OpConstantTrue %7
|
||||
%1 = OpFunction %2 None %3
|
||||
%4 = OpLabel
|
||||
OpSelectionMerge %5 None
|
||||
OpBranchConditional %6 %5 %5
|
||||
%5 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(SpvGeneratorImplTest, If_FalseEmpty) {
|
||||
auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
|
||||
|
||||
auto* i = b.CreateIf(b.Constant(true));
|
||||
b.Branch(i->false_.target->As<ir::Block>(), i->merge.target);
|
||||
b.Branch(i->merge.target->As<ir::Block>(), func->end_target);
|
||||
|
||||
auto* true_block = i->true_.target->As<ir::Block>();
|
||||
true_block->instructions.Push(
|
||||
b.Add(mod.types.Get<type::I32>(), b.Constant(1_i), b.Constant(1_i)));
|
||||
b.Branch(true_block, i->merge.target);
|
||||
|
||||
b.Branch(func->start_target, i);
|
||||
|
||||
generator_.EmitFunction(func);
|
||||
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
|
||||
%2 = OpTypeVoid
|
||||
%3 = OpTypeFunction %2
|
||||
%8 = OpTypeBool
|
||||
%7 = OpConstantTrue %8
|
||||
%10 = OpTypeInt 32 1
|
||||
%11 = OpConstant %10 1
|
||||
%1 = OpFunction %2 None %3
|
||||
%4 = OpLabel
|
||||
OpSelectionMerge %5 None
|
||||
OpBranchConditional %7 %6 %5
|
||||
%6 = OpLabel
|
||||
%9 = OpIAdd %10 %11 %11
|
||||
OpBranch %5
|
||||
%5 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(SpvGeneratorImplTest, If_TrueEmpty) {
|
||||
auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
|
||||
|
||||
auto* i = b.CreateIf(b.Constant(true));
|
||||
b.Branch(i->true_.target->As<ir::Block>(), i->merge.target);
|
||||
b.Branch(i->merge.target->As<ir::Block>(), func->end_target);
|
||||
|
||||
auto* false_block = i->false_.target->As<ir::Block>();
|
||||
false_block->instructions.Push(
|
||||
b.Add(mod.types.Get<type::I32>(), b.Constant(1_i), b.Constant(1_i)));
|
||||
b.Branch(false_block, i->merge.target);
|
||||
|
||||
b.Branch(func->start_target, i);
|
||||
|
||||
generator_.EmitFunction(func);
|
||||
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
|
||||
%2 = OpTypeVoid
|
||||
%3 = OpTypeFunction %2
|
||||
%8 = OpTypeBool
|
||||
%7 = OpConstantTrue %8
|
||||
%10 = OpTypeInt 32 1
|
||||
%11 = OpConstant %10 1
|
||||
%1 = OpFunction %2 None %3
|
||||
%4 = OpLabel
|
||||
OpSelectionMerge %5 None
|
||||
OpBranchConditional %7 %5 %6
|
||||
%6 = OpLabel
|
||||
%9 = OpIAdd %10 %11 %11
|
||||
OpBranch %5
|
||||
%5 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(SpvGeneratorImplTest, If_BothBranchesReturn) {
|
||||
auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
|
||||
|
||||
auto* i = b.CreateIf(b.Constant(true));
|
||||
b.Branch(i->true_.target->As<ir::Block>(), func->end_target);
|
||||
b.Branch(i->false_.target->As<ir::Block>(), func->end_target);
|
||||
i->merge.target->As<ir::Block>()->branch.target = nullptr;
|
||||
|
||||
b.Branch(func->start_target, i);
|
||||
|
||||
generator_.EmitFunction(func);
|
||||
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
|
||||
%2 = OpTypeVoid
|
||||
%3 = OpTypeFunction %2
|
||||
%9 = OpTypeBool
|
||||
%8 = OpConstantTrue %9
|
||||
%1 = OpFunction %2 None %3
|
||||
%4 = OpLabel
|
||||
OpSelectionMerge %5 None
|
||||
OpBranchConditional %8 %6 %7
|
||||
%6 = OpLabel
|
||||
OpReturn
|
||||
%7 = OpLabel
|
||||
OpReturn
|
||||
%5 = OpLabel
|
||||
OpUnreachable
|
||||
OpFunctionEnd
|
||||
)");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tint::writer::spirv
|
|
@ -18,6 +18,7 @@
|
|||
#include "src/tint/ir/binary.h"
|
||||
#include "src/tint/ir/block.h"
|
||||
#include "src/tint/ir/function_terminator.h"
|
||||
#include "src/tint/ir/if.h"
|
||||
#include "src/tint/ir/module.h"
|
||||
#include "src/tint/ir/transform/add_empty_entry_point.h"
|
||||
#include "src/tint/switch.h"
|
||||
|
@ -178,6 +179,10 @@ uint32_t GeneratorImplIr::Value(const ir::Value* value) {
|
|||
});
|
||||
}
|
||||
|
||||
uint32_t GeneratorImplIr::Label(const ir::Block* block) {
|
||||
return block_labels_.GetOrCreate(block, [&]() { return module_.NextId(); });
|
||||
}
|
||||
|
||||
void GeneratorImplIr::EmitFunction(const ir::Function* func) {
|
||||
// Make an ID for the function.
|
||||
auto id = module_.NextId();
|
||||
|
@ -255,6 +260,12 @@ void GeneratorImplIr::EmitEntryPoint(const ir::Function* func, uint32_t id) {
|
|||
}
|
||||
|
||||
void GeneratorImplIr::EmitBlock(const ir::Block* block) {
|
||||
// Emit the label.
|
||||
// Skip if this is the function's entry block, as it will be emitted by the function object.
|
||||
if (!current_function_.instructions().empty()) {
|
||||
current_function_.push_inst(spv::Op::OpLabel, {Label(block)});
|
||||
}
|
||||
|
||||
// Emit the instructions.
|
||||
for (auto* inst : block->instructions) {
|
||||
auto result = Switch(
|
||||
|
@ -271,6 +282,8 @@ void GeneratorImplIr::EmitBlock(const ir::Block* block) {
|
|||
// Handle the branch at the end of the block.
|
||||
Switch(
|
||||
block->branch.target,
|
||||
[&](const ir::Block* b) { current_function_.push_inst(spv::Op::OpBranch, {Label(b)}); },
|
||||
[&](const ir::If* i) { EmitIf(i); },
|
||||
[&](const ir::FunctionTerminator*) {
|
||||
// TODO(jrprice): Handle the return value, which will be a branch argument.
|
||||
if (!block->branch.args.IsEmpty()) {
|
||||
|
@ -278,7 +291,52 @@ void GeneratorImplIr::EmitBlock(const ir::Block* block) {
|
|||
}
|
||||
current_function_.push_inst(spv::Op::OpReturn, {});
|
||||
},
|
||||
[&](Default) { TINT_ICE(Writer, diagnostics_) << "unimplemented branch target"; });
|
||||
[&](Default) {
|
||||
if (!block->branch.target) {
|
||||
// A block may not have an outward branch (e.g. an unreachable merge block).
|
||||
current_function_.push_inst(spv::Op::OpUnreachable, {});
|
||||
} else {
|
||||
TINT_ICE(Writer, diagnostics_)
|
||||
<< "unimplemented branch target: " << block->branch.target->TypeInfo().name;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void GeneratorImplIr::EmitIf(const ir::If* i) {
|
||||
auto* merge_block = i->merge.target->As<ir::Block>();
|
||||
auto* true_block = i->true_.target->As<ir::Block>();
|
||||
auto* false_block = i->false_.target->As<ir::Block>();
|
||||
|
||||
// Generate labels for the blocks. We emit the true or false block if it:
|
||||
// 1. contains instructions, or
|
||||
// 2. branches somewhere other then the merge target.
|
||||
// Otherwise we skip them and branch straight to the merge block.
|
||||
uint32_t merge_label = Label(merge_block);
|
||||
uint32_t true_label = merge_label;
|
||||
uint32_t false_label = merge_label;
|
||||
if (!true_block->instructions.IsEmpty() || true_block->branch.target != merge_block) {
|
||||
true_label = Label(true_block);
|
||||
}
|
||||
if (!false_block->instructions.IsEmpty() || false_block->branch.target != merge_block) {
|
||||
false_label = Label(false_block);
|
||||
}
|
||||
|
||||
// Emit the OpSelectionMerge and OpBranchConditional instructions.
|
||||
current_function_.push_inst(spv::Op::OpSelectionMerge,
|
||||
{merge_label, U32Operand(SpvSelectionControlMaskNone)});
|
||||
current_function_.push_inst(spv::Op::OpBranchConditional,
|
||||
{Value(i->condition), true_label, false_label});
|
||||
|
||||
// Emit the `true` and `false` blocks, if they're not being skipped.
|
||||
if (true_label != merge_label) {
|
||||
EmitBlock(true_block);
|
||||
}
|
||||
if (false_label != merge_label) {
|
||||
EmitBlock(false_block);
|
||||
}
|
||||
|
||||
// Emit the merge block.
|
||||
EmitBlock(merge_block);
|
||||
}
|
||||
|
||||
uint32_t GeneratorImplIr::EmitBinary(const ir::Binary* binary) {
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
namespace tint::ir {
|
||||
class Binary;
|
||||
class Block;
|
||||
class If;
|
||||
class Function;
|
||||
class Module;
|
||||
class Value;
|
||||
|
@ -76,6 +77,11 @@ class GeneratorImplIr {
|
|||
/// @returns the result ID of the value
|
||||
uint32_t Value(const ir::Value* value);
|
||||
|
||||
/// Get the ID of the label for `block`.
|
||||
/// @param block the block to get the label ID for
|
||||
/// @returns the ID of the block's label
|
||||
uint32_t Label(const ir::Block* block);
|
||||
|
||||
/// Emit a function.
|
||||
/// @param func the function to emit
|
||||
void EmitFunction(const ir::Function* func);
|
||||
|
@ -89,6 +95,10 @@ class GeneratorImplIr {
|
|||
/// @param block the block to emit
|
||||
void EmitBlock(const ir::Block* block);
|
||||
|
||||
/// Emit an `if` flow node.
|
||||
/// @param i the if node to emit
|
||||
void EmitIf(const ir::If* i);
|
||||
|
||||
/// Emit a binary instruction.
|
||||
/// @param binary the binary instruction to emit
|
||||
/// @returns the result ID of the instruction
|
||||
|
@ -161,6 +171,9 @@ class GeneratorImplIr {
|
|||
/// The map of instructions to their result IDs.
|
||||
utils::Hashmap<const ir::Instruction*, uint32_t, 8> instructions_;
|
||||
|
||||
/// The map of blocks to the IDs of their label instructions.
|
||||
utils::Hashmap<const ir::Block*, uint32_t, 8> block_labels_;
|
||||
|
||||
/// The current function that is being emitted.
|
||||
Function current_function_;
|
||||
|
||||
|
|
Loading…
Reference in New Issue