[ir][spirv-writer] Emit builtin function calls

Add support for `abs()`, `max()`, and `min()`. Import the GLSL
extended instruction set the first time it is needed.

Move testing utilities from the binary tests into the base test helper
class, as they are more widely useful.

Bug: tint:1906
Change-Id: I5faa928b98c621afcca770cb14a8f9c06f36bcfe
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/134521
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
This commit is contained in:
James Price 2023-05-26 15:10:30 +00:00 committed by Dawn LUCI CQ
parent 2ee63ffc0c
commit bcf4174c06
7 changed files with 318 additions and 89 deletions

View File

@ -1964,6 +1964,7 @@ if (tint_build_unittests) {
if (tint_build_ir) {
sources += [
"writer/spirv/ir/generator_impl_ir_binary_test.cc",
"writer/spirv/ir/generator_impl_ir_builtin_test.cc",
"writer/spirv/ir/generator_impl_ir_constant_test.cc",
"writer/spirv/ir/generator_impl_ir_function_test.cc",
"writer/spirv/ir/generator_impl_ir_if_test.cc",

View File

@ -1254,6 +1254,7 @@ if(TINT_BUILD_TESTS)
if(${TINT_BUILD_IR})
list(APPEND TINT_TEST_SRCS
writer/spirv/ir/generator_impl_ir_binary_test.cc
writer/spirv/ir/generator_impl_ir_builtin_test.cc
writer/spirv/ir/generator_impl_ir_constant_test.cc
writer/spirv/ir/generator_impl_ir_function_test.cc
writer/spirv/ir/generator_impl_ir_if_test.cc

View File

@ -16,9 +16,11 @@
#include <utility>
#include "spirv/unified1/GLSL.std.450.h"
#include "spirv/unified1/spirv.h"
#include "src/tint/ir/binary.h"
#include "src/tint/ir/block.h"
#include "src/tint/ir/builtin.h"
#include "src/tint/ir/exit_if.h"
#include "src/tint/ir/if.h"
#include "src/tint/ir/load.h"
@ -330,6 +332,7 @@ void GeneratorImplIr::EmitBlock(const ir::Block* block) {
auto result = Switch(
inst, //
[&](const ir::Binary* b) { return EmitBinary(b); },
[&](const ir::Builtin* b) { return EmitBuiltin(b); },
[&](const ir::Load* l) { return EmitLoad(l); },
[&](const ir::Store* s) {
EmitStore(s);
@ -518,6 +521,73 @@ uint32_t GeneratorImplIr::EmitBinary(const ir::Binary* binary) {
return id;
}
uint32_t GeneratorImplIr::EmitBuiltin(const ir::Builtin* builtin) {
auto id = module_.NextId();
auto* result_ty = builtin->Type();
spv::Op op = spv::Op::Max;
OperandList operands = {Type(result_ty), id};
// Helper to set up the opcode and operand list for a GLSL extended instruction.
auto glsl_ext_inst = [&](enum GLSLstd450 inst) {
constexpr const char* kGLSLstd450 = "GLSL.std.450";
op = spv::Op::OpExtInst;
operands.push_back(imports_.GetOrCreate(kGLSLstd450, [&]() {
// Import the instruction set the first time it is requested.
auto import = module_.NextId();
module_.PushExtImport(spv::Op::OpExtInstImport, {import, Operand(kGLSLstd450)});
return import;
}));
operands.push_back(U32Operand(inst));
};
// Determine the opcode.
switch (builtin->Func()) {
case builtin::Function::kAbs:
if (result_ty->is_float_scalar_or_vector()) {
glsl_ext_inst(GLSLstd450FAbs);
} else if (result_ty->is_signed_integer_scalar_or_vector()) {
glsl_ext_inst(GLSLstd450SAbs);
} else if (result_ty->is_unsigned_integer_scalar_or_vector()) {
// abs() is a no-op for unsigned integers.
return Value(builtin->Args()[0]);
}
break;
case builtin::Function::kMax:
if (result_ty->is_float_scalar_or_vector()) {
glsl_ext_inst(GLSLstd450FMax);
} else if (result_ty->is_signed_integer_scalar_or_vector()) {
glsl_ext_inst(GLSLstd450SMax);
} else if (result_ty->is_unsigned_integer_scalar_or_vector()) {
glsl_ext_inst(GLSLstd450UMax);
}
break;
case builtin::Function::kMin:
if (result_ty->is_float_scalar_or_vector()) {
glsl_ext_inst(GLSLstd450FMin);
} else if (result_ty->is_signed_integer_scalar_or_vector()) {
glsl_ext_inst(GLSLstd450SMin);
} else if (result_ty->is_unsigned_integer_scalar_or_vector()) {
glsl_ext_inst(GLSLstd450UMin);
}
break;
default:
TINT_ICE(Writer, diagnostics_) << "unimplemented builtin function: " << builtin->Func();
return 0u;
}
TINT_ASSERT(Writer, op != spv::Op::Max);
// Add the arguments to the builtin call.
for (auto* arg : builtin->Args()) {
operands.push_back(Value(arg));
}
// Emit the instruction.
current_function_.push_inst(op, operands);
return id;
}
uint32_t GeneratorImplIr::EmitLoad(const ir::Load* load) {
auto id = module_.NextId();
current_function_.push_inst(spv::Op::OpLoad, {Type(load->Type()), id, Value(load->From())});

View File

@ -32,6 +32,7 @@ namespace tint::ir {
class Binary;
class Block;
class Branch;
class Builtin;
class If;
class Function;
class Load;
@ -110,6 +111,11 @@ class GeneratorImplIr {
/// @returns the result ID of the instruction
uint32_t EmitBinary(const ir::Binary* binary);
/// Emit a builtin function call instruction.
/// @param call the builtin call instruction to emit
/// @returns the result ID of the instruction
uint32_t EmitBuiltin(const ir::Builtin* call);
/// Emit a load instruction.
/// @param load the load instruction to emit
/// @returns the result ID of the instruction
@ -184,6 +190,9 @@ class GeneratorImplIr {
/// The map of blocks to the IDs of their label instructions.
utils::Hashmap<const ir::Block*, uint32_t, 8> block_labels_;
/// The map of extended instruction set names to their result IDs.
utils::Hashmap<std::string_view, uint32_t, 2> imports_;
/// The current function that is being emitted.
Function current_function_;

View File

@ -22,102 +22,17 @@ using namespace tint::number_suffixes; // NOLINT
namespace tint::writer::spirv {
namespace {
/// The element type of a test.
enum Type {
kBool,
kI32,
kU32,
kF32,
kF16,
};
/// A parameterized test case.
struct BinaryTestCase {
/// The element type to test.
Type type;
TestElementType type;
/// The binary operation.
enum ir::Binary::Kind kind;
/// The expected SPIR-V instruction.
std::string spirv_inst;
};
/// A helper class for parameterized binary instruction tests.
class BinaryInstructionTest : public SpvGeneratorImplTestWithParam<BinaryTestCase> {
protected:
/// Helper to make a scalar type corresponding to the element type `ty`.
/// @param ty the element type
/// @returns the scalar type
const type::Type* MakeScalarType(Type ty) {
switch (ty) {
case kBool:
return mod.Types().bool_();
case kI32:
return mod.Types().i32();
case kU32:
return mod.Types().u32();
case kF32:
return mod.Types().f32();
case kF16:
return mod.Types().f16();
}
return nullptr;
}
/// Helper to make a vector type corresponding to the element type `ty`.
/// @param ty the element type
/// @returns the vector type
const type::Type* MakeVectorType(Type ty) { return mod.Types().vec2(MakeScalarType(ty)); }
/// Helper to make a scalar value with the scalar type `ty`.
/// @param ty the element type
/// @returns the scalar value
ir::Value* MakeScalarValue(Type ty) {
switch (ty) {
case kBool:
return b.Constant(true);
case kI32:
return b.Constant(1_i);
case kU32:
return b.Constant(1_u);
case kF32:
return b.Constant(1_f);
case kF16:
return b.Constant(1_h);
}
return nullptr;
}
/// Helper to make a vector value with an element type of `ty`.
/// @param ty the element type
/// @returns the vector value
ir::Value* MakeVectorValue(Type ty) {
switch (ty) {
case kBool:
return b.Constant(b.ir.constant_values.Composite(
MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(true),
b.ir.constant_values.Get(false)}));
case kI32:
return b.Constant(b.ir.constant_values.Composite(
MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(42_i),
b.ir.constant_values.Get(-10_i)}));
case kU32:
return b.Constant(b.ir.constant_values.Composite(
MakeVectorType(ty),
utils::Vector{b.ir.constant_values.Get(42_u), b.ir.constant_values.Get(10_u)}));
case kF32:
return b.Constant(b.ir.constant_values.Composite(
MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(42_f),
b.ir.constant_values.Get(-0.5_f)}));
case kF16:
return b.Constant(b.ir.constant_values.Composite(
MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(42_h),
b.ir.constant_values.Get(-0.5_h)}));
}
return nullptr;
}
};
using Arithmetic = BinaryInstructionTest;
using Arithmetic = SpvGeneratorImplTestWithParam<BinaryTestCase>;
TEST_P(Arithmetic, Scalar) {
auto params = GetParam();
@ -164,7 +79,7 @@ INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest_Binary_F16,
BinaryTestCase{kF16, ir::Binary::Kind::kSubtract,
"OpFSub"}));
using Bitwise = BinaryInstructionTest;
using Bitwise = SpvGeneratorImplTestWithParam<BinaryTestCase>;
TEST_P(Bitwise, Scalar) {
auto params = GetParam();
@ -203,7 +118,7 @@ INSTANTIATE_TEST_SUITE_P(
BinaryTestCase{kU32, ir::Binary::Kind::kOr, "OpBitwiseOr"},
BinaryTestCase{kU32, ir::Binary::Kind::kXor, "OpBitwiseXor"}));
using Comparison = BinaryInstructionTest;
using Comparison = SpvGeneratorImplTestWithParam<BinaryTestCase>;
TEST_P(Comparison, Scalar) {
auto params = GetParam();

View File

@ -0,0 +1,145 @@
// Copyright 2023 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/writer/spirv/ir/test_helper_ir.h"
#include "gmock/gmock.h"
#include "src/tint/builtin/function.h"
using namespace tint::number_suffixes; // NOLINT
namespace tint::writer::spirv {
namespace {
/// A parameterized builtin function test case.
struct BuiltinTestCase {
/// The element type to test.
TestElementType type;
/// The builtin function.
enum builtin::Function function;
/// The expected SPIR-V instruction string.
std::string spirv_inst;
};
// Tests for builtins with the signature: T = func(T)
using Builtin_1arg = SpvGeneratorImplTestWithParam<BuiltinTestCase>;
TEST_P(Builtin_1arg, Scalar) {
auto params = GetParam();
auto* func = b.CreateFunction("foo", mod.Types().void_());
func->StartTarget()->SetInstructions(
utils::Vector{b.Builtin(MakeScalarType(params.type), params.function,
utils::Vector{MakeScalarValue(params.type)}),
b.Return(func)});
generator_.EmitFunction(func);
EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
}
TEST_P(Builtin_1arg, Vector) {
auto params = GetParam();
auto* func = b.CreateFunction("foo", mod.Types().void_());
func->StartTarget()->SetInstructions(
utils::Vector{b.Builtin(MakeVectorType(params.type), params.function,
utils::Vector{MakeVectorValue(params.type)}),
b.Return(func)});
generator_.EmitFunction(func);
EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
}
INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest,
Builtin_1arg,
testing::Values(BuiltinTestCase{kI32, builtin::Function::kAbs, "SAbs"},
BuiltinTestCase{kF32, builtin::Function::kAbs, "FAbs"}));
// Test that abs of an unsigned value just folds away.
TEST_F(SpvGeneratorImplTest, Builtin_Abs_u32) {
auto* result = b.Builtin(MakeScalarType(kU32), builtin::Function::kAbs,
utils::Vector{MakeScalarValue(kU32)});
auto* func = b.CreateFunction("foo", MakeScalarType(kU32));
func->StartTarget()->SetInstructions(
utils::Vector{result, b.Return(func, utils::Vector{result})});
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeInt 32 0
%3 = OpTypeFunction %2
%6 = OpConstant %2 1
%1 = OpFunction %2 None %3
%4 = OpLabel
OpReturnValue %6
OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Builtin_Abs_vec2u) {
auto* result = b.Builtin(MakeVectorType(kU32), builtin::Function::kAbs,
utils::Vector{MakeVectorValue(kU32)});
auto* func = b.CreateFunction("foo", MakeVectorType(kU32));
func->StartTarget()->SetInstructions(
utils::Vector{result, b.Return(func, utils::Vector{result})});
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%3 = OpTypeInt 32 0
%2 = OpTypeVector %3 2
%4 = OpTypeFunction %2
%8 = OpConstant %3 42
%9 = OpConstant %3 10
%7 = OpConstantComposite %2 %8 %9
%1 = OpFunction %2 None %4
%5 = OpLabel
OpReturnValue %7
OpFunctionEnd
)");
}
// Tests for builtins with the signature: T = func(T, T)
using Builtin_2arg = SpvGeneratorImplTestWithParam<BuiltinTestCase>;
TEST_P(Builtin_2arg, Scalar) {
auto params = GetParam();
auto* func = b.CreateFunction("foo", mod.Types().void_());
func->StartTarget()->SetInstructions(utils::Vector{
b.Builtin(MakeScalarType(params.type), params.function,
utils::Vector{MakeScalarValue(params.type), MakeScalarValue(params.type)}),
b.Return(func)});
generator_.EmitFunction(func);
EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
}
TEST_P(Builtin_2arg, Vector) {
auto params = GetParam();
auto* func = b.CreateFunction("foo", mod.Types().void_());
func->StartTarget()->SetInstructions(utils::Vector{
b.Builtin(MakeVectorType(params.type), params.function,
utils::Vector{MakeVectorValue(params.type), MakeVectorValue(params.type)}),
b.Return(func)});
generator_.EmitFunction(func);
EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
}
INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest,
Builtin_2arg,
testing::Values(BuiltinTestCase{kF32, builtin::Function::kMax, "FMax"},
BuiltinTestCase{kI32, builtin::Function::kMax, "SMax"},
BuiltinTestCase{kU32, builtin::Function::kMax, "UMax"},
BuiltinTestCase{kF32, builtin::Function::kMin, "FMin"},
BuiltinTestCase{kI32, builtin::Function::kMin, "SMin"},
BuiltinTestCase{kU32, builtin::Function::kMin, "UMin"}));
} // namespace
} // namespace tint::writer::spirv

View File

@ -24,6 +24,15 @@
namespace tint::writer::spirv {
/// The element type of a test.
enum TestElementType {
kBool,
kI32,
kU32,
kF32,
kF16,
};
/// Base helper class for testing the SPIR-V generator implementation.
template <typename BASE>
class SpvGeneratorTestHelperBase : public BASE {
@ -41,6 +50,85 @@ class SpvGeneratorTestHelperBase : public BASE {
/// @returns the disassembled types from the generated module.
std::string DumpTypes() { return DumpInstructions(generator_.Module().Types()); }
/// Helper to make a scalar type corresponding to the element type `ty`.
/// @param ty the element type
/// @returns the scalar type
const type::Type* MakeScalarType(TestElementType ty) {
switch (ty) {
case kBool:
return mod.Types().bool_();
case kI32:
return mod.Types().i32();
case kU32:
return mod.Types().u32();
case kF32:
return mod.Types().f32();
case kF16:
return mod.Types().f16();
}
return nullptr;
}
/// Helper to make a vector type corresponding to the element type `ty`.
/// @param ty the element type
/// @returns the vector type
const type::Type* MakeVectorType(TestElementType ty) {
return mod.Types().vec2(MakeScalarType(ty));
}
/// Helper to make a scalar value with the scalar type `ty`.
/// @param ty the element type
/// @returns the scalar value
ir::Value* MakeScalarValue(TestElementType ty) {
switch (ty) {
case kBool:
return b.Constant(true);
case kI32:
return b.Constant(i32(1));
case kU32:
return b.Constant(u32(1));
case kF32:
return b.Constant(f32(1));
case kF16:
return b.Constant(f16(1));
}
return nullptr;
}
/// Helper to make a vector value with an element type of `ty`.
/// @param ty the element type
/// @returns the vector value
ir::Value* MakeVectorValue(TestElementType ty) {
switch (ty) {
case kBool:
return b.Constant(mod.constant_values.Composite(
MakeVectorType(ty),
utils::Vector<const constant::Value*, 2>{mod.constant_values.Get(true),
mod.constant_values.Get(false)}));
case kI32:
return b.Constant(mod.constant_values.Composite(
MakeVectorType(ty),
utils::Vector<const constant::Value*, 2>{mod.constant_values.Get(i32(42)),
mod.constant_values.Get(i32(-10))}));
case kU32:
return b.Constant(mod.constant_values.Composite(
MakeVectorType(ty),
utils::Vector<const constant::Value*, 2>{mod.constant_values.Get(u32(42)),
mod.constant_values.Get(u32(10))}));
case kF32:
return b.Constant(mod.constant_values.Composite(
MakeVectorType(ty),
utils::Vector<const constant::Value*, 2>{mod.constant_values.Get(f32(42)),
mod.constant_values.Get(f32(-0.5))}));
case kF16:
return b.Constant(mod.constant_values.Composite(
MakeVectorType(ty),
utils::Vector<const constant::Value*, 2>{mod.constant_values.Get(f16(42)),
mod.constant_values.Get(f16(-0.5))}));
}
return nullptr;
}
};
using SpvGeneratorImplTest = SpvGeneratorTestHelperBase<testing::Test>;