diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index fea93ca8e6..89fa7bbc7e 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -1862,6 +1862,7 @@ if (tint_build_unittests) { if (tint_build_ir) { sources += [ + "writer/spirv/generator_impl_constant_test.cc", "writer/spirv/generator_impl_function_test.cc", "writer/spirv/generator_impl_ir_test.cc", "writer/spirv/generator_impl_type_test.cc", diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index c4a8724147..4686ac0cfb 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -1229,6 +1229,7 @@ if(TINT_BUILD_TESTS) if(${TINT_BUILD_IR}) list(APPEND TINT_TEST_SRCS + writer/spirv/generator_impl_constant_test.cc writer/spirv/generator_impl_function_test.cc writer/spirv/generator_impl_ir_test.cc writer/spirv/generator_impl_type_test.cc diff --git a/src/tint/writer/spirv/generator_impl_constant_test.cc b/src/tint/writer/spirv/generator_impl_constant_test.cc new file mode 100644 index 0000000000..c8d7a10e96 --- /dev/null +++ b/src/tint/writer/spirv/generator_impl_constant_test.cc @@ -0,0 +1,76 @@ +// Copyright 2023 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/writer/spirv/test_helper_ir.h" + +namespace tint::writer::spirv { +namespace { + +TEST_F(SpvGeneratorImplTest, Type_Bool) { + generator_.Constant(Constant(true)); + generator_.Constant(Constant(false)); + EXPECT_EQ(DumpTypes(), R"(%2 = OpTypeBool +%1 = OpConstantTrue %2 +%3 = OpConstantFalse %2 +)"); +} + +TEST_F(SpvGeneratorImplTest, Constant_I32) { + generator_.Constant(Constant(i32(42))); + generator_.Constant(Constant(i32(-1))); + EXPECT_EQ(DumpTypes(), R"(%2 = OpTypeInt 32 1 +%1 = OpConstant %2 42 +%3 = OpConstant %2 -1 +)"); +} + +TEST_F(SpvGeneratorImplTest, Constant_U32) { + generator_.Constant(Constant(u32(42))); + generator_.Constant(Constant(u32(4000000000))); + EXPECT_EQ(DumpTypes(), R"(%2 = OpTypeInt 32 0 +%1 = OpConstant %2 42 +%3 = OpConstant %2 4000000000 +)"); +} + +TEST_F(SpvGeneratorImplTest, Constant_F32) { + generator_.Constant(Constant(f32(42))); + generator_.Constant(Constant(f32(-1))); + EXPECT_EQ(DumpTypes(), R"(%2 = OpTypeFloat 32 +%1 = OpConstant %2 42 +%3 = OpConstant %2 -1 +)"); +} + +TEST_F(SpvGeneratorImplTest, Constant_F16) { + generator_.Constant(Constant(f16(42))); + generator_.Constant(Constant(f16(-1))); + EXPECT_EQ(DumpTypes(), R"(%2 = OpTypeFloat 16 +%1 = OpConstant %2 0x1.5p+5 +%3 = OpConstant %2 -0x1p+0 +)"); +} + +// Test that we do not emit the same constant more than once. +TEST_F(SpvGeneratorImplTest, Constant_Deduplicate) { + generator_.Constant(Constant(i32(42))); + generator_.Constant(Constant(i32(42))); + generator_.Constant(Constant(i32(42))); + EXPECT_EQ(DumpTypes(), R"(%2 = OpTypeInt 32 1 +%1 = OpConstant %2 42 +)"); +} + +} // namespace +} // namespace tint::writer::spirv diff --git a/src/tint/writer/spirv/generator_impl_ir.cc b/src/tint/writer/spirv/generator_impl_ir.cc index e656e1a628..76de13a805 100644 --- a/src/tint/writer/spirv/generator_impl_ir.cc +++ b/src/tint/writer/spirv/generator_impl_ir.cc @@ -55,6 +55,40 @@ bool GeneratorImplIr::Generate() { return true; } +uint32_t GeneratorImplIr::Constant(const ir::Constant* constant) { + return constants_.GetOrCreate(constant, [&]() { + auto id = module_.NextId(); + auto* ty = constant->Type(); + auto* value = constant->value; + Switch( + ty, // + [&](const type::Bool*) { + module_.PushType( + value->ValueAs() ? spv::Op::OpConstantTrue : spv::Op::OpConstantFalse, + {Type(ty), id}); + }, + [&](const type::I32*) { + module_.PushType(spv::Op::OpConstant, {Type(ty), id, value->ValueAs()}); + }, + [&](const type::U32*) { + module_.PushType(spv::Op::OpConstant, + {Type(ty), id, U32Operand(value->ValueAs())}); + }, + [&](const type::F32*) { + module_.PushType(spv::Op::OpConstant, {Type(ty), id, value->ValueAs()}); + }, + [&](const type::F16*) { + module_.PushType( + spv::Op::OpConstant, + {Type(ty), id, U32Operand(value->ValueAs().BitsRepresentation())}); + }, + [&](Default) { + TINT_ICE(Writer, diagnostics_) << "unhandled constant type: " << ty->FriendlyName(); + }); + return id; + }); +} + uint32_t GeneratorImplIr::Type(const type::Type* ty) { return types_.GetOrCreate(ty, [&]() { auto id = module_.NextId(); diff --git a/src/tint/writer/spirv/generator_impl_ir.h b/src/tint/writer/spirv/generator_impl_ir.h index 0aa4e7bbd3..f7a6898f82 100644 --- a/src/tint/writer/spirv/generator_impl_ir.h +++ b/src/tint/writer/spirv/generator_impl_ir.h @@ -17,7 +17,9 @@ #include +#include "src/tint/constant/value.h" #include "src/tint/diagnostic/diagnostic.h" +#include "src/tint/ir/constant.h" #include "src/tint/utils/hashmap.h" #include "src/tint/utils/vector.h" #include "src/tint/writer/spirv/binary_writer.h" @@ -55,6 +57,11 @@ class GeneratorImplIr { /// @returns the list of diagnostics raised by the generator diag::List Diagnostics() const { return diagnostics_; } + /// Get the result ID of the constant `constant`, emitting its instruction if necessary. + /// @param constant the constant to get the ID for + /// @returns the result ID of the constant + uint32_t Constant(const ir::Constant* constant); + /// Get the result ID of the type `ty`, emitting a type declaration instruction if necessary. /// @param ty the type to get the ID for /// @returns the result ID of the type @@ -100,12 +107,34 @@ class GeneratorImplIr { } }; + /// ConstantHasher provides a hash function for an ir::Constant pointer, hashing the value + /// instead of the pointer itself. + struct ConstantHasher { + /// @param c the ir::Constant pointer to create a hash for + /// @return the hash value + inline std::size_t operator()(const ir::Constant* c) const { return c->value->Hash(); } + }; + + /// ConstantEquals provides an equality function for two ir::Constant pointers, comparing their + /// values instead of the pointers. + struct ConstantEquals { + /// @param a the first ir::Constant pointer to compare + /// @param b the second ir::Constant pointer to compare + /// @return the hash value + inline bool operator()(const ir::Constant* a, const ir::Constant* b) const { + return a->value->Equal(b->value); + } + }; + /// The map of types to their result IDs. utils::Hashmap types_; /// The map of function types to their result IDs. utils::Hashmap function_types_; + /// The map of constants to their result IDs. + utils::Hashmap constants_; + bool zero_init_workgroup_memory_ = false; };