Tint&Dawn: Enable f16 override

This CL enable using f16 override, and also fix related tests in Dawn
and Tint.

Bug: tint:1473, tint:1502
Change-Id: I8336770e8a73e5023c1aba224b7b5f21692fbaa6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/124544
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Zhaoming Jiang <zhaoming.jiang@intel.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
This commit is contained in:
Zhaoming Jiang 2023-03-20 10:32:45 +00:00 committed by Dawn LUCI CQ
parent a0c34124d1
commit 6af073cecc
11 changed files with 227 additions and 63 deletions

View File

@ -25,6 +25,14 @@
#include "dawn/native/PipelineLayout.h"
#include "dawn/native/ShaderModule.h"
namespace {
bool IsDoubleValueRepresentableAsF16(double value) {
constexpr double kLowestF16 = -65504.0;
constexpr double kMaxF16 = 65504.0;
return kLowestF16 <= value && value <= kMaxF16;
}
} // namespace
namespace dawn::native {
MaybeError ValidateProgrammableStage(DeviceBase* device,
const ShaderModuleBase* module,
@ -80,6 +88,12 @@ MaybeError ValidateProgrammableStage(DeviceBase* device,
"representable in type (%s)",
constants[i].key, constants[i].value, "f32");
break;
case EntryPointMetadata::Override::Type::Float16:
DAWN_INVALID_IF(!IsDoubleValueRepresentableAsF16(constants[i].value),
"Pipeline overridable constant \"%s\" with value (%f) is not "
"representable in type (%s)",
constants[i].key, constants[i].value, "f16");
break;
case EntryPointMetadata::Override::Type::Int32:
DAWN_INVALID_IF(!IsDoubleValueRepresentable<int32_t>(constants[i].value),
"Pipeline overridable constant \"%s\" with value (%f) is not "

View File

@ -288,6 +288,8 @@ EntryPointMetadata::Override::Type FromTintOverrideType(tint::inspector::Overrid
return EntryPointMetadata::Override::Type::Boolean;
case tint::inspector::Override::Type::kFloat32:
return EntryPointMetadata::Override::Type::Float32;
case tint::inspector::Override::Type::kFloat16:
return EntryPointMetadata::Override::Type::Float16;
case tint::inspector::Override::Type::kInt32:
return EntryPointMetadata::Override::Type::Int32;
case tint::inspector::Override::Type::kUint32:

View File

@ -221,7 +221,7 @@ struct EntryPointMetadata {
// Match tint::inspector::Override::Type
// Bool is defined as a macro on linux X11 and cannot compile
enum class Type { Boolean, Float32, Uint32, Int32 } type;
enum class Type { Boolean, Float32, Uint32, Int32, Float16 } type;
// If the constant doesn't not have an initializer in the shader
// Then it is required for the pipeline stage to have a constant record to initialize a

View File

@ -21,8 +21,32 @@
class ComputePipelineOverridableConstantsValidationTest : public ValidationTest {
protected:
WGPUDevice CreateTestDevice(dawn::native::Adapter dawnAdapter) override {
std::vector<const char*> enabledToggles;
std::vector<const char*> disabledToggles;
disabledToggles.push_back("disallow_unsafe_apis");
wgpu::DawnTogglesDescriptor deviceTogglesDesc;
deviceTogglesDesc.enabledToggles = enabledToggles.data();
deviceTogglesDesc.enabledTogglesCount = enabledToggles.size();
deviceTogglesDesc.disabledToggles = disabledToggles.data();
deviceTogglesDesc.disabledTogglesCount = disabledToggles.size();
const wgpu::FeatureName requiredFeatures[] = {wgpu::FeatureName::ShaderF16};
wgpu::DeviceDescriptor deviceDescriptor;
deviceDescriptor.nextInChain = &deviceTogglesDesc;
deviceDescriptor.requiredFeatures = requiredFeatures;
deviceDescriptor.requiredFeaturesCount = 1;
return dawnAdapter.CreateDevice(&deviceDescriptor);
}
void SetUpShadersWithDefaultValueConstants() {
computeModule = utils::CreateShaderModule(device, R"(
enable f16;
override c0: bool = true; // type: bool
override c1: bool = false; // default override
override c2: f32 = 0.0; // type: float32
@ -33,7 +57,10 @@ override c6: i32 = 0; // default override
override c7: i32 = 7; // default
override c8: u32 = 0u; // type: uint32
override c9: u32 = 0u; // default override
@id(1000) override c10: u32 = 10u; // default
override c10: u32 = 10u; // default
override c11: f16 = 0.0h; // type: float16
override c12: f16 = 0.0h; // default override
@id(1000) override c13: f16 = 4.0h; // default
@compute @workgroup_size(1) fn main() {
// make sure the overridable constants are not optimized out
@ -48,11 +75,16 @@ override c9: u32 = 0u; // default override
_ = u32(c8);
_ = u32(c9);
_ = u32(c10);
_ = u32(c11);
_ = u32(c12);
_ = u32(c13);
})");
}
void SetUpShadersWithUninitializedConstants() {
computeModule = utils::CreateShaderModule(device, R"(
enable f16;
override c0: bool; // type: bool
override c1: bool = false; // default override
override c2: f32; // type: float32
@ -63,7 +95,10 @@ override c6: i32 = 0; // default override
override c7: i32 = 7; // default
override c8: u32; // type: uint32
override c9: u32 = 0u; // default override
@id(1000) override c10: u32 = 10u; // default
override c10: u32 = 10u; // default
override c11: f16; // type: float16
override c12: f16 = 0.0h; // default override
@id(1000) override c13: f16 = 4.0h; // default
@compute @workgroup_size(1) fn main() {
// make sure the overridable constants are not optimized out
@ -78,6 +113,9 @@ override c9: u32 = 0u; // default override
_ = u32(c8);
_ = u32(c9);
_ = u32(c10);
_ = u32(c11);
_ = u32(c12);
_ = u32(c13);
})");
}
@ -94,6 +132,11 @@ override c9: u32 = 0u; // default override
wgpu::Buffer buffer;
};
// Basic constants lookup tests
TEST_F(ComputePipelineOverridableConstantsValidationTest, CreateShaderWithOverride) {
SetUpShadersWithUninitializedConstants();
}
// Basic constants lookup tests
TEST_F(ComputePipelineOverridableConstantsValidationTest, ConstantsIdentifierLookUp) {
SetUpShadersWithDefaultValueConstants();
@ -122,7 +165,7 @@ TEST_F(ComputePipelineOverridableConstantsValidationTest, ConstantsIdentifierLoo
}
{
// Error: c10 already has a constant numeric id specified
std::vector<wgpu::ConstantEntry> constants{{nullptr, "c10", 0}};
std::vector<wgpu::ConstantEntry> constants{{nullptr, "c13", 0}};
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
}
{
@ -152,16 +195,15 @@ TEST_F(ComputePipelineOverridableConstantsValidationTest, UninitializedConstants
{nullptr, "c2", 1},
// c5 is missing
{nullptr, "c8", 1},
{nullptr, "c11", 1},
};
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
}
{
// Valid: all constants initialized
std::vector<wgpu::ConstantEntry> constants{
{nullptr, "c0", false},
{nullptr, "c2", 1},
{nullptr, "c5", 1},
{nullptr, "c8", 1},
{nullptr, "c0", false}, {nullptr, "c2", 1}, {nullptr, "c5", 1},
{nullptr, "c8", 1}, {nullptr, "c11", 1},
};
TestCreatePipeline(constants);
}
@ -169,7 +211,7 @@ TEST_F(ComputePipelineOverridableConstantsValidationTest, UninitializedConstants
// Error: duplicate initializations
std::vector<wgpu::ConstantEntry> constants{
{nullptr, "c0", false}, {nullptr, "c2", 1}, {nullptr, "c5", 1},
{nullptr, "c8", 1}, {nullptr, "c2", 2},
{nullptr, "c8", 1}, {nullptr, "c11", 1}, {nullptr, "c2", 2},
};
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
}
@ -214,7 +256,7 @@ TEST_F(ComputePipelineOverridableConstantsValidationTest, ConstantsIdentifierUni
}
{
// Error: constant with numeric id cannot be referenced with variable name
std::vector<wgpu::ConstantEntry> constants{{nullptr, "c10", 0}};
std::vector<wgpu::ConstantEntry> constants{{nullptr, "c13", 0}};
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
}
}
@ -267,6 +309,34 @@ TEST_F(ComputePipelineOverridableConstantsValidationTest, OutofRangeValue) {
{nullptr, "c3", std::numeric_limits<double>::max()}};
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
}
{
// Valid: max f32 representable value
std::vector<wgpu::ConstantEntry> constants{
{nullptr, "c3", std::numeric_limits<float>::max()}};
TestCreatePipeline(constants);
}
{
// Error: one ULP higher than max f32 representable value
std::vector<wgpu::ConstantEntry> constants{
{nullptr, "c3",
std::nextafter<double>(std::numeric_limits<float>::max(),
std::numeric_limits<double>::max())}};
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
}
{
// Valid: lowest f32 representable value
std::vector<wgpu::ConstantEntry> constants{
{nullptr, "c3", std::numeric_limits<float>::lowest()}};
TestCreatePipeline(constants);
}
{
// Error: one ULP lower than lowest f32 representable value
std::vector<wgpu::ConstantEntry> constants{
{nullptr, "c3",
std::nextafter<double>(std::numeric_limits<float>::lowest(),
std::numeric_limits<double>::lowest())}};
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
}
{
// Error: i32 out of range
std::vector<wgpu::ConstantEntry> constants{
@ -291,4 +361,27 @@ TEST_F(ComputePipelineOverridableConstantsValidationTest, OutofRangeValue) {
{nullptr, "c0", static_cast<double>(std::numeric_limits<int32_t>::max()) + 1.0}};
TestCreatePipeline(constants);
}
{
// Valid: max f16 representable value
std::vector<wgpu::ConstantEntry> constants{{nullptr, "c11", 65504.0}};
TestCreatePipeline(constants);
}
{
// Error: one ULP higher than max f16 representable value
std::vector<wgpu::ConstantEntry> constants{
{nullptr, "c11", std::nextafter<double>(65504.0, std::numeric_limits<double>::max())}};
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
}
{
// Valid: lowest f16 representable value
std::vector<wgpu::ConstantEntry> constants{{nullptr, "c11", -65504.0}};
TestCreatePipeline(constants);
}
{
// Error: one ULP lower than lowest f16 representable value
std::vector<wgpu::ConstantEntry> constants{
{nullptr, "c11",
std::nextafter<double>(-65504.0, std::numeric_limits<double>::lowest())}};
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
}
}

View File

@ -505,6 +505,8 @@ std::string OverrideTypeToString(tint::inspector::Override::Type type) {
return "bool";
case tint::inspector::Override::Type::kFloat32:
return "f32";
case tint::inspector::Override::Type::kFloat16:
return "f16";
case tint::inspector::Override::Type::kUint32:
return "u32";
case tint::inspector::Override::Type::kInt32:

View File

@ -92,6 +92,7 @@ struct Override {
kFloat32,
kUint32,
kInt32,
kFloat16,
};
/// Type of the scalar

View File

@ -204,7 +204,11 @@ EntryPoint Inspector::GetEntryPoint(const tint::ast::Function* func) {
if (type->is_bool_scalar_or_vector()) {
override.type = Override::Type::kBool;
} else if (type->is_float_scalar()) {
if (type->Is<type::F16>()) {
override.type = Override::Type::kFloat16;
} else {
override.type = Override::Type::kFloat32;
}
} else if (type->is_signed_integer_scalar()) {
override.type = Override::Type::kInt32;
} else if (type->is_unsigned_integer_scalar()) {
@ -270,6 +274,10 @@ std::map<OverrideId, Scalar> Inspector::GetOverrideDefaultValues() {
[&](const type::I32*) { return Scalar(value->ValueAs<i32>()); },
[&](const type::U32*) { return Scalar(value->ValueAs<u32>()); },
[&](const type::F32*) { return Scalar(value->ValueAs<f32>()); },
[&](const type::F16*) {
// Default value of f16 override is also stored as float scalar.
return Scalar(static_cast<float>(value->ValueAs<f16>()));
},
[&](const type::Bool*) { return Scalar(value->ValueAs<bool>()); });
continue;
}

View File

@ -908,18 +908,23 @@ TEST_F(InspectorGetEntryPointTest, OverrideReferencedByArraySizeViaAlias) {
}
TEST_F(InspectorGetEntryPointTest, OverrideTypes) {
Enable(builtin::Extension::kF16);
Override("bool_var", ty.bool_());
Override("float_var", ty.f32());
Override("u32_var", ty.u32());
Override("i32_var", ty.i32());
Override("f16_var", ty.f16());
MakePlainGlobalReferenceBodyFunction("bool_func", "bool_var", ty.bool_(), utils::Empty);
MakePlainGlobalReferenceBodyFunction("float_func", "float_var", ty.f32(), utils::Empty);
MakePlainGlobalReferenceBodyFunction("u32_func", "u32_var", ty.u32(), utils::Empty);
MakePlainGlobalReferenceBodyFunction("i32_func", "i32_var", ty.i32(), utils::Empty);
MakePlainGlobalReferenceBodyFunction("f16_func", "f16_var", ty.f16(), utils::Empty);
MakeCallerBodyFunction(
"ep_func", utils::Vector{std::string("bool_func"), "float_func", "u32_func", "i32_func"},
"ep_func",
utils::Vector{std::string("bool_func"), "float_func", "u32_func", "i32_func", "f16_func"},
utils::Vector{
Stage(ast::PipelineStage::kCompute),
WorkgroupSize(1_i),
@ -930,7 +935,7 @@ TEST_F(InspectorGetEntryPointTest, OverrideTypes) {
auto result = inspector.GetEntryPoints();
ASSERT_EQ(1u, result.size());
ASSERT_EQ(4u, result[0].overrides.size());
ASSERT_EQ(5u, result[0].overrides.size());
EXPECT_EQ("bool_var", result[0].overrides[0].name);
EXPECT_EQ(inspector::Override::Type::kBool, result[0].overrides[0].type);
EXPECT_EQ("float_var", result[0].overrides[1].name);
@ -939,6 +944,8 @@ TEST_F(InspectorGetEntryPointTest, OverrideTypes) {
EXPECT_EQ(inspector::Override::Type::kUint32, result[0].overrides[2].type);
EXPECT_EQ("i32_var", result[0].overrides[3].name);
EXPECT_EQ(inspector::Override::Type::kInt32, result[0].overrides[3].type);
EXPECT_EQ("f16_var", result[0].overrides[4].name);
EXPECT_EQ(inspector::Override::Type::kFloat16, result[0].overrides[4].type);
}
TEST_F(InspectorGetEntryPointTest, OverrideInitialized) {
@ -1572,7 +1579,7 @@ TEST_F(InspectorGetOverrideDefaultValuesTest, I32) {
EXPECT_EQ(100, result[OverrideId{6000}].AsI32());
}
TEST_F(InspectorGetOverrideDefaultValuesTest, Float) {
TEST_F(InspectorGetOverrideDefaultValuesTest, F32) {
Override("a", ty.f32(), Id(1_a));
Override("b", ty.f32(), Expr(0_f), Id(20_a));
Override("c", ty.f32(), Expr(-10_f), Id(300_a));
@ -1609,6 +1616,46 @@ TEST_F(InspectorGetOverrideDefaultValuesTest, Float) {
EXPECT_FLOAT_EQ(150.0f, result[OverrideId{6000}].AsFloat());
}
TEST_F(InspectorGetOverrideDefaultValuesTest, F16) {
Enable(builtin::Extension::kF16);
Override("a", ty.f16(), Id(1_a));
Override("b", ty.f16(), Expr(0_h), Id(20_a));
Override("c", ty.f16(), Expr(-10_h), Id(300_a));
Override("d", Expr(15_h), Id(4000_a));
Override("3", Expr(42.0_h), Id(5000_a));
Override("e", ty.f16(), Mul(15_h, 10_a), Id(6000_a));
Inspector& inspector = Build();
auto result = inspector.GetOverrideDefaultValues();
ASSERT_EQ(6u, result.size());
ASSERT_TRUE(result.find(OverrideId{1}) != result.end());
EXPECT_TRUE(result[OverrideId{1}].IsNull());
ASSERT_TRUE(result.find(OverrideId{20}) != result.end());
// Default value of f16 override is also stored as float scalar.
EXPECT_TRUE(result[OverrideId{20}].IsFloat());
EXPECT_FLOAT_EQ(0.0f, result[OverrideId{20}].AsFloat());
ASSERT_TRUE(result.find(OverrideId{300}) != result.end());
EXPECT_TRUE(result[OverrideId{300}].IsFloat());
EXPECT_FLOAT_EQ(-10.0f, result[OverrideId{300}].AsFloat());
ASSERT_TRUE(result.find(OverrideId{4000}) != result.end());
EXPECT_TRUE(result[OverrideId{4000}].IsFloat());
EXPECT_FLOAT_EQ(15.0f, result[OverrideId{4000}].AsFloat());
ASSERT_TRUE(result.find(OverrideId{5000}) != result.end());
EXPECT_TRUE(result[OverrideId{5000}].IsFloat());
EXPECT_FLOAT_EQ(42.0f, result[OverrideId{5000}].AsFloat());
ASSERT_TRUE(result.find(OverrideId{6000}) != result.end());
EXPECT_TRUE(result[OverrideId{6000}].IsFloat());
EXPECT_FLOAT_EQ(150.0f, result[OverrideId{6000}].AsFloat());
}
TEST_F(InspectorGetConstantNameToIdMapTest, WithAndWithoutIds) {
Override("v1", ty.f32(), Id(1_a));
Override("v20", ty.f32(), Id(20_a));

View File

@ -66,10 +66,12 @@ TEST_F(ResolverOverrideTest, WithoutId) {
}
TEST_F(ResolverOverrideTest, WithAndWithoutIds) {
Enable(builtin::Extension::kF16);
auto* a = Override("a", ty.f32(), Expr(1_f));
auto* b = Override("b", ty.f32(), Expr(1_f));
auto* c = Override("c", ty.f32(), Expr(1_f), Id(2_u));
auto* d = Override("d", ty.f32(), Expr(1_f), Id(4_u));
auto* b = Override("b", ty.f16(), Expr(1_h));
auto* c = Override("c", ty.i32(), Expr(1_i), Id(2_u));
auto* d = Override("d", ty.u32(), Expr(1_u), Id(4_u));
auto* e = Override("e", ty.f32(), Expr(1_f));
auto* f = Override("f", ty.f32(), Expr(1_f), Id(1_u));
@ -102,16 +104,6 @@ TEST_F(ResolverOverrideTest, IdTooLarge) {
EXPECT_EQ(r()->error(), "12:34 error: @id value must be between 0 and 65535");
}
TEST_F(ResolverOverrideTest, F16_TemporallyBan) {
Enable(builtin::Extension::kF16);
Override(Source{{12, 34}}, "a", ty.f16(), Expr(1_h), Id(1_u));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: 'override' of type f16 is not implemented yet");
}
TEST_F(ResolverOverrideTest, TransitiveReferences_DirectUse) {
auto* a = Override("a", ty.f32());
auto* b = Override("b", ty.f32(), Expr(1_f));

View File

@ -788,11 +788,6 @@ bool Validator::Override(
return false;
}
if (storage_ty->Is<type::F16>()) {
AddError("'override' of type f16 is not implemented yet", decl->source);
return false;
}
return true;
}

View File

@ -84,15 +84,16 @@ fn main() -> @builtin(position) vec4<f32> {
TEST_F(SubstituteOverrideTest, ImplicitId) {
auto* src = R"(
enable f16;
override i_width: i32;
override i_height = 1i;
override f_width: f32;
override f_height = 1.f;
// TODO(crbug.com/tint/1473)
// override h_width: f16;
// override h_height = 1.h;
override h_width: f16;
override h_height = 1.h;
override b_width: bool;
override b_height = true;
@ -106,6 +107,8 @@ fn main() -> @builtin(position) vec4<f32> {
)";
auto* expect = R"(
enable f16;
const i_width : i32 = 42i;
const i_height = 11i;
@ -114,6 +117,10 @@ const f_width : f32 = 22.299999237060546875f;
const f_height = 12.3999996185302734375f;
const h_width : f16 = 9.3984375h;
const h_height = 3.3984375h;
const b_width : bool = true;
const b_height = false;
@ -131,10 +138,10 @@ fn main() -> @builtin(position) vec4<f32> {
cfg.map.insert({OverrideId{1}, 11.0});
cfg.map.insert({OverrideId{2}, 22.3});
cfg.map.insert({OverrideId{3}, 12.4});
// cfg.map.insert({OverrideId{4}, 9.4});
// cfg.map.insert({OverrideId{5}, 3.4});
cfg.map.insert({OverrideId{4}, 1.0});
cfg.map.insert({OverrideId{5}, 0.0});
cfg.map.insert({OverrideId{4}, 9.4});
cfg.map.insert({OverrideId{5}, 3.4});
cfg.map.insert({OverrideId{6}, 1.0});
cfg.map.insert({OverrideId{7}, 0.0});
DataMap data;
data.Add<SubstituteOverride::Config>(cfg);
@ -153,9 +160,8 @@ enable f16;
@id(1) override f_width: f32;
@id(9) override f_height = 1.f;
// TODO(crbug.com/tint/1473)
// @id(2) override h_width: f16;
// @id(8) override h_height = 1.h;
@id(2) override h_width: f16;
@id(8) override h_height = 1.h;
@id(3) override b_width: bool;
@id(7) override b_height = true;
@ -179,6 +185,10 @@ const f_width : f32 = 22.299999237060546875f;
const f_height = 12.3999996185302734375f;
const h_width : f16 = 9.3984375h;
const h_height = 3.3984375h;
const b_width : bool = true;
const b_height = false;