[ir][spirv-writer] Implement user function calls

Bug: tint:1906
Change-Id: Icf9a0a00409b61d3c8baa844b66865c1a4dd9b69
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/134202
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
James Price 2023-05-24 17:42:35 +00:00 committed by Dawn LUCI CQ
parent c1fd6316de
commit b54b58d57d
3 changed files with 92 additions and 0 deletions

View File

@ -25,6 +25,7 @@
#include "src/tint/ir/module.h" #include "src/tint/ir/module.h"
#include "src/tint/ir/store.h" #include "src/tint/ir/store.h"
#include "src/tint/ir/transform/add_empty_entry_point.h" #include "src/tint/ir/transform/add_empty_entry_point.h"
#include "src/tint/ir/user_call.h"
#include "src/tint/ir/var.h" #include "src/tint/ir/var.h"
#include "src/tint/switch.h" #include "src/tint/switch.h"
#include "src/tint/transform/manager.h" #include "src/tint/transform/manager.h"
@ -210,6 +211,7 @@ uint32_t GeneratorImplIr::Label(const ir::Block* block) {
void GeneratorImplIr::EmitFunction(const ir::Function* func) { void GeneratorImplIr::EmitFunction(const ir::Function* func) {
// Make an ID for the function. // Make an ID for the function.
auto id = module_.NextId(); auto id = module_.NextId();
functions_.Add(func->Name(), id);
// Emit the function name. // Emit the function name.
module_.PushDebug(spv::Op::OpName, {id, Operand(func->Name().Name())}); module_.PushDebug(spv::Op::OpName, {id, Operand(func->Name().Name())});
@ -319,6 +321,7 @@ void GeneratorImplIr::EmitBlock(const ir::Block* block) {
EmitStore(s); EmitStore(s);
return 0u; return 0u;
}, },
[&](const ir::UserCall* c) { return EmitUserCall(c); },
[&](const ir::Var* v) { return EmitVar(v); }, [&](const ir::Var* v) { return EmitVar(v); },
[&](const ir::If* i) { [&](const ir::If* i) {
EmitIf(i); EmitIf(i);
@ -432,6 +435,16 @@ void GeneratorImplIr::EmitStore(const ir::Store* store) {
current_function_.push_inst(spv::Op::OpStore, {Value(store->To()), Value(store->From())}); current_function_.push_inst(spv::Op::OpStore, {Value(store->To()), Value(store->From())});
} }
uint32_t GeneratorImplIr::EmitUserCall(const ir::UserCall* call) {
auto id = module_.NextId();
OperandList operands = {Type(call->Type()), id, functions_.Get(call->Name()).value()};
for (auto* arg : call->Args()) {
operands.push_back(Value(arg));
}
current_function_.push_inst(spv::Op::OpFunctionCall, operands);
return id;
}
uint32_t GeneratorImplIr::EmitVar(const ir::Var* var) { uint32_t GeneratorImplIr::EmitVar(const ir::Var* var) {
auto id = module_.NextId(); auto id = module_.NextId();
auto* ptr = var->Type()->As<type::Pointer>(); auto* ptr = var->Type()->As<type::Pointer>();

View File

@ -20,6 +20,7 @@
#include "src/tint/constant/value.h" #include "src/tint/constant/value.h"
#include "src/tint/diagnostic/diagnostic.h" #include "src/tint/diagnostic/diagnostic.h"
#include "src/tint/ir/constant.h" #include "src/tint/ir/constant.h"
#include "src/tint/symbol.h"
#include "src/tint/utils/hashmap.h" #include "src/tint/utils/hashmap.h"
#include "src/tint/utils/vector.h" #include "src/tint/utils/vector.h"
#include "src/tint/writer/spirv/binary_writer.h" #include "src/tint/writer/spirv/binary_writer.h"
@ -36,6 +37,7 @@ class Function;
class Load; class Load;
class Module; class Module;
class Store; class Store;
class UserCall;
class Value; class Value;
class Var; class Var;
} // namespace tint::ir } // namespace tint::ir
@ -117,6 +119,11 @@ class GeneratorImplIr {
/// @param store the store instruction to emit /// @param store the store instruction to emit
void EmitStore(const ir::Store* store); void EmitStore(const ir::Store* store);
/// Emit a user call instruction.
/// @param call the user call instruction to emit
/// @returns the result ID of the instruction
uint32_t EmitUserCall(const ir::UserCall* call);
/// Emit a var instruction. /// Emit a var instruction.
/// @param var the var instruction to emit /// @param var the var instruction to emit
/// @returns the result ID of the instruction /// @returns the result ID of the instruction
@ -171,6 +178,10 @@ class GeneratorImplIr {
/// The map of constants to their result IDs. /// The map of constants to their result IDs.
utils::Hashmap<const constant::Value*, uint32_t, 16> constants_; utils::Hashmap<const constant::Value*, uint32_t, 16> constants_;
/// The map of functions to their result IDs.
/// TODO(jrprice): Merge into `values_` map when `ir::Function` becomes an `ir::Value`.
utils::Hashmap<Symbol, uint32_t, 8> functions_;
/// The map of non-constant values to their result IDs. /// The map of non-constant values to their result IDs.
utils::Hashmap<const ir::Value*, uint32_t, 8> values_; utils::Hashmap<const ir::Value*, uint32_t, 8> values_;

View File

@ -183,5 +183,73 @@ OpFunctionEnd
)"); )");
} }
TEST_F(SpvGeneratorImplTest, Function_Call) {
auto* i32_ty = mod.types.i32();
auto* x = b.FunctionParam(i32_ty);
auto* y = b.FunctionParam(i32_ty);
auto* result = b.Add(i32_ty, x, y);
auto* foo = b.CreateFunction("foo", i32_ty);
foo->SetParams(utils::Vector{x, y});
foo->StartTarget()->SetInstructions(
utils::Vector{result, b.Branch(foo->EndTarget(), utils::Vector{result})});
auto* bar = b.CreateFunction("bar", mod.types.void_());
bar->StartTarget()->SetInstructions(
utils::Vector{b.UserCall(i32_ty, mod.symbols.Get("foo"),
utils::Vector{b.Constant(i32(2)), b.Constant(i32(3))}),
b.Branch(bar->EndTarget())});
generator_.EmitFunction(foo);
generator_.EmitFunction(bar);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
OpName %8 "bar"
%2 = OpTypeInt 32 1
%5 = OpTypeFunction %2 %2 %2
%9 = OpTypeVoid
%10 = OpTypeFunction %9
%13 = OpConstant %2 2
%14 = OpConstant %2 3
%1 = OpFunction %2 None %5
%3 = OpFunctionParameter %2
%4 = OpFunctionParameter %2
%6 = OpLabel
%7 = OpIAdd %2 %3 %4
OpReturnValue %7
OpFunctionEnd
%8 = OpFunction %9 None %10
%11 = OpLabel
%12 = OpFunctionCall %2 %1 %13 %14
OpReturn
OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Function_Call_Void) {
auto* foo = b.CreateFunction("foo", mod.types.void_());
foo->StartTarget()->SetInstructions(utils::Vector{b.Branch(foo->EndTarget())});
auto* bar = b.CreateFunction("bar", mod.types.void_());
bar->StartTarget()->SetInstructions(
utils::Vector{b.UserCall(mod.types.void_(), mod.symbols.Get("foo"), utils::Empty),
b.Branch(bar->EndTarget())});
generator_.EmitFunction(foo);
generator_.EmitFunction(bar);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
OpName %5 "bar"
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%1 = OpFunction %2 None %3
%4 = OpLabel
OpReturn
OpFunctionEnd
%5 = OpFunction %2 None %3
%6 = OpLabel
%7 = OpFunctionCall %2 %1
OpReturn
OpFunctionEnd
)");
}
} // namespace } // namespace
} // namespace tint::writer::spirv } // namespace tint::writer::spirv