From bcf4174c06eca0a3b19e10d087f276fd0dfc4641 Mon Sep 17 00:00:00 2001 From: James Price Date: Fri, 26 May 2023 15:10:30 +0000 Subject: [PATCH] [ir][spirv-writer] Emit builtin function calls Add support for `abs()`, `max()`, and `min()`. Import the GLSL extended instruction set the first time it is needed. Move testing utilities from the binary tests into the base test helper class, as they are more widely useful. Bug: tint:1906 Change-Id: I5faa928b98c621afcca770cb14a8f9c06f36bcfe Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/134521 Reviewed-by: Ben Clayton Kokoro: Kokoro Commit-Queue: James Price --- src/tint/BUILD.gn | 1 + src/tint/CMakeLists.txt | 1 + src/tint/writer/spirv/ir/generator_impl_ir.cc | 70 +++++++++ src/tint/writer/spirv/ir/generator_impl_ir.h | 9 ++ .../spirv/ir/generator_impl_ir_binary_test.cc | 93 +---------- .../ir/generator_impl_ir_builtin_test.cc | 145 ++++++++++++++++++ src/tint/writer/spirv/ir/test_helper_ir.h | 88 +++++++++++ 7 files changed, 318 insertions(+), 89 deletions(-) create mode 100644 src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index 5970683500..f15e7e29f5 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -1964,6 +1964,7 @@ if (tint_build_unittests) { if (tint_build_ir) { sources += [ "writer/spirv/ir/generator_impl_ir_binary_test.cc", + "writer/spirv/ir/generator_impl_ir_builtin_test.cc", "writer/spirv/ir/generator_impl_ir_constant_test.cc", "writer/spirv/ir/generator_impl_ir_function_test.cc", "writer/spirv/ir/generator_impl_ir_if_test.cc", diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index 17fa2d3fab..7217a3513d 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -1254,6 +1254,7 @@ if(TINT_BUILD_TESTS) if(${TINT_BUILD_IR}) list(APPEND TINT_TEST_SRCS writer/spirv/ir/generator_impl_ir_binary_test.cc + writer/spirv/ir/generator_impl_ir_builtin_test.cc writer/spirv/ir/generator_impl_ir_constant_test.cc writer/spirv/ir/generator_impl_ir_function_test.cc writer/spirv/ir/generator_impl_ir_if_test.cc diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc index 7da87c1305..4bc7955fba 100644 --- a/src/tint/writer/spirv/ir/generator_impl_ir.cc +++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc @@ -16,9 +16,11 @@ #include +#include "spirv/unified1/GLSL.std.450.h" #include "spirv/unified1/spirv.h" #include "src/tint/ir/binary.h" #include "src/tint/ir/block.h" +#include "src/tint/ir/builtin.h" #include "src/tint/ir/exit_if.h" #include "src/tint/ir/if.h" #include "src/tint/ir/load.h" @@ -330,6 +332,7 @@ void GeneratorImplIr::EmitBlock(const ir::Block* block) { auto result = Switch( inst, // [&](const ir::Binary* b) { return EmitBinary(b); }, + [&](const ir::Builtin* b) { return EmitBuiltin(b); }, [&](const ir::Load* l) { return EmitLoad(l); }, [&](const ir::Store* s) { EmitStore(s); @@ -518,6 +521,73 @@ uint32_t GeneratorImplIr::EmitBinary(const ir::Binary* binary) { return id; } +uint32_t GeneratorImplIr::EmitBuiltin(const ir::Builtin* builtin) { + auto id = module_.NextId(); + auto* result_ty = builtin->Type(); + + spv::Op op = spv::Op::Max; + OperandList operands = {Type(result_ty), id}; + + // Helper to set up the opcode and operand list for a GLSL extended instruction. + auto glsl_ext_inst = [&](enum GLSLstd450 inst) { + constexpr const char* kGLSLstd450 = "GLSL.std.450"; + op = spv::Op::OpExtInst; + operands.push_back(imports_.GetOrCreate(kGLSLstd450, [&]() { + // Import the instruction set the first time it is requested. + auto import = module_.NextId(); + module_.PushExtImport(spv::Op::OpExtInstImport, {import, Operand(kGLSLstd450)}); + return import; + })); + operands.push_back(U32Operand(inst)); + }; + + // Determine the opcode. + switch (builtin->Func()) { + case builtin::Function::kAbs: + if (result_ty->is_float_scalar_or_vector()) { + glsl_ext_inst(GLSLstd450FAbs); + } else if (result_ty->is_signed_integer_scalar_or_vector()) { + glsl_ext_inst(GLSLstd450SAbs); + } else if (result_ty->is_unsigned_integer_scalar_or_vector()) { + // abs() is a no-op for unsigned integers. + return Value(builtin->Args()[0]); + } + break; + case builtin::Function::kMax: + if (result_ty->is_float_scalar_or_vector()) { + glsl_ext_inst(GLSLstd450FMax); + } else if (result_ty->is_signed_integer_scalar_or_vector()) { + glsl_ext_inst(GLSLstd450SMax); + } else if (result_ty->is_unsigned_integer_scalar_or_vector()) { + glsl_ext_inst(GLSLstd450UMax); + } + break; + case builtin::Function::kMin: + if (result_ty->is_float_scalar_or_vector()) { + glsl_ext_inst(GLSLstd450FMin); + } else if (result_ty->is_signed_integer_scalar_or_vector()) { + glsl_ext_inst(GLSLstd450SMin); + } else if (result_ty->is_unsigned_integer_scalar_or_vector()) { + glsl_ext_inst(GLSLstd450UMin); + } + break; + default: + TINT_ICE(Writer, diagnostics_) << "unimplemented builtin function: " << builtin->Func(); + return 0u; + } + TINT_ASSERT(Writer, op != spv::Op::Max); + + // Add the arguments to the builtin call. + for (auto* arg : builtin->Args()) { + operands.push_back(Value(arg)); + } + + // Emit the instruction. + current_function_.push_inst(op, operands); + + return id; +} + uint32_t GeneratorImplIr::EmitLoad(const ir::Load* load) { auto id = module_.NextId(); current_function_.push_inst(spv::Op::OpLoad, {Type(load->Type()), id, Value(load->From())}); diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.h b/src/tint/writer/spirv/ir/generator_impl_ir.h index f4ababac7c..d42cce694e 100644 --- a/src/tint/writer/spirv/ir/generator_impl_ir.h +++ b/src/tint/writer/spirv/ir/generator_impl_ir.h @@ -32,6 +32,7 @@ namespace tint::ir { class Binary; class Block; class Branch; +class Builtin; class If; class Function; class Load; @@ -110,6 +111,11 @@ class GeneratorImplIr { /// @returns the result ID of the instruction uint32_t EmitBinary(const ir::Binary* binary); + /// Emit a builtin function call instruction. + /// @param call the builtin call instruction to emit + /// @returns the result ID of the instruction + uint32_t EmitBuiltin(const ir::Builtin* call); + /// Emit a load instruction. /// @param load the load instruction to emit /// @returns the result ID of the instruction @@ -184,6 +190,9 @@ class GeneratorImplIr { /// The map of blocks to the IDs of their label instructions. utils::Hashmap block_labels_; + /// The map of extended instruction set names to their result IDs. + utils::Hashmap imports_; + /// The current function that is being emitted. Function current_function_; diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc index cfe3dc4eea..f6e501546b 100644 --- a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc +++ b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc @@ -22,102 +22,17 @@ using namespace tint::number_suffixes; // NOLINT namespace tint::writer::spirv { namespace { -/// The element type of a test. -enum Type { - kBool, - kI32, - kU32, - kF32, - kF16, -}; - /// A parameterized test case. struct BinaryTestCase { /// The element type to test. - Type type; + TestElementType type; /// The binary operation. enum ir::Binary::Kind kind; /// The expected SPIR-V instruction. std::string spirv_inst; }; -/// A helper class for parameterized binary instruction tests. -class BinaryInstructionTest : public SpvGeneratorImplTestWithParam { - protected: - /// Helper to make a scalar type corresponding to the element type `ty`. - /// @param ty the element type - /// @returns the scalar type - const type::Type* MakeScalarType(Type ty) { - switch (ty) { - case kBool: - return mod.Types().bool_(); - case kI32: - return mod.Types().i32(); - case kU32: - return mod.Types().u32(); - case kF32: - return mod.Types().f32(); - case kF16: - return mod.Types().f16(); - } - return nullptr; - } - - /// Helper to make a vector type corresponding to the element type `ty`. - /// @param ty the element type - /// @returns the vector type - const type::Type* MakeVectorType(Type ty) { return mod.Types().vec2(MakeScalarType(ty)); } - - /// Helper to make a scalar value with the scalar type `ty`. - /// @param ty the element type - /// @returns the scalar value - ir::Value* MakeScalarValue(Type ty) { - switch (ty) { - case kBool: - return b.Constant(true); - case kI32: - return b.Constant(1_i); - case kU32: - return b.Constant(1_u); - case kF32: - return b.Constant(1_f); - case kF16: - return b.Constant(1_h); - } - return nullptr; - } - - /// Helper to make a vector value with an element type of `ty`. - /// @param ty the element type - /// @returns the vector value - ir::Value* MakeVectorValue(Type ty) { - switch (ty) { - case kBool: - return b.Constant(b.ir.constant_values.Composite( - MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(true), - b.ir.constant_values.Get(false)})); - case kI32: - return b.Constant(b.ir.constant_values.Composite( - MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(42_i), - b.ir.constant_values.Get(-10_i)})); - case kU32: - return b.Constant(b.ir.constant_values.Composite( - MakeVectorType(ty), - utils::Vector{b.ir.constant_values.Get(42_u), b.ir.constant_values.Get(10_u)})); - case kF32: - return b.Constant(b.ir.constant_values.Composite( - MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(42_f), - b.ir.constant_values.Get(-0.5_f)})); - case kF16: - return b.Constant(b.ir.constant_values.Composite( - MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(42_h), - b.ir.constant_values.Get(-0.5_h)})); - } - return nullptr; - } -}; - -using Arithmetic = BinaryInstructionTest; +using Arithmetic = SpvGeneratorImplTestWithParam; TEST_P(Arithmetic, Scalar) { auto params = GetParam(); @@ -164,7 +79,7 @@ INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest_Binary_F16, BinaryTestCase{kF16, ir::Binary::Kind::kSubtract, "OpFSub"})); -using Bitwise = BinaryInstructionTest; +using Bitwise = SpvGeneratorImplTestWithParam; TEST_P(Bitwise, Scalar) { auto params = GetParam(); @@ -203,7 +118,7 @@ INSTANTIATE_TEST_SUITE_P( BinaryTestCase{kU32, ir::Binary::Kind::kOr, "OpBitwiseOr"}, BinaryTestCase{kU32, ir::Binary::Kind::kXor, "OpBitwiseXor"})); -using Comparison = BinaryInstructionTest; +using Comparison = SpvGeneratorImplTestWithParam; TEST_P(Comparison, Scalar) { auto params = GetParam(); diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc new file mode 100644 index 0000000000..f62848e756 --- /dev/null +++ b/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc @@ -0,0 +1,145 @@ +// Copyright 2023 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/writer/spirv/ir/test_helper_ir.h" + +#include "gmock/gmock.h" +#include "src/tint/builtin/function.h" + +using namespace tint::number_suffixes; // NOLINT + +namespace tint::writer::spirv { +namespace { + +/// A parameterized builtin function test case. +struct BuiltinTestCase { + /// The element type to test. + TestElementType type; + /// The builtin function. + enum builtin::Function function; + /// The expected SPIR-V instruction string. + std::string spirv_inst; +}; + +// Tests for builtins with the signature: T = func(T) +using Builtin_1arg = SpvGeneratorImplTestWithParam; +TEST_P(Builtin_1arg, Scalar) { + auto params = GetParam(); + + auto* func = b.CreateFunction("foo", mod.Types().void_()); + func->StartTarget()->SetInstructions( + utils::Vector{b.Builtin(MakeScalarType(params.type), params.function, + utils::Vector{MakeScalarValue(params.type)}), + b.Return(func)}); + + generator_.EmitFunction(func); + EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst)); +} +TEST_P(Builtin_1arg, Vector) { + auto params = GetParam(); + + auto* func = b.CreateFunction("foo", mod.Types().void_()); + func->StartTarget()->SetInstructions( + utils::Vector{b.Builtin(MakeVectorType(params.type), params.function, + utils::Vector{MakeVectorValue(params.type)}), + + b.Return(func)}); + + generator_.EmitFunction(func); + EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst)); +} +INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest, + Builtin_1arg, + testing::Values(BuiltinTestCase{kI32, builtin::Function::kAbs, "SAbs"}, + BuiltinTestCase{kF32, builtin::Function::kAbs, "FAbs"})); + +// Test that abs of an unsigned value just folds away. +TEST_F(SpvGeneratorImplTest, Builtin_Abs_u32) { + auto* result = b.Builtin(MakeScalarType(kU32), builtin::Function::kAbs, + utils::Vector{MakeScalarValue(kU32)}); + auto* func = b.CreateFunction("foo", MakeScalarType(kU32)); + func->StartTarget()->SetInstructions( + utils::Vector{result, b.Return(func, utils::Vector{result})}); + + generator_.EmitFunction(func); + EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" +%2 = OpTypeInt 32 0 +%3 = OpTypeFunction %2 +%6 = OpConstant %2 1 +%1 = OpFunction %2 None %3 +%4 = OpLabel +OpReturnValue %6 +OpFunctionEnd +)"); +} +TEST_F(SpvGeneratorImplTest, Builtin_Abs_vec2u) { + auto* result = b.Builtin(MakeVectorType(kU32), builtin::Function::kAbs, + utils::Vector{MakeVectorValue(kU32)}); + auto* func = b.CreateFunction("foo", MakeVectorType(kU32)); + func->StartTarget()->SetInstructions( + utils::Vector{result, b.Return(func, utils::Vector{result})}); + + generator_.EmitFunction(func); + EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" +%3 = OpTypeInt 32 0 +%2 = OpTypeVector %3 2 +%4 = OpTypeFunction %2 +%8 = OpConstant %3 42 +%9 = OpConstant %3 10 +%7 = OpConstantComposite %2 %8 %9 +%1 = OpFunction %2 None %4 +%5 = OpLabel +OpReturnValue %7 +OpFunctionEnd +)"); +} + +// Tests for builtins with the signature: T = func(T, T) +using Builtin_2arg = SpvGeneratorImplTestWithParam; +TEST_P(Builtin_2arg, Scalar) { + auto params = GetParam(); + + auto* func = b.CreateFunction("foo", mod.Types().void_()); + func->StartTarget()->SetInstructions(utils::Vector{ + b.Builtin(MakeScalarType(params.type), params.function, + utils::Vector{MakeScalarValue(params.type), MakeScalarValue(params.type)}), + b.Return(func)}); + + generator_.EmitFunction(func); + EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst)); +} +TEST_P(Builtin_2arg, Vector) { + auto params = GetParam(); + + auto* func = b.CreateFunction("foo", mod.Types().void_()); + func->StartTarget()->SetInstructions(utils::Vector{ + b.Builtin(MakeVectorType(params.type), params.function, + utils::Vector{MakeVectorValue(params.type), MakeVectorValue(params.type)}), + + b.Return(func)}); + + generator_.EmitFunction(func); + EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst)); +} +INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest, + Builtin_2arg, + testing::Values(BuiltinTestCase{kF32, builtin::Function::kMax, "FMax"}, + BuiltinTestCase{kI32, builtin::Function::kMax, "SMax"}, + BuiltinTestCase{kU32, builtin::Function::kMax, "UMax"}, + BuiltinTestCase{kF32, builtin::Function::kMin, "FMin"}, + BuiltinTestCase{kI32, builtin::Function::kMin, "SMin"}, + BuiltinTestCase{kU32, builtin::Function::kMin, "UMin"})); + +} // namespace +} // namespace tint::writer::spirv diff --git a/src/tint/writer/spirv/ir/test_helper_ir.h b/src/tint/writer/spirv/ir/test_helper_ir.h index 9509b4238c..add9eaf4e5 100644 --- a/src/tint/writer/spirv/ir/test_helper_ir.h +++ b/src/tint/writer/spirv/ir/test_helper_ir.h @@ -24,6 +24,15 @@ namespace tint::writer::spirv { +/// The element type of a test. +enum TestElementType { + kBool, + kI32, + kU32, + kF32, + kF16, +}; + /// Base helper class for testing the SPIR-V generator implementation. template class SpvGeneratorTestHelperBase : public BASE { @@ -41,6 +50,85 @@ class SpvGeneratorTestHelperBase : public BASE { /// @returns the disassembled types from the generated module. std::string DumpTypes() { return DumpInstructions(generator_.Module().Types()); } + + /// Helper to make a scalar type corresponding to the element type `ty`. + /// @param ty the element type + /// @returns the scalar type + const type::Type* MakeScalarType(TestElementType ty) { + switch (ty) { + case kBool: + return mod.Types().bool_(); + case kI32: + return mod.Types().i32(); + case kU32: + return mod.Types().u32(); + case kF32: + return mod.Types().f32(); + case kF16: + return mod.Types().f16(); + } + return nullptr; + } + + /// Helper to make a vector type corresponding to the element type `ty`. + /// @param ty the element type + /// @returns the vector type + const type::Type* MakeVectorType(TestElementType ty) { + return mod.Types().vec2(MakeScalarType(ty)); + } + + /// Helper to make a scalar value with the scalar type `ty`. + /// @param ty the element type + /// @returns the scalar value + ir::Value* MakeScalarValue(TestElementType ty) { + switch (ty) { + case kBool: + return b.Constant(true); + case kI32: + return b.Constant(i32(1)); + case kU32: + return b.Constant(u32(1)); + case kF32: + return b.Constant(f32(1)); + case kF16: + return b.Constant(f16(1)); + } + return nullptr; + } + + /// Helper to make a vector value with an element type of `ty`. + /// @param ty the element type + /// @returns the vector value + ir::Value* MakeVectorValue(TestElementType ty) { + switch (ty) { + case kBool: + return b.Constant(mod.constant_values.Composite( + MakeVectorType(ty), + utils::Vector{mod.constant_values.Get(true), + mod.constant_values.Get(false)})); + case kI32: + return b.Constant(mod.constant_values.Composite( + MakeVectorType(ty), + utils::Vector{mod.constant_values.Get(i32(42)), + mod.constant_values.Get(i32(-10))})); + case kU32: + return b.Constant(mod.constant_values.Composite( + MakeVectorType(ty), + utils::Vector{mod.constant_values.Get(u32(42)), + mod.constant_values.Get(u32(10))})); + case kF32: + return b.Constant(mod.constant_values.Composite( + MakeVectorType(ty), + utils::Vector{mod.constant_values.Get(f32(42)), + mod.constant_values.Get(f32(-0.5))})); + case kF16: + return b.Constant(mod.constant_values.Composite( + MakeVectorType(ty), + utils::Vector{mod.constant_values.Get(f16(42)), + mod.constant_values.Get(f16(-0.5))})); + } + return nullptr; + } }; using SpvGeneratorImplTest = SpvGeneratorTestHelperBase;