diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index 5970683500..f15e7e29f5 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -1964,6 +1964,7 @@ if (tint_build_unittests) { if (tint_build_ir) { sources += [ "writer/spirv/ir/generator_impl_ir_binary_test.cc", + "writer/spirv/ir/generator_impl_ir_builtin_test.cc", "writer/spirv/ir/generator_impl_ir_constant_test.cc", "writer/spirv/ir/generator_impl_ir_function_test.cc", "writer/spirv/ir/generator_impl_ir_if_test.cc", diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index 17fa2d3fab..7217a3513d 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -1254,6 +1254,7 @@ if(TINT_BUILD_TESTS) if(${TINT_BUILD_IR}) list(APPEND TINT_TEST_SRCS writer/spirv/ir/generator_impl_ir_binary_test.cc + writer/spirv/ir/generator_impl_ir_builtin_test.cc writer/spirv/ir/generator_impl_ir_constant_test.cc writer/spirv/ir/generator_impl_ir_function_test.cc writer/spirv/ir/generator_impl_ir_if_test.cc diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc index 7da87c1305..4bc7955fba 100644 --- a/src/tint/writer/spirv/ir/generator_impl_ir.cc +++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc @@ -16,9 +16,11 @@ #include +#include "spirv/unified1/GLSL.std.450.h" #include "spirv/unified1/spirv.h" #include "src/tint/ir/binary.h" #include "src/tint/ir/block.h" +#include "src/tint/ir/builtin.h" #include "src/tint/ir/exit_if.h" #include "src/tint/ir/if.h" #include "src/tint/ir/load.h" @@ -330,6 +332,7 @@ void GeneratorImplIr::EmitBlock(const ir::Block* block) { auto result = Switch( inst, // [&](const ir::Binary* b) { return EmitBinary(b); }, + [&](const ir::Builtin* b) { return EmitBuiltin(b); }, [&](const ir::Load* l) { return EmitLoad(l); }, [&](const ir::Store* s) { EmitStore(s); @@ -518,6 +521,73 @@ uint32_t GeneratorImplIr::EmitBinary(const ir::Binary* binary) { return id; } +uint32_t GeneratorImplIr::EmitBuiltin(const ir::Builtin* builtin) { + auto id = module_.NextId(); + auto* result_ty = builtin->Type(); + + spv::Op op = spv::Op::Max; + OperandList operands = {Type(result_ty), id}; + + // Helper to set up the opcode and operand list for a GLSL extended instruction. + auto glsl_ext_inst = [&](enum GLSLstd450 inst) { + constexpr const char* kGLSLstd450 = "GLSL.std.450"; + op = spv::Op::OpExtInst; + operands.push_back(imports_.GetOrCreate(kGLSLstd450, [&]() { + // Import the instruction set the first time it is requested. + auto import = module_.NextId(); + module_.PushExtImport(spv::Op::OpExtInstImport, {import, Operand(kGLSLstd450)}); + return import; + })); + operands.push_back(U32Operand(inst)); + }; + + // Determine the opcode. + switch (builtin->Func()) { + case builtin::Function::kAbs: + if (result_ty->is_float_scalar_or_vector()) { + glsl_ext_inst(GLSLstd450FAbs); + } else if (result_ty->is_signed_integer_scalar_or_vector()) { + glsl_ext_inst(GLSLstd450SAbs); + } else if (result_ty->is_unsigned_integer_scalar_or_vector()) { + // abs() is a no-op for unsigned integers. + return Value(builtin->Args()[0]); + } + break; + case builtin::Function::kMax: + if (result_ty->is_float_scalar_or_vector()) { + glsl_ext_inst(GLSLstd450FMax); + } else if (result_ty->is_signed_integer_scalar_or_vector()) { + glsl_ext_inst(GLSLstd450SMax); + } else if (result_ty->is_unsigned_integer_scalar_or_vector()) { + glsl_ext_inst(GLSLstd450UMax); + } + break; + case builtin::Function::kMin: + if (result_ty->is_float_scalar_or_vector()) { + glsl_ext_inst(GLSLstd450FMin); + } else if (result_ty->is_signed_integer_scalar_or_vector()) { + glsl_ext_inst(GLSLstd450SMin); + } else if (result_ty->is_unsigned_integer_scalar_or_vector()) { + glsl_ext_inst(GLSLstd450UMin); + } + break; + default: + TINT_ICE(Writer, diagnostics_) << "unimplemented builtin function: " << builtin->Func(); + return 0u; + } + TINT_ASSERT(Writer, op != spv::Op::Max); + + // Add the arguments to the builtin call. + for (auto* arg : builtin->Args()) { + operands.push_back(Value(arg)); + } + + // Emit the instruction. + current_function_.push_inst(op, operands); + + return id; +} + uint32_t GeneratorImplIr::EmitLoad(const ir::Load* load) { auto id = module_.NextId(); current_function_.push_inst(spv::Op::OpLoad, {Type(load->Type()), id, Value(load->From())}); diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.h b/src/tint/writer/spirv/ir/generator_impl_ir.h index f4ababac7c..d42cce694e 100644 --- a/src/tint/writer/spirv/ir/generator_impl_ir.h +++ b/src/tint/writer/spirv/ir/generator_impl_ir.h @@ -32,6 +32,7 @@ namespace tint::ir { class Binary; class Block; class Branch; +class Builtin; class If; class Function; class Load; @@ -110,6 +111,11 @@ class GeneratorImplIr { /// @returns the result ID of the instruction uint32_t EmitBinary(const ir::Binary* binary); + /// Emit a builtin function call instruction. + /// @param call the builtin call instruction to emit + /// @returns the result ID of the instruction + uint32_t EmitBuiltin(const ir::Builtin* call); + /// Emit a load instruction. /// @param load the load instruction to emit /// @returns the result ID of the instruction @@ -184,6 +190,9 @@ class GeneratorImplIr { /// The map of blocks to the IDs of their label instructions. utils::Hashmap block_labels_; + /// The map of extended instruction set names to their result IDs. + utils::Hashmap imports_; + /// The current function that is being emitted. Function current_function_; diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc index cfe3dc4eea..f6e501546b 100644 --- a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc +++ b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc @@ -22,102 +22,17 @@ using namespace tint::number_suffixes; // NOLINT namespace tint::writer::spirv { namespace { -/// The element type of a test. -enum Type { - kBool, - kI32, - kU32, - kF32, - kF16, -}; - /// A parameterized test case. struct BinaryTestCase { /// The element type to test. - Type type; + TestElementType type; /// The binary operation. enum ir::Binary::Kind kind; /// The expected SPIR-V instruction. std::string spirv_inst; }; -/// A helper class for parameterized binary instruction tests. -class BinaryInstructionTest : public SpvGeneratorImplTestWithParam { - protected: - /// Helper to make a scalar type corresponding to the element type `ty`. - /// @param ty the element type - /// @returns the scalar type - const type::Type* MakeScalarType(Type ty) { - switch (ty) { - case kBool: - return mod.Types().bool_(); - case kI32: - return mod.Types().i32(); - case kU32: - return mod.Types().u32(); - case kF32: - return mod.Types().f32(); - case kF16: - return mod.Types().f16(); - } - return nullptr; - } - - /// Helper to make a vector type corresponding to the element type `ty`. - /// @param ty the element type - /// @returns the vector type - const type::Type* MakeVectorType(Type ty) { return mod.Types().vec2(MakeScalarType(ty)); } - - /// Helper to make a scalar value with the scalar type `ty`. - /// @param ty the element type - /// @returns the scalar value - ir::Value* MakeScalarValue(Type ty) { - switch (ty) { - case kBool: - return b.Constant(true); - case kI32: - return b.Constant(1_i); - case kU32: - return b.Constant(1_u); - case kF32: - return b.Constant(1_f); - case kF16: - return b.Constant(1_h); - } - return nullptr; - } - - /// Helper to make a vector value with an element type of `ty`. - /// @param ty the element type - /// @returns the vector value - ir::Value* MakeVectorValue(Type ty) { - switch (ty) { - case kBool: - return b.Constant(b.ir.constant_values.Composite( - MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(true), - b.ir.constant_values.Get(false)})); - case kI32: - return b.Constant(b.ir.constant_values.Composite( - MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(42_i), - b.ir.constant_values.Get(-10_i)})); - case kU32: - return b.Constant(b.ir.constant_values.Composite( - MakeVectorType(ty), - utils::Vector{b.ir.constant_values.Get(42_u), b.ir.constant_values.Get(10_u)})); - case kF32: - return b.Constant(b.ir.constant_values.Composite( - MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(42_f), - b.ir.constant_values.Get(-0.5_f)})); - case kF16: - return b.Constant(b.ir.constant_values.Composite( - MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(42_h), - b.ir.constant_values.Get(-0.5_h)})); - } - return nullptr; - } -}; - -using Arithmetic = BinaryInstructionTest; +using Arithmetic = SpvGeneratorImplTestWithParam; TEST_P(Arithmetic, Scalar) { auto params = GetParam(); @@ -164,7 +79,7 @@ INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest_Binary_F16, BinaryTestCase{kF16, ir::Binary::Kind::kSubtract, "OpFSub"})); -using Bitwise = BinaryInstructionTest; +using Bitwise = SpvGeneratorImplTestWithParam; TEST_P(Bitwise, Scalar) { auto params = GetParam(); @@ -203,7 +118,7 @@ INSTANTIATE_TEST_SUITE_P( BinaryTestCase{kU32, ir::Binary::Kind::kOr, "OpBitwiseOr"}, BinaryTestCase{kU32, ir::Binary::Kind::kXor, "OpBitwiseXor"})); -using Comparison = BinaryInstructionTest; +using Comparison = SpvGeneratorImplTestWithParam; TEST_P(Comparison, Scalar) { auto params = GetParam(); diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc new file mode 100644 index 0000000000..f62848e756 --- /dev/null +++ b/src/tint/writer/spirv/ir/generator_impl_ir_builtin_test.cc @@ -0,0 +1,145 @@ +// Copyright 2023 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/writer/spirv/ir/test_helper_ir.h" + +#include "gmock/gmock.h" +#include "src/tint/builtin/function.h" + +using namespace tint::number_suffixes; // NOLINT + +namespace tint::writer::spirv { +namespace { + +/// A parameterized builtin function test case. +struct BuiltinTestCase { + /// The element type to test. + TestElementType type; + /// The builtin function. + enum builtin::Function function; + /// The expected SPIR-V instruction string. + std::string spirv_inst; +}; + +// Tests for builtins with the signature: T = func(T) +using Builtin_1arg = SpvGeneratorImplTestWithParam; +TEST_P(Builtin_1arg, Scalar) { + auto params = GetParam(); + + auto* func = b.CreateFunction("foo", mod.Types().void_()); + func->StartTarget()->SetInstructions( + utils::Vector{b.Builtin(MakeScalarType(params.type), params.function, + utils::Vector{MakeScalarValue(params.type)}), + b.Return(func)}); + + generator_.EmitFunction(func); + EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst)); +} +TEST_P(Builtin_1arg, Vector) { + auto params = GetParam(); + + auto* func = b.CreateFunction("foo", mod.Types().void_()); + func->StartTarget()->SetInstructions( + utils::Vector{b.Builtin(MakeVectorType(params.type), params.function, + utils::Vector{MakeVectorValue(params.type)}), + + b.Return(func)}); + + generator_.EmitFunction(func); + EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst)); +} +INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest, + Builtin_1arg, + testing::Values(BuiltinTestCase{kI32, builtin::Function::kAbs, "SAbs"}, + BuiltinTestCase{kF32, builtin::Function::kAbs, "FAbs"})); + +// Test that abs of an unsigned value just folds away. +TEST_F(SpvGeneratorImplTest, Builtin_Abs_u32) { + auto* result = b.Builtin(MakeScalarType(kU32), builtin::Function::kAbs, + utils::Vector{MakeScalarValue(kU32)}); + auto* func = b.CreateFunction("foo", MakeScalarType(kU32)); + func->StartTarget()->SetInstructions( + utils::Vector{result, b.Return(func, utils::Vector{result})}); + + generator_.EmitFunction(func); + EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" +%2 = OpTypeInt 32 0 +%3 = OpTypeFunction %2 +%6 = OpConstant %2 1 +%1 = OpFunction %2 None %3 +%4 = OpLabel +OpReturnValue %6 +OpFunctionEnd +)"); +} +TEST_F(SpvGeneratorImplTest, Builtin_Abs_vec2u) { + auto* result = b.Builtin(MakeVectorType(kU32), builtin::Function::kAbs, + utils::Vector{MakeVectorValue(kU32)}); + auto* func = b.CreateFunction("foo", MakeVectorType(kU32)); + func->StartTarget()->SetInstructions( + utils::Vector{result, b.Return(func, utils::Vector{result})}); + + generator_.EmitFunction(func); + EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" +%3 = OpTypeInt 32 0 +%2 = OpTypeVector %3 2 +%4 = OpTypeFunction %2 +%8 = OpConstant %3 42 +%9 = OpConstant %3 10 +%7 = OpConstantComposite %2 %8 %9 +%1 = OpFunction %2 None %4 +%5 = OpLabel +OpReturnValue %7 +OpFunctionEnd +)"); +} + +// Tests for builtins with the signature: T = func(T, T) +using Builtin_2arg = SpvGeneratorImplTestWithParam; +TEST_P(Builtin_2arg, Scalar) { + auto params = GetParam(); + + auto* func = b.CreateFunction("foo", mod.Types().void_()); + func->StartTarget()->SetInstructions(utils::Vector{ + b.Builtin(MakeScalarType(params.type), params.function, + utils::Vector{MakeScalarValue(params.type), MakeScalarValue(params.type)}), + b.Return(func)}); + + generator_.EmitFunction(func); + EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst)); +} +TEST_P(Builtin_2arg, Vector) { + auto params = GetParam(); + + auto* func = b.CreateFunction("foo", mod.Types().void_()); + func->StartTarget()->SetInstructions(utils::Vector{ + b.Builtin(MakeVectorType(params.type), params.function, + utils::Vector{MakeVectorValue(params.type), MakeVectorValue(params.type)}), + + b.Return(func)}); + + generator_.EmitFunction(func); + EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst)); +} +INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest, + Builtin_2arg, + testing::Values(BuiltinTestCase{kF32, builtin::Function::kMax, "FMax"}, + BuiltinTestCase{kI32, builtin::Function::kMax, "SMax"}, + BuiltinTestCase{kU32, builtin::Function::kMax, "UMax"}, + BuiltinTestCase{kF32, builtin::Function::kMin, "FMin"}, + BuiltinTestCase{kI32, builtin::Function::kMin, "SMin"}, + BuiltinTestCase{kU32, builtin::Function::kMin, "UMin"})); + +} // namespace +} // namespace tint::writer::spirv diff --git a/src/tint/writer/spirv/ir/test_helper_ir.h b/src/tint/writer/spirv/ir/test_helper_ir.h index 9509b4238c..add9eaf4e5 100644 --- a/src/tint/writer/spirv/ir/test_helper_ir.h +++ b/src/tint/writer/spirv/ir/test_helper_ir.h @@ -24,6 +24,15 @@ namespace tint::writer::spirv { +/// The element type of a test. +enum TestElementType { + kBool, + kI32, + kU32, + kF32, + kF16, +}; + /// Base helper class for testing the SPIR-V generator implementation. template class SpvGeneratorTestHelperBase : public BASE { @@ -41,6 +50,85 @@ class SpvGeneratorTestHelperBase : public BASE { /// @returns the disassembled types from the generated module. std::string DumpTypes() { return DumpInstructions(generator_.Module().Types()); } + + /// Helper to make a scalar type corresponding to the element type `ty`. + /// @param ty the element type + /// @returns the scalar type + const type::Type* MakeScalarType(TestElementType ty) { + switch (ty) { + case kBool: + return mod.Types().bool_(); + case kI32: + return mod.Types().i32(); + case kU32: + return mod.Types().u32(); + case kF32: + return mod.Types().f32(); + case kF16: + return mod.Types().f16(); + } + return nullptr; + } + + /// Helper to make a vector type corresponding to the element type `ty`. + /// @param ty the element type + /// @returns the vector type + const type::Type* MakeVectorType(TestElementType ty) { + return mod.Types().vec2(MakeScalarType(ty)); + } + + /// Helper to make a scalar value with the scalar type `ty`. + /// @param ty the element type + /// @returns the scalar value + ir::Value* MakeScalarValue(TestElementType ty) { + switch (ty) { + case kBool: + return b.Constant(true); + case kI32: + return b.Constant(i32(1)); + case kU32: + return b.Constant(u32(1)); + case kF32: + return b.Constant(f32(1)); + case kF16: + return b.Constant(f16(1)); + } + return nullptr; + } + + /// Helper to make a vector value with an element type of `ty`. + /// @param ty the element type + /// @returns the vector value + ir::Value* MakeVectorValue(TestElementType ty) { + switch (ty) { + case kBool: + return b.Constant(mod.constant_values.Composite( + MakeVectorType(ty), + utils::Vector{mod.constant_values.Get(true), + mod.constant_values.Get(false)})); + case kI32: + return b.Constant(mod.constant_values.Composite( + MakeVectorType(ty), + utils::Vector{mod.constant_values.Get(i32(42)), + mod.constant_values.Get(i32(-10))})); + case kU32: + return b.Constant(mod.constant_values.Composite( + MakeVectorType(ty), + utils::Vector{mod.constant_values.Get(u32(42)), + mod.constant_values.Get(u32(10))})); + case kF32: + return b.Constant(mod.constant_values.Composite( + MakeVectorType(ty), + utils::Vector{mod.constant_values.Get(f32(42)), + mod.constant_values.Get(f32(-0.5))})); + case kF16: + return b.Constant(mod.constant_values.Composite( + MakeVectorType(ty), + utils::Vector{mod.constant_values.Get(f16(42)), + mod.constant_values.Get(f16(-0.5))})); + } + return nullptr; + } }; using SpvGeneratorImplTest = SpvGeneratorTestHelperBase;