mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-07-03 03:35:59 +00:00
tint/writer/spirv: Support for F16 type, constructor, and convertor
This patch make SPIRV writer support emitting f16 types, f16 literals, f16 constructor and convertor. Unittests are also implemented. Currently SPIRV writer will require 4 capabilities in generated SPIRV: `Float16`, `UniformAndStorageBuffer16BitAccess`, `StorageBuffer16BitAccess`, and `storageInputOutput16`. Bug: tint:1473, tint:1502 Change-Id: Ia1af04f1f4a02bf1b1c2599a5d89791854eabc16 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/95920 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Ben Clayton <bclayton@google.com> Commit-Queue: Zhaoming Jiang <zhaoming.jiang@intel.com>
This commit is contained in:
parent
edf650caad
commit
66d4f6e6fb
@ -379,22 +379,18 @@ void Builder::push_extension(const char* extension) {
|
||||
}
|
||||
|
||||
bool Builder::GenerateExtension(ast::Extension extension) {
|
||||
/*
|
||||
For each supported extension, push corresponding capability into the builder.
|
||||
For example:
|
||||
if (kind == ast::Extension::Kind::kF16) {
|
||||
push_capability(SpvCapabilityFloat16);
|
||||
push_capability(SpvCapabilityUniformAndStorageBuffer16BitAccess);
|
||||
push_capability(SpvCapabilityStorageBuffer16BitAccess);
|
||||
push_capability(SpvCapabilityStorageInputOutput16);
|
||||
}
|
||||
*/
|
||||
switch (extension) {
|
||||
case ast::Extension::kChromiumExperimentalDP4a:
|
||||
push_extension("SPV_KHR_integer_dot_product");
|
||||
push_capability(SpvCapabilityDotProductKHR);
|
||||
push_capability(SpvCapabilityDotProductInput4x8BitPackedKHR);
|
||||
break;
|
||||
case ast::Extension::kF16:
|
||||
push_capability(SpvCapabilityFloat16);
|
||||
push_capability(SpvCapabilityUniformAndStorageBuffer16BitAccess);
|
||||
push_capability(SpvCapabilityStorageBuffer16BitAccess);
|
||||
push_capability(SpvCapabilityStorageInputOutput16);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@ -1354,6 +1350,9 @@ uint32_t Builder::GenerateTypeConstructorOrConversion(const sem::Call* call,
|
||||
if (result_type->Is<sem::F32>()) {
|
||||
return GenerateConstantIfNeeded(ScalarConstant::F32(0).AsSpecOp(constant_id));
|
||||
}
|
||||
if (result_type->Is<sem::F16>()) {
|
||||
return GenerateConstantIfNeeded(ScalarConstant::F16(0).AsSpecOp(constant_id));
|
||||
}
|
||||
if (result_type->Is<sem::Bool>()) {
|
||||
return GenerateConstantIfNeeded(ScalarConstant::Bool(false).AsSpecOp(constant_id));
|
||||
}
|
||||
@ -1560,22 +1559,23 @@ uint32_t Builder::GenerateCastOrCopyOrPassthrough(const sem::Type* to_type,
|
||||
auto* from_type = TypeOf(from_expr)->UnwrapRef();
|
||||
|
||||
spv::Op op = spv::Op::OpNop;
|
||||
if ((from_type->Is<sem::I32>() && to_type->Is<sem::F32>()) ||
|
||||
if ((from_type->Is<sem::I32>() && to_type->is_float_scalar()) ||
|
||||
(from_type->is_signed_integer_vector() && to_type->is_float_vector())) {
|
||||
op = spv::Op::OpConvertSToF;
|
||||
} else if ((from_type->Is<sem::U32>() && to_type->Is<sem::F32>()) ||
|
||||
} else if ((from_type->Is<sem::U32>() && to_type->is_float_scalar()) ||
|
||||
(from_type->is_unsigned_integer_vector() && to_type->is_float_vector())) {
|
||||
op = spv::Op::OpConvertUToF;
|
||||
} else if ((from_type->Is<sem::F32>() && to_type->Is<sem::I32>()) ||
|
||||
} else if ((from_type->is_float_scalar() && to_type->Is<sem::I32>()) ||
|
||||
(from_type->is_float_vector() && to_type->is_signed_integer_vector())) {
|
||||
op = spv::Op::OpConvertFToS;
|
||||
} else if ((from_type->Is<sem::F32>() && to_type->Is<sem::U32>()) ||
|
||||
} else if ((from_type->is_float_scalar() && to_type->Is<sem::U32>()) ||
|
||||
(from_type->is_float_vector() && to_type->is_unsigned_integer_vector())) {
|
||||
op = spv::Op::OpConvertFToU;
|
||||
} else if ((from_type->Is<sem::Bool>() && to_type->Is<sem::Bool>()) ||
|
||||
(from_type->Is<sem::U32>() && to_type->Is<sem::U32>()) ||
|
||||
(from_type->Is<sem::I32>() && to_type->Is<sem::I32>()) ||
|
||||
(from_type->Is<sem::F32>() && to_type->Is<sem::F32>()) ||
|
||||
(from_type->Is<sem::F16>() && to_type->Is<sem::F16>()) ||
|
||||
(from_type->Is<sem::Vector>() && (from_type == to_type))) {
|
||||
return val_id;
|
||||
} else if ((from_type->Is<sem::I32>() && to_type->Is<sem::U32>()) ||
|
||||
@ -1608,6 +1608,9 @@ uint32_t Builder::GenerateCastOrCopyOrPassthrough(const sem::Type* to_type,
|
||||
if (to_elem_type->Is<sem::F32>()) {
|
||||
zero_id = GenerateConstantIfNeeded(ScalarConstant::F32(0));
|
||||
one_id = GenerateConstantIfNeeded(ScalarConstant::F32(1));
|
||||
} else if (to_elem_type->Is<sem::F16>()) {
|
||||
zero_id = GenerateConstantIfNeeded(ScalarConstant::F16(0));
|
||||
one_id = GenerateConstantIfNeeded(ScalarConstant::F16(1));
|
||||
} else if (to_elem_type->Is<sem::U32>()) {
|
||||
zero_id = GenerateConstantIfNeeded(ScalarConstant::U32(0));
|
||||
one_id = GenerateConstantIfNeeded(ScalarConstant::U32(1));
|
||||
@ -1691,7 +1694,9 @@ uint32_t Builder::GenerateLiteralIfNeeded(const ast::Variable* var,
|
||||
constant.value.f32 = static_cast<float>(f->value);
|
||||
return;
|
||||
case ast::FloatLiteralExpression::Suffix::kH:
|
||||
error_ = "Type f16 is not completely implemented yet";
|
||||
constant.kind = ScalarConstant::Kind::kF16;
|
||||
constant.value.f16 = {f16(static_cast<float>(f->value)).BitsRepresentation()};
|
||||
return;
|
||||
}
|
||||
},
|
||||
[&](Default) { error_ = "unknown literal type"; });
|
||||
@ -1750,6 +1755,10 @@ uint32_t Builder::GenerateConstantIfNeeded(const sem::Constant* constant) {
|
||||
auto val = constant->As<f32>();
|
||||
return GenerateConstantIfNeeded(ScalarConstant::F32(val.value));
|
||||
},
|
||||
[&](const sem::F16*) {
|
||||
auto val = constant->As<f16>();
|
||||
return GenerateConstantIfNeeded(ScalarConstant::F16(val.value));
|
||||
},
|
||||
[&](const sem::I32*) {
|
||||
auto val = constant->As<i32>();
|
||||
return GenerateConstantIfNeeded(ScalarConstant::I32(val.value));
|
||||
@ -1788,6 +1797,10 @@ uint32_t Builder::GenerateConstantIfNeeded(const ScalarConstant& constant) {
|
||||
type_id = GenerateTypeIfNeeded(builder_.create<sem::F32>());
|
||||
break;
|
||||
}
|
||||
case ScalarConstant::Kind::kF16: {
|
||||
type_id = GenerateTypeIfNeeded(builder_.create<sem::F16>());
|
||||
break;
|
||||
}
|
||||
case ScalarConstant::Kind::kBool: {
|
||||
type_id = GenerateTypeIfNeeded(builder_.create<sem::Bool>());
|
||||
break;
|
||||
@ -1822,6 +1835,12 @@ uint32_t Builder::GenerateConstantIfNeeded(const ScalarConstant& constant) {
|
||||
{Operand(type_id), result, Operand(constant.value.f32)});
|
||||
break;
|
||||
}
|
||||
case ScalarConstant::Kind::kF16: {
|
||||
push_type(
|
||||
constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
|
||||
{Operand(type_id), result, U32Operand(constant.value.f16.bits_representation)});
|
||||
break;
|
||||
}
|
||||
case ScalarConstant::Kind::kBool: {
|
||||
if (constant.value.b) {
|
||||
push_type(
|
||||
@ -3795,9 +3814,8 @@ uint32_t Builder::GenerateTypeIfNeeded(const sem::Type* type) {
|
||||
return true;
|
||||
},
|
||||
[&](const sem::F16*) {
|
||||
// Should be `push_type(spv::Op::OpTypeFloat, {result, Operand(16u)});`
|
||||
error_ = "Type f16 is not completely implemented yet.";
|
||||
return false;
|
||||
push_type(spv::Op::OpTypeFloat, {result, Operand(16u)});
|
||||
return true;
|
||||
},
|
||||
[&](const sem::I32*) {
|
||||
push_type(spv::Op::OpTypeInt, {result, Operand(32u), Operand(1u)});
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -115,6 +115,36 @@ TEST_F(BuilderTest, GlobalConst_Vec_Constructor) {
|
||||
Validate(b);
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, GlobalConst_Vec_F16_Constructor) {
|
||||
// const c = vec3<f16>(1h, 2h, 3h);
|
||||
// var v = c;
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* c = GlobalConst("c", nullptr, vec3<f16>(1_h, 2_h, 3_h));
|
||||
GlobalVar("v", nullptr, ast::StorageClass::kPrivate, Expr(c));
|
||||
|
||||
spirv::Builder& b = SanitizeAndBuild();
|
||||
|
||||
ASSERT_TRUE(b.Build()) << b.error();
|
||||
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 16
|
||||
%1 = OpTypeVector %2 3
|
||||
%3 = OpConstant %2 0x1p+0
|
||||
%4 = OpConstant %2 0x1p+1
|
||||
%5 = OpConstant %2 0x1.8p+1
|
||||
%6 = OpConstantComposite %1 %3 %4 %5
|
||||
%8 = OpTypePointer Private %1
|
||||
%7 = OpVariable %8 Private %6
|
||||
%10 = OpTypeVoid
|
||||
%9 = OpTypeFunction %10
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].variables()), R"()");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), R"(OpReturn
|
||||
)");
|
||||
|
||||
Validate(b);
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, GlobalConst_Vec_AInt_Constructor) {
|
||||
// const c = vec3(1, 2, 3);
|
||||
// var v = c;
|
||||
|
@ -163,4 +163,39 @@ TEST_F(BuilderTest, Literal_F32_Dedup) {
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Literal_F16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* i = create<ast::FloatLiteralExpression>(23.245, ast::FloatLiteralExpression::Suffix::kH);
|
||||
WrapInFunction(i);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
|
||||
auto id = b.GenerateLiteralIfNeeded(nullptr, i);
|
||||
ASSERT_FALSE(b.has_error()) << b.error();
|
||||
EXPECT_EQ(2u, id);
|
||||
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 16
|
||||
%2 = OpConstant %1 0x1.73cp+4
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Literal_F16_Dedup) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* i1 = create<ast::FloatLiteralExpression>(23.245, ast::FloatLiteralExpression::Suffix::kH);
|
||||
auto* i2 = create<ast::FloatLiteralExpression>(23.245, ast::FloatLiteralExpression::Suffix::kH);
|
||||
WrapInFunction(i1, i2);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
|
||||
ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, i1), 0u);
|
||||
ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, i2), 0u);
|
||||
ASSERT_FALSE(b.has_error()) << b.error();
|
||||
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 16
|
||||
%2 = OpConstant %1 0x1.73cp+4
|
||||
)");
|
||||
}
|
||||
|
||||
} // namespace tint::writer::spirv
|
||||
|
@ -175,6 +175,34 @@ TEST_F(BuilderTest_Type, ReturnsGeneratedF32) {
|
||||
ASSERT_FALSE(b.has_error()) << b.error();
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest_Type, GenerateF16) {
|
||||
auto* f16 = create<sem::F16>();
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
|
||||
auto id = b.GenerateTypeIfNeeded(f16);
|
||||
ASSERT_FALSE(b.has_error()) << b.error();
|
||||
EXPECT_EQ(id, 1u);
|
||||
|
||||
ASSERT_EQ(b.types().size(), 1u);
|
||||
EXPECT_EQ(DumpInstruction(b.types()[0]), R"(%1 = OpTypeFloat 16
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest_Type, ReturnsGeneratedF16) {
|
||||
auto* f16 = create<sem::F16>();
|
||||
auto* i32 = create<sem::I32>();
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
|
||||
EXPECT_EQ(b.GenerateTypeIfNeeded(f16), 1u);
|
||||
ASSERT_FALSE(b.has_error()) << b.error();
|
||||
EXPECT_EQ(b.GenerateTypeIfNeeded(i32), 2u);
|
||||
ASSERT_FALSE(b.has_error()) << b.error();
|
||||
EXPECT_EQ(b.GenerateTypeIfNeeded(f16), 1u);
|
||||
ASSERT_FALSE(b.has_error()) << b.error();
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest_Type, GenerateI32) {
|
||||
auto* i32 = create<sem::I32>();
|
||||
|
||||
@ -236,6 +264,39 @@ TEST_F(BuilderTest_Type, ReturnsGeneratedMatrix) {
|
||||
ASSERT_FALSE(b.has_error()) << b.error();
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest_Type, GenerateF16Matrix) {
|
||||
auto* f16 = create<sem::F16>();
|
||||
auto* vec3 = create<sem::Vector>(f16, 3u);
|
||||
auto* mat2x3 = create<sem::Matrix>(vec3, 2u);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
|
||||
auto id = b.GenerateTypeIfNeeded(mat2x3);
|
||||
ASSERT_FALSE(b.has_error()) << b.error();
|
||||
EXPECT_EQ(id, 1u);
|
||||
|
||||
EXPECT_EQ(b.types().size(), 3u);
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 16
|
||||
%2 = OpTypeVector %3 3
|
||||
%1 = OpTypeMatrix %2 2
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest_Type, ReturnsGeneratedF16Matrix) {
|
||||
auto* f16 = create<sem::F16>();
|
||||
auto* col = create<sem::Vector>(f16, 4u);
|
||||
auto* mat = create<sem::Matrix>(col, 3u);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
|
||||
EXPECT_EQ(b.GenerateTypeIfNeeded(mat), 1u);
|
||||
ASSERT_FALSE(b.has_error()) << b.error();
|
||||
EXPECT_EQ(b.GenerateTypeIfNeeded(f16), 3u);
|
||||
ASSERT_FALSE(b.has_error()) << b.error();
|
||||
EXPECT_EQ(b.GenerateTypeIfNeeded(mat), 1u);
|
||||
ASSERT_FALSE(b.has_error()) << b.error();
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest_Type, GeneratePtr) {
|
||||
auto* i32 = create<sem::I32>();
|
||||
auto* ptr = create<sem::Pointer>(i32, ast::StorageClass::kOutput, ast::Access::kReadWrite);
|
||||
|
@ -20,6 +20,7 @@
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
|
||||
#include "src/tint/number.h"
|
||||
#include "src/tint/utils/hash.h"
|
||||
|
||||
// Forward declarations
|
||||
@ -31,6 +32,12 @@ namespace tint::writer::spirv {
|
||||
|
||||
/// ScalarConstant represents a scalar constant value
|
||||
struct ScalarConstant {
|
||||
/// The struct type to hold the bits representation of f16 in the Value union
|
||||
struct F16 {
|
||||
/// The 16 bits representation of the f16, stored as uint16_t
|
||||
uint16_t bits_representation;
|
||||
};
|
||||
|
||||
/// The constant value
|
||||
union Value {
|
||||
/// The value as a bool
|
||||
@ -41,6 +48,8 @@ struct ScalarConstant {
|
||||
int32_t i32;
|
||||
/// The value as a float
|
||||
float f32;
|
||||
/// The value as bits representation of a f16
|
||||
F16 f16;
|
||||
|
||||
/// The value that is wide enough to encompass all other types (including
|
||||
/// future 64-bit data types).
|
||||
@ -48,7 +57,7 @@ struct ScalarConstant {
|
||||
};
|
||||
|
||||
/// The kind of constant
|
||||
enum class Kind { kBool, kU32, kI32, kF32 };
|
||||
enum class Kind { kBool, kU32, kI32, kF32, kF16 };
|
||||
|
||||
/// Constructor
|
||||
inline ScalarConstant() { value.u64 = 0; }
|
||||
@ -72,7 +81,7 @@ struct ScalarConstant {
|
||||
}
|
||||
|
||||
/// @param value the value of the constant
|
||||
/// @returns a new ScalarConstant with the provided value and kind Kind::kI32
|
||||
/// @returns a new ScalarConstant with the provided value and kind Kind::kF32
|
||||
static inline ScalarConstant F32(float value) {
|
||||
ScalarConstant c;
|
||||
c.value.f32 = value;
|
||||
@ -80,6 +89,15 @@ struct ScalarConstant {
|
||||
return c;
|
||||
}
|
||||
|
||||
/// @param value the value of the constant
|
||||
/// @returns a new ScalarConstant with the provided value and kind Kind::kF16
|
||||
static inline ScalarConstant F16(f16::type value) {
|
||||
ScalarConstant c;
|
||||
c.value.f16 = {f16(value).BitsRepresentation()};
|
||||
c.kind = Kind::kF16;
|
||||
return c;
|
||||
}
|
||||
|
||||
/// @param value the value of the constant
|
||||
/// @returns a new ScalarConstant with the provided value and kind Kind::kBool
|
||||
static inline ScalarConstant Bool(bool value) {
|
||||
|
@ -52,5 +52,12 @@ TEST_F(SpirvScalarConstantTest, U32) {
|
||||
EXPECT_EQ(c.kind, ScalarConstant::Kind::kU32);
|
||||
}
|
||||
|
||||
TEST_F(SpirvScalarConstantTest, F16) {
|
||||
auto c = ScalarConstant::F16(123.456f);
|
||||
// 123.456f will be quantized to f16 123.4375h, bit pattern 0x57b7
|
||||
EXPECT_EQ(c.value.f16.bits_representation, 0x57b7u);
|
||||
EXPECT_EQ(c.kind, ScalarConstant::Kind::kF16);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tint::writer::spirv
|
||||
|
@ -4,6 +4,10 @@
|
||||
; Bound: 19
|
||||
; Schema: 0
|
||||
OpCapability Shader
|
||||
OpCapability Float16
|
||||
OpCapability UniformAndStorageBuffer16BitAccess
|
||||
OpCapability StorageBuffer16BitAccess
|
||||
OpCapability StorageInputOutput16
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint Fragment %main "main" %value
|
||||
OpExecutionMode %main OriginUpperLeft
|
||||
|
@ -4,6 +4,10 @@
|
||||
; Bound: 19
|
||||
; Schema: 0
|
||||
OpCapability Shader
|
||||
OpCapability Float16
|
||||
OpCapability UniformAndStorageBuffer16BitAccess
|
||||
OpCapability StorageBuffer16BitAccess
|
||||
OpCapability StorageInputOutput16
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint Fragment %main "main" %value
|
||||
OpExecutionMode %main OriginUpperLeft
|
||||
|
Loading…
x
Reference in New Issue
Block a user