diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index 639771ed53..47efd523fd 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -1949,6 +1949,7 @@ if (tint_build_unittests) { "writer/spirv/ir/generator_impl_ir_if_test.cc", "writer/spirv/ir/generator_impl_ir_test.cc", "writer/spirv/ir/generator_impl_ir_type_test.cc", + "writer/spirv/ir/generator_impl_ir_var_test.cc", "writer/spirv/ir/test_helper_ir.h", ] deps += [ ":libtint_ir_src" ] diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index 966718aa02..d0c7487a56 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -1246,6 +1246,7 @@ if(TINT_BUILD_TESTS) writer/spirv/ir/generator_impl_ir_if_test.cc writer/spirv/ir/generator_impl_ir_test.cc writer/spirv/ir/generator_impl_ir_type_test.cc + writer/spirv/ir/generator_impl_ir_var_test.cc writer/spirv/ir/test_helper_ir.h ) endif() diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc index b16d087034..e55443acbc 100644 --- a/src/tint/writer/spirv/ir/generator_impl_ir.cc +++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc @@ -21,12 +21,14 @@ #include "src/tint/ir/if.h" #include "src/tint/ir/module.h" #include "src/tint/ir/transform/add_empty_entry_point.h" +#include "src/tint/ir/var.h" #include "src/tint/switch.h" #include "src/tint/transform/manager.h" #include "src/tint/type/bool.h" #include "src/tint/type/f16.h" #include "src/tint/type/f32.h" #include "src/tint/type/i32.h" +#include "src/tint/type/pointer.h" #include "src/tint/type/type.h" #include "src/tint/type/u32.h" #include "src/tint/type/vector.h" @@ -48,6 +50,23 @@ void Sanitize(ir::Module* module) { manager.Run(module, data, outputs); } +SpvStorageClass StorageClass(builtin::AddressSpace addrspace) { + switch (addrspace) { + case builtin::AddressSpace::kFunction: + return SpvStorageClassFunction; + case builtin::AddressSpace::kPrivate: + return SpvStorageClassPrivate; + case builtin::AddressSpace::kStorage: + return SpvStorageClassStorageBuffer; + case builtin::AddressSpace::kUniform: + return SpvStorageClassUniform; + case builtin::AddressSpace::kWorkgroup: + return SpvStorageClassWorkgroup; + default: + return SpvStorageClassMax; + } +} + } // namespace GeneratorImplIr::GeneratorImplIr(ir::Module* module, bool zero_init_workgroup_mem) @@ -154,6 +173,11 @@ uint32_t GeneratorImplIr::Type(const type::Type* ty) { [&](const type::Vector* vec) { module_.PushType(spv::Op::OpTypeVector, {id, Type(vec->type()), vec->Width()}); }, + [&](const type::Pointer* ptr) { + module_.PushType( + spv::Op::OpTypePointer, + {id, U32Operand(StorageClass(ptr->AddressSpace())), Type(ptr->StoreType())}); + }, [&](Default) { TINT_ICE(Writer, diagnostics_) << "unhandled type: " << ty->FriendlyName(); }); @@ -271,6 +295,7 @@ void GeneratorImplIr::EmitBlock(const ir::Block* block) { auto result = Switch( inst, // [&](const ir::Binary* b) { return EmitBinary(b); }, + [&](const ir::Var* v) { return EmitVar(v); }, [&](Default) { TINT_ICE(Writer, diagnostics_) << "unimplemented instruction: " << inst->TypeInfo().name; @@ -366,4 +391,30 @@ uint32_t GeneratorImplIr::EmitBinary(const ir::Binary* binary) { return id; } +uint32_t GeneratorImplIr::EmitVar(const ir::Var* var) { + auto id = module_.NextId(); + auto* ptr = var->Type()->As(); + TINT_ASSERT(Writer, ptr); + auto ty = Type(ptr); + + if (ptr->AddressSpace() == builtin::AddressSpace::kFunction) { + TINT_ASSERT(Writer, current_function_); + current_function_.push_var({ty, id, U32Operand(SpvStorageClassFunction)}); + if (var->initializer) { + current_function_.push_inst(spv::Op::OpStore, {id, Value(var->initializer)}); + } + } else { + TINT_ICE(Writer, diagnostics_) + << "unimplemented variable address space " << ptr->AddressSpace(); + return 0u; + } + + // Set the name if present. + if (auto name = ir_->NameOf(var)) { + module_.PushDebug(spv::Op::OpName, {id, Operand(name.Name())}); + } + + return id; +} + } // namespace tint::writer::spirv diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.h b/src/tint/writer/spirv/ir/generator_impl_ir.h index b18cf9565e..437edb243c 100644 --- a/src/tint/writer/spirv/ir/generator_impl_ir.h +++ b/src/tint/writer/spirv/ir/generator_impl_ir.h @@ -34,6 +34,7 @@ class If; class Function; class Module; class Value; +class Var; } // namespace tint::ir namespace tint::type { class Type; @@ -104,6 +105,11 @@ class GeneratorImplIr { /// @returns the result ID of the instruction uint32_t EmitBinary(const ir::Binary* binary); + /// Emit a var instruction. + /// @param var the var instruction to emit + /// @returns the result ID of the instruction + uint32_t EmitVar(const ir::Var* var); + private: /// Get the result ID of the constant `constant`, emitting its instruction if necessary. /// @param constant the constant to get the ID for diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc new file mode 100644 index 0000000000..ecf8bf7574 --- /dev/null +++ b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc @@ -0,0 +1,142 @@ +// Copyright 2023 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/type/pointer.h" +#include "src/tint/writer/spirv/ir/test_helper_ir.h" + +using namespace tint::number_suffixes; // NOLINT + +namespace tint::writer::spirv { +namespace { + +TEST_F(SpvGeneratorImplTest, FunctionVar_NoInit) { + auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get()); + b.Branch(func->start_target, func->end_target); + + auto* ty = mod.types.Get( + mod.types.Get(), builtin::AddressSpace::kFunction, builtin::Access::kReadWrite); + auto* v = b.Declare(ty); + func->start_target->instructions.Push(v); + + generator_.EmitFunction(func); + EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%7 = OpTypeInt 32 1 +%6 = OpTypePointer Function %7 +%1 = OpFunction %2 None %3 +%4 = OpLabel +%5 = OpVariable %6 Function +OpReturn +OpFunctionEnd +)"); +} + +TEST_F(SpvGeneratorImplTest, FunctionVar_WithInit) { + auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get()); + b.Branch(func->start_target, func->end_target); + + auto* ty = mod.types.Get( + mod.types.Get(), builtin::AddressSpace::kFunction, builtin::Access::kReadWrite); + auto* v = b.Declare(ty); + func->start_target->instructions.Push(v); + v->initializer = b.Constant(42_i); + + generator_.EmitFunction(func); + EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%7 = OpTypeInt 32 1 +%6 = OpTypePointer Function %7 +%8 = OpConstant %7 42 +%1 = OpFunction %2 None %3 +%4 = OpLabel +%5 = OpVariable %6 Function +OpStore %5 %8 +OpReturn +OpFunctionEnd +)"); +} + +TEST_F(SpvGeneratorImplTest, FunctionVar_Name) { + auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get()); + b.Branch(func->start_target, func->end_target); + + auto* ty = mod.types.Get( + mod.types.Get(), builtin::AddressSpace::kFunction, builtin::Access::kReadWrite); + auto* v = b.Declare(ty); + func->start_target->instructions.Push(v); + mod.SetName(v, "myvar"); + + generator_.EmitFunction(func); + EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" +OpName %5 "myvar" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%7 = OpTypeInt 32 1 +%6 = OpTypePointer Function %7 +%1 = OpFunction %2 None %3 +%4 = OpLabel +%5 = OpVariable %6 Function +OpReturn +OpFunctionEnd +)"); +} + +TEST_F(SpvGeneratorImplTest, FunctionVar_DeclInsideBlock) { + auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get()); + b.Branch(func->start_target, func->end_target); + + auto* ty = mod.types.Get( + mod.types.Get(), builtin::AddressSpace::kFunction, builtin::Access::kReadWrite); + auto* v = b.Declare(ty); + v->initializer = b.Constant(42_i); + + auto* i = b.CreateIf(b.Constant(true)); + b.Branch(i->false_.target->As(), func->end_target); + b.Branch(i->merge.target->As(), func->end_target); + + auto* true_block = i->true_.target->As(); + true_block->instructions.Push(v); + b.Branch(true_block, i->merge.target); + + b.Branch(func->start_target, i); + + generator_.EmitFunction(func); + EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%9 = OpTypeBool +%8 = OpConstantTrue %9 +%12 = OpTypeInt 32 1 +%11 = OpTypePointer Function %12 +%13 = OpConstant %12 42 +%1 = OpFunction %2 None %3 +%4 = OpLabel +%10 = OpVariable %11 Function +OpSelectionMerge %5 None +OpBranchConditional %8 %6 %7 +%6 = OpLabel +OpStore %10 %13 +OpBranch %5 +%7 = OpLabel +OpReturn +%5 = OpLabel +OpReturn +OpFunctionEnd +)"); +} + +} // namespace +} // namespace tint::writer::spirv