tint/writer/spirv: Support for F16 type, constructor, and convertor

This patch make SPIRV writer support emitting f16 types, f16 literals,
f16 constructor and convertor. Unittests are also implemented.

Currently SPIRV writer will require 4 capabilities in generated SPIRV:
`Float16`, `UniformAndStorageBuffer16BitAccess`,
`StorageBuffer16BitAccess`, and `storageInputOutput16`.

Bug: tint:1473, tint:1502
Change-Id: Ia1af04f1f4a02bf1b1c2599a5d89791854eabc16
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/95920
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Zhaoming Jiang <zhaoming.jiang@intel.com>
This commit is contained in:
Zhaoming Jiang 2022-07-12 12:35:09 +00:00 committed by Dawn LUCI CQ
parent edf650caad
commit 66d4f6e6fb
9 changed files with 2198 additions and 46 deletions

View File

@ -379,22 +379,18 @@ void Builder::push_extension(const char* extension) {
} }
bool Builder::GenerateExtension(ast::Extension extension) { bool Builder::GenerateExtension(ast::Extension extension) {
/*
For each supported extension, push corresponding capability into the builder.
For example:
if (kind == ast::Extension::Kind::kF16) {
push_capability(SpvCapabilityFloat16);
push_capability(SpvCapabilityUniformAndStorageBuffer16BitAccess);
push_capability(SpvCapabilityStorageBuffer16BitAccess);
push_capability(SpvCapabilityStorageInputOutput16);
}
*/
switch (extension) { switch (extension) {
case ast::Extension::kChromiumExperimentalDP4a: case ast::Extension::kChromiumExperimentalDP4a:
push_extension("SPV_KHR_integer_dot_product"); push_extension("SPV_KHR_integer_dot_product");
push_capability(SpvCapabilityDotProductKHR); push_capability(SpvCapabilityDotProductKHR);
push_capability(SpvCapabilityDotProductInput4x8BitPackedKHR); push_capability(SpvCapabilityDotProductInput4x8BitPackedKHR);
break; break;
case ast::Extension::kF16:
push_capability(SpvCapabilityFloat16);
push_capability(SpvCapabilityUniformAndStorageBuffer16BitAccess);
push_capability(SpvCapabilityStorageBuffer16BitAccess);
push_capability(SpvCapabilityStorageInputOutput16);
break;
default: default:
return false; return false;
} }
@ -1354,6 +1350,9 @@ uint32_t Builder::GenerateTypeConstructorOrConversion(const sem::Call* call,
if (result_type->Is<sem::F32>()) { if (result_type->Is<sem::F32>()) {
return GenerateConstantIfNeeded(ScalarConstant::F32(0).AsSpecOp(constant_id)); return GenerateConstantIfNeeded(ScalarConstant::F32(0).AsSpecOp(constant_id));
} }
if (result_type->Is<sem::F16>()) {
return GenerateConstantIfNeeded(ScalarConstant::F16(0).AsSpecOp(constant_id));
}
if (result_type->Is<sem::Bool>()) { if (result_type->Is<sem::Bool>()) {
return GenerateConstantIfNeeded(ScalarConstant::Bool(false).AsSpecOp(constant_id)); return GenerateConstantIfNeeded(ScalarConstant::Bool(false).AsSpecOp(constant_id));
} }
@ -1560,22 +1559,23 @@ uint32_t Builder::GenerateCastOrCopyOrPassthrough(const sem::Type* to_type,
auto* from_type = TypeOf(from_expr)->UnwrapRef(); auto* from_type = TypeOf(from_expr)->UnwrapRef();
spv::Op op = spv::Op::OpNop; spv::Op op = spv::Op::OpNop;
if ((from_type->Is<sem::I32>() && to_type->Is<sem::F32>()) || if ((from_type->Is<sem::I32>() && to_type->is_float_scalar()) ||
(from_type->is_signed_integer_vector() && to_type->is_float_vector())) { (from_type->is_signed_integer_vector() && to_type->is_float_vector())) {
op = spv::Op::OpConvertSToF; op = spv::Op::OpConvertSToF;
} else if ((from_type->Is<sem::U32>() && to_type->Is<sem::F32>()) || } else if ((from_type->Is<sem::U32>() && to_type->is_float_scalar()) ||
(from_type->is_unsigned_integer_vector() && to_type->is_float_vector())) { (from_type->is_unsigned_integer_vector() && to_type->is_float_vector())) {
op = spv::Op::OpConvertUToF; op = spv::Op::OpConvertUToF;
} else if ((from_type->Is<sem::F32>() && to_type->Is<sem::I32>()) || } else if ((from_type->is_float_scalar() && to_type->Is<sem::I32>()) ||
(from_type->is_float_vector() && to_type->is_signed_integer_vector())) { (from_type->is_float_vector() && to_type->is_signed_integer_vector())) {
op = spv::Op::OpConvertFToS; op = spv::Op::OpConvertFToS;
} else if ((from_type->Is<sem::F32>() && to_type->Is<sem::U32>()) || } else if ((from_type->is_float_scalar() && to_type->Is<sem::U32>()) ||
(from_type->is_float_vector() && to_type->is_unsigned_integer_vector())) { (from_type->is_float_vector() && to_type->is_unsigned_integer_vector())) {
op = spv::Op::OpConvertFToU; op = spv::Op::OpConvertFToU;
} else if ((from_type->Is<sem::Bool>() && to_type->Is<sem::Bool>()) || } else if ((from_type->Is<sem::Bool>() && to_type->Is<sem::Bool>()) ||
(from_type->Is<sem::U32>() && to_type->Is<sem::U32>()) || (from_type->Is<sem::U32>() && to_type->Is<sem::U32>()) ||
(from_type->Is<sem::I32>() && to_type->Is<sem::I32>()) || (from_type->Is<sem::I32>() && to_type->Is<sem::I32>()) ||
(from_type->Is<sem::F32>() && to_type->Is<sem::F32>()) || (from_type->Is<sem::F32>() && to_type->Is<sem::F32>()) ||
(from_type->Is<sem::F16>() && to_type->Is<sem::F16>()) ||
(from_type->Is<sem::Vector>() && (from_type == to_type))) { (from_type->Is<sem::Vector>() && (from_type == to_type))) {
return val_id; return val_id;
} else if ((from_type->Is<sem::I32>() && to_type->Is<sem::U32>()) || } else if ((from_type->Is<sem::I32>() && to_type->Is<sem::U32>()) ||
@ -1608,6 +1608,9 @@ uint32_t Builder::GenerateCastOrCopyOrPassthrough(const sem::Type* to_type,
if (to_elem_type->Is<sem::F32>()) { if (to_elem_type->Is<sem::F32>()) {
zero_id = GenerateConstantIfNeeded(ScalarConstant::F32(0)); zero_id = GenerateConstantIfNeeded(ScalarConstant::F32(0));
one_id = GenerateConstantIfNeeded(ScalarConstant::F32(1)); one_id = GenerateConstantIfNeeded(ScalarConstant::F32(1));
} else if (to_elem_type->Is<sem::F16>()) {
zero_id = GenerateConstantIfNeeded(ScalarConstant::F16(0));
one_id = GenerateConstantIfNeeded(ScalarConstant::F16(1));
} else if (to_elem_type->Is<sem::U32>()) { } else if (to_elem_type->Is<sem::U32>()) {
zero_id = GenerateConstantIfNeeded(ScalarConstant::U32(0)); zero_id = GenerateConstantIfNeeded(ScalarConstant::U32(0));
one_id = GenerateConstantIfNeeded(ScalarConstant::U32(1)); one_id = GenerateConstantIfNeeded(ScalarConstant::U32(1));
@ -1691,7 +1694,9 @@ uint32_t Builder::GenerateLiteralIfNeeded(const ast::Variable* var,
constant.value.f32 = static_cast<float>(f->value); constant.value.f32 = static_cast<float>(f->value);
return; return;
case ast::FloatLiteralExpression::Suffix::kH: case ast::FloatLiteralExpression::Suffix::kH:
error_ = "Type f16 is not completely implemented yet"; constant.kind = ScalarConstant::Kind::kF16;
constant.value.f16 = {f16(static_cast<float>(f->value)).BitsRepresentation()};
return;
} }
}, },
[&](Default) { error_ = "unknown literal type"; }); [&](Default) { error_ = "unknown literal type"; });
@ -1750,6 +1755,10 @@ uint32_t Builder::GenerateConstantIfNeeded(const sem::Constant* constant) {
auto val = constant->As<f32>(); auto val = constant->As<f32>();
return GenerateConstantIfNeeded(ScalarConstant::F32(val.value)); return GenerateConstantIfNeeded(ScalarConstant::F32(val.value));
}, },
[&](const sem::F16*) {
auto val = constant->As<f16>();
return GenerateConstantIfNeeded(ScalarConstant::F16(val.value));
},
[&](const sem::I32*) { [&](const sem::I32*) {
auto val = constant->As<i32>(); auto val = constant->As<i32>();
return GenerateConstantIfNeeded(ScalarConstant::I32(val.value)); return GenerateConstantIfNeeded(ScalarConstant::I32(val.value));
@ -1788,6 +1797,10 @@ uint32_t Builder::GenerateConstantIfNeeded(const ScalarConstant& constant) {
type_id = GenerateTypeIfNeeded(builder_.create<sem::F32>()); type_id = GenerateTypeIfNeeded(builder_.create<sem::F32>());
break; break;
} }
case ScalarConstant::Kind::kF16: {
type_id = GenerateTypeIfNeeded(builder_.create<sem::F16>());
break;
}
case ScalarConstant::Kind::kBool: { case ScalarConstant::Kind::kBool: {
type_id = GenerateTypeIfNeeded(builder_.create<sem::Bool>()); type_id = GenerateTypeIfNeeded(builder_.create<sem::Bool>());
break; break;
@ -1822,6 +1835,12 @@ uint32_t Builder::GenerateConstantIfNeeded(const ScalarConstant& constant) {
{Operand(type_id), result, Operand(constant.value.f32)}); {Operand(type_id), result, Operand(constant.value.f32)});
break; break;
} }
case ScalarConstant::Kind::kF16: {
push_type(
constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
{Operand(type_id), result, U32Operand(constant.value.f16.bits_representation)});
break;
}
case ScalarConstant::Kind::kBool: { case ScalarConstant::Kind::kBool: {
if (constant.value.b) { if (constant.value.b) {
push_type( push_type(
@ -3795,9 +3814,8 @@ uint32_t Builder::GenerateTypeIfNeeded(const sem::Type* type) {
return true; return true;
}, },
[&](const sem::F16*) { [&](const sem::F16*) {
// Should be `push_type(spv::Op::OpTypeFloat, {result, Operand(16u)});` push_type(spv::Op::OpTypeFloat, {result, Operand(16u)});
error_ = "Type f16 is not completely implemented yet."; return true;
return false;
}, },
[&](const sem::I32*) { [&](const sem::I32*) {
push_type(spv::Op::OpTypeInt, {result, Operand(32u), Operand(1u)}); push_type(spv::Op::OpTypeInt, {result, Operand(32u), Operand(1u)});

File diff suppressed because it is too large Load Diff

View File

@ -115,6 +115,36 @@ TEST_F(BuilderTest, GlobalConst_Vec_Constructor) {
Validate(b); Validate(b);
} }
TEST_F(BuilderTest, GlobalConst_Vec_F16_Constructor) {
// const c = vec3<f16>(1h, 2h, 3h);
// var v = c;
Enable(ast::Extension::kF16);
auto* c = GlobalConst("c", nullptr, vec3<f16>(1_h, 2_h, 3_h));
GlobalVar("v", nullptr, ast::StorageClass::kPrivate, Expr(c));
spirv::Builder& b = SanitizeAndBuild();
ASSERT_TRUE(b.Build()) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 16
%1 = OpTypeVector %2 3
%3 = OpConstant %2 0x1p+0
%4 = OpConstant %2 0x1p+1
%5 = OpConstant %2 0x1.8p+1
%6 = OpConstantComposite %1 %3 %4 %5
%8 = OpTypePointer Private %1
%7 = OpVariable %8 Private %6
%10 = OpTypeVoid
%9 = OpTypeFunction %10
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()), R"()");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), R"(OpReturn
)");
Validate(b);
}
TEST_F(BuilderTest, GlobalConst_Vec_AInt_Constructor) { TEST_F(BuilderTest, GlobalConst_Vec_AInt_Constructor) {
// const c = vec3(1, 2, 3); // const c = vec3(1, 2, 3);
// var v = c; // var v = c;

View File

@ -163,4 +163,39 @@ TEST_F(BuilderTest, Literal_F32_Dedup) {
)"); )");
} }
TEST_F(BuilderTest, Literal_F16) {
Enable(ast::Extension::kF16);
auto* i = create<ast::FloatLiteralExpression>(23.245, ast::FloatLiteralExpression::Suffix::kH);
WrapInFunction(i);
spirv::Builder& b = Build();
auto id = b.GenerateLiteralIfNeeded(nullptr, i);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(2u, id);
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 16
%2 = OpConstant %1 0x1.73cp+4
)");
}
TEST_F(BuilderTest, Literal_F16_Dedup) {
Enable(ast::Extension::kF16);
auto* i1 = create<ast::FloatLiteralExpression>(23.245, ast::FloatLiteralExpression::Suffix::kH);
auto* i2 = create<ast::FloatLiteralExpression>(23.245, ast::FloatLiteralExpression::Suffix::kH);
WrapInFunction(i1, i2);
spirv::Builder& b = Build();
ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, i1), 0u);
ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, i2), 0u);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 16
%2 = OpConstant %1 0x1.73cp+4
)");
}
} // namespace tint::writer::spirv } // namespace tint::writer::spirv

View File

@ -175,6 +175,34 @@ TEST_F(BuilderTest_Type, ReturnsGeneratedF32) {
ASSERT_FALSE(b.has_error()) << b.error(); ASSERT_FALSE(b.has_error()) << b.error();
} }
TEST_F(BuilderTest_Type, GenerateF16) {
auto* f16 = create<sem::F16>();
spirv::Builder& b = Build();
auto id = b.GenerateTypeIfNeeded(f16);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(id, 1u);
ASSERT_EQ(b.types().size(), 1u);
EXPECT_EQ(DumpInstruction(b.types()[0]), R"(%1 = OpTypeFloat 16
)");
}
TEST_F(BuilderTest_Type, ReturnsGeneratedF16) {
auto* f16 = create<sem::F16>();
auto* i32 = create<sem::I32>();
spirv::Builder& b = Build();
EXPECT_EQ(b.GenerateTypeIfNeeded(f16), 1u);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(b.GenerateTypeIfNeeded(i32), 2u);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(b.GenerateTypeIfNeeded(f16), 1u);
ASSERT_FALSE(b.has_error()) << b.error();
}
TEST_F(BuilderTest_Type, GenerateI32) { TEST_F(BuilderTest_Type, GenerateI32) {
auto* i32 = create<sem::I32>(); auto* i32 = create<sem::I32>();
@ -236,6 +264,39 @@ TEST_F(BuilderTest_Type, ReturnsGeneratedMatrix) {
ASSERT_FALSE(b.has_error()) << b.error(); ASSERT_FALSE(b.has_error()) << b.error();
} }
TEST_F(BuilderTest_Type, GenerateF16Matrix) {
auto* f16 = create<sem::F16>();
auto* vec3 = create<sem::Vector>(f16, 3u);
auto* mat2x3 = create<sem::Matrix>(vec3, 2u);
spirv::Builder& b = Build();
auto id = b.GenerateTypeIfNeeded(mat2x3);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(id, 1u);
EXPECT_EQ(b.types().size(), 3u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 16
%2 = OpTypeVector %3 3
%1 = OpTypeMatrix %2 2
)");
}
TEST_F(BuilderTest_Type, ReturnsGeneratedF16Matrix) {
auto* f16 = create<sem::F16>();
auto* col = create<sem::Vector>(f16, 4u);
auto* mat = create<sem::Matrix>(col, 3u);
spirv::Builder& b = Build();
EXPECT_EQ(b.GenerateTypeIfNeeded(mat), 1u);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(b.GenerateTypeIfNeeded(f16), 3u);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(b.GenerateTypeIfNeeded(mat), 1u);
ASSERT_FALSE(b.has_error()) << b.error();
}
TEST_F(BuilderTest_Type, GeneratePtr) { TEST_F(BuilderTest_Type, GeneratePtr) {
auto* i32 = create<sem::I32>(); auto* i32 = create<sem::I32>();
auto* ptr = create<sem::Pointer>(i32, ast::StorageClass::kOutput, ast::Access::kReadWrite); auto* ptr = create<sem::Pointer>(i32, ast::StorageClass::kOutput, ast::Access::kReadWrite);

View File

@ -20,6 +20,7 @@
#include <cstring> #include <cstring>
#include <functional> #include <functional>
#include "src/tint/number.h"
#include "src/tint/utils/hash.h" #include "src/tint/utils/hash.h"
// Forward declarations // Forward declarations
@ -31,6 +32,12 @@ namespace tint::writer::spirv {
/// ScalarConstant represents a scalar constant value /// ScalarConstant represents a scalar constant value
struct ScalarConstant { struct ScalarConstant {
/// The struct type to hold the bits representation of f16 in the Value union
struct F16 {
/// The 16 bits representation of the f16, stored as uint16_t
uint16_t bits_representation;
};
/// The constant value /// The constant value
union Value { union Value {
/// The value as a bool /// The value as a bool
@ -41,6 +48,8 @@ struct ScalarConstant {
int32_t i32; int32_t i32;
/// The value as a float /// The value as a float
float f32; float f32;
/// The value as bits representation of a f16
F16 f16;
/// The value that is wide enough to encompass all other types (including /// The value that is wide enough to encompass all other types (including
/// future 64-bit data types). /// future 64-bit data types).
@ -48,7 +57,7 @@ struct ScalarConstant {
}; };
/// The kind of constant /// The kind of constant
enum class Kind { kBool, kU32, kI32, kF32 }; enum class Kind { kBool, kU32, kI32, kF32, kF16 };
/// Constructor /// Constructor
inline ScalarConstant() { value.u64 = 0; } inline ScalarConstant() { value.u64 = 0; }
@ -72,7 +81,7 @@ struct ScalarConstant {
} }
/// @param value the value of the constant /// @param value the value of the constant
/// @returns a new ScalarConstant with the provided value and kind Kind::kI32 /// @returns a new ScalarConstant with the provided value and kind Kind::kF32
static inline ScalarConstant F32(float value) { static inline ScalarConstant F32(float value) {
ScalarConstant c; ScalarConstant c;
c.value.f32 = value; c.value.f32 = value;
@ -80,6 +89,15 @@ struct ScalarConstant {
return c; return c;
} }
/// @param value the value of the constant
/// @returns a new ScalarConstant with the provided value and kind Kind::kF16
static inline ScalarConstant F16(f16::type value) {
ScalarConstant c;
c.value.f16 = {f16(value).BitsRepresentation()};
c.kind = Kind::kF16;
return c;
}
/// @param value the value of the constant /// @param value the value of the constant
/// @returns a new ScalarConstant with the provided value and kind Kind::kBool /// @returns a new ScalarConstant with the provided value and kind Kind::kBool
static inline ScalarConstant Bool(bool value) { static inline ScalarConstant Bool(bool value) {

View File

@ -52,5 +52,12 @@ TEST_F(SpirvScalarConstantTest, U32) {
EXPECT_EQ(c.kind, ScalarConstant::Kind::kU32); EXPECT_EQ(c.kind, ScalarConstant::Kind::kU32);
} }
TEST_F(SpirvScalarConstantTest, F16) {
auto c = ScalarConstant::F16(123.456f);
// 123.456f will be quantized to f16 123.4375h, bit pattern 0x57b7
EXPECT_EQ(c.value.f16.bits_representation, 0x57b7u);
EXPECT_EQ(c.kind, ScalarConstant::Kind::kF16);
}
} // namespace } // namespace
} // namespace tint::writer::spirv } // namespace tint::writer::spirv

View File

@ -4,6 +4,10 @@
; Bound: 19 ; Bound: 19
; Schema: 0 ; Schema: 0
OpCapability Shader OpCapability Shader
OpCapability Float16
OpCapability UniformAndStorageBuffer16BitAccess
OpCapability StorageBuffer16BitAccess
OpCapability StorageInputOutput16
OpMemoryModel Logical GLSL450 OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main" %value OpEntryPoint Fragment %main "main" %value
OpExecutionMode %main OriginUpperLeft OpExecutionMode %main OriginUpperLeft

View File

@ -4,6 +4,10 @@
; Bound: 19 ; Bound: 19
; Schema: 0 ; Schema: 0
OpCapability Shader OpCapability Shader
OpCapability Float16
OpCapability UniformAndStorageBuffer16BitAccess
OpCapability StorageBuffer16BitAccess
OpCapability StorageInputOutput16
OpMemoryModel Logical GLSL450 OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main" %value OpEntryPoint Fragment %main "main" %value
OpExecutionMode %main OriginUpperLeft OpExecutionMode %main OriginUpperLeft