diff --git a/src/dawn/native/Pipeline.cpp b/src/dawn/native/Pipeline.cpp index 37d0928bad..cffeaa300f 100644 --- a/src/dawn/native/Pipeline.cpp +++ b/src/dawn/native/Pipeline.cpp @@ -25,6 +25,14 @@ #include "dawn/native/PipelineLayout.h" #include "dawn/native/ShaderModule.h" +namespace { +bool IsDoubleValueRepresentableAsF16(double value) { + constexpr double kLowestF16 = -65504.0; + constexpr double kMaxF16 = 65504.0; + return kLowestF16 <= value && value <= kMaxF16; +} +} // namespace + namespace dawn::native { MaybeError ValidateProgrammableStage(DeviceBase* device, const ShaderModuleBase* module, @@ -80,6 +88,12 @@ MaybeError ValidateProgrammableStage(DeviceBase* device, "representable in type (%s)", constants[i].key, constants[i].value, "f32"); break; + case EntryPointMetadata::Override::Type::Float16: + DAWN_INVALID_IF(!IsDoubleValueRepresentableAsF16(constants[i].value), + "Pipeline overridable constant \"%s\" with value (%f) is not " + "representable in type (%s)", + constants[i].key, constants[i].value, "f16"); + break; case EntryPointMetadata::Override::Type::Int32: DAWN_INVALID_IF(!IsDoubleValueRepresentable(constants[i].value), "Pipeline overridable constant \"%s\" with value (%f) is not " diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp index 4143457cc1..ae45a32d08 100644 --- a/src/dawn/native/ShaderModule.cpp +++ b/src/dawn/native/ShaderModule.cpp @@ -288,6 +288,8 @@ EntryPointMetadata::Override::Type FromTintOverrideType(tint::inspector::Overrid return EntryPointMetadata::Override::Type::Boolean; case tint::inspector::Override::Type::kFloat32: return EntryPointMetadata::Override::Type::Float32; + case tint::inspector::Override::Type::kFloat16: + return EntryPointMetadata::Override::Type::Float16; case tint::inspector::Override::Type::kInt32: return EntryPointMetadata::Override::Type::Int32; case tint::inspector::Override::Type::kUint32: diff --git a/src/dawn/native/ShaderModule.h b/src/dawn/native/ShaderModule.h index fd604a82f1..5d33cd22ce 100644 --- a/src/dawn/native/ShaderModule.h +++ b/src/dawn/native/ShaderModule.h @@ -221,7 +221,7 @@ struct EntryPointMetadata { // Match tint::inspector::Override::Type // Bool is defined as a macro on linux X11 and cannot compile - enum class Type { Boolean, Float32, Uint32, Int32 } type; + enum class Type { Boolean, Float32, Uint32, Int32, Float16 } type; // If the constant doesn't not have an initializer in the shader // Then it is required for the pipeline stage to have a constant record to initialize a diff --git a/src/dawn/tests/unittests/validation/OverridableConstantsValidationTests.cpp b/src/dawn/tests/unittests/validation/OverridableConstantsValidationTests.cpp index 542bb06742..8a518f6a73 100644 --- a/src/dawn/tests/unittests/validation/OverridableConstantsValidationTests.cpp +++ b/src/dawn/tests/unittests/validation/OverridableConstantsValidationTests.cpp @@ -21,19 +21,46 @@ class ComputePipelineOverridableConstantsValidationTest : public ValidationTest { protected: + WGPUDevice CreateTestDevice(dawn::native::Adapter dawnAdapter) override { + std::vector enabledToggles; + std::vector disabledToggles; + + disabledToggles.push_back("disallow_unsafe_apis"); + + wgpu::DawnTogglesDescriptor deviceTogglesDesc; + deviceTogglesDesc.enabledToggles = enabledToggles.data(); + deviceTogglesDesc.enabledTogglesCount = enabledToggles.size(); + deviceTogglesDesc.disabledToggles = disabledToggles.data(); + deviceTogglesDesc.disabledTogglesCount = disabledToggles.size(); + + const wgpu::FeatureName requiredFeatures[] = {wgpu::FeatureName::ShaderF16}; + + wgpu::DeviceDescriptor deviceDescriptor; + deviceDescriptor.nextInChain = &deviceTogglesDesc; + deviceDescriptor.requiredFeatures = requiredFeatures; + deviceDescriptor.requiredFeaturesCount = 1; + + return dawnAdapter.CreateDevice(&deviceDescriptor); + } + void SetUpShadersWithDefaultValueConstants() { computeModule = utils::CreateShaderModule(device, R"( -override c0: bool = true; // type: bool -override c1: bool = false; // default override -override c2: f32 = 0.0; // type: float32 -override c3: f32 = 0.0; // default override -override c4: f32 = 4.0; // default -override c5: i32 = 0; // type: int32 -override c6: i32 = 0; // default override -override c7: i32 = 7; // default -override c8: u32 = 0u; // type: uint32 -override c9: u32 = 0u; // default override -@id(1000) override c10: u32 = 10u; // default +enable f16; + +override c0: bool = true; // type: bool +override c1: bool = false; // default override +override c2: f32 = 0.0; // type: float32 +override c3: f32 = 0.0; // default override +override c4: f32 = 4.0; // default +override c5: i32 = 0; // type: int32 +override c6: i32 = 0; // default override +override c7: i32 = 7; // default +override c8: u32 = 0u; // type: uint32 +override c9: u32 = 0u; // default override +override c10: u32 = 10u; // default +override c11: f16 = 0.0h; // type: float16 +override c12: f16 = 0.0h; // default override +@id(1000) override c13: f16 = 4.0h; // default @compute @workgroup_size(1) fn main() { // make sure the overridable constants are not optimized out @@ -48,22 +75,30 @@ override c9: u32 = 0u; // default override _ = u32(c8); _ = u32(c9); _ = u32(c10); + _ = u32(c11); + _ = u32(c12); + _ = u32(c13); })"); } void SetUpShadersWithUninitializedConstants() { computeModule = utils::CreateShaderModule(device, R"( -override c0: bool; // type: bool -override c1: bool = false; // default override -override c2: f32; // type: float32 -override c3: f32 = 0.0; // default override -override c4: f32 = 4.0; // default -override c5: i32; // type: int32 -override c6: i32 = 0; // default override -override c7: i32 = 7; // default -override c8: u32; // type: uint32 -override c9: u32 = 0u; // default override -@id(1000) override c10: u32 = 10u; // default +enable f16; + +override c0: bool; // type: bool +override c1: bool = false; // default override +override c2: f32; // type: float32 +override c3: f32 = 0.0; // default override +override c4: f32 = 4.0; // default +override c5: i32; // type: int32 +override c6: i32 = 0; // default override +override c7: i32 = 7; // default +override c8: u32; // type: uint32 +override c9: u32 = 0u; // default override +override c10: u32 = 10u; // default +override c11: f16; // type: float16 +override c12: f16 = 0.0h; // default override +@id(1000) override c13: f16 = 4.0h; // default @compute @workgroup_size(1) fn main() { // make sure the overridable constants are not optimized out @@ -78,6 +113,9 @@ override c9: u32 = 0u; // default override _ = u32(c8); _ = u32(c9); _ = u32(c10); + _ = u32(c11); + _ = u32(c12); + _ = u32(c13); })"); } @@ -94,6 +132,11 @@ override c9: u32 = 0u; // default override wgpu::Buffer buffer; }; +// Basic constants lookup tests +TEST_F(ComputePipelineOverridableConstantsValidationTest, CreateShaderWithOverride) { + SetUpShadersWithUninitializedConstants(); +} + // Basic constants lookup tests TEST_F(ComputePipelineOverridableConstantsValidationTest, ConstantsIdentifierLookUp) { SetUpShadersWithDefaultValueConstants(); @@ -122,7 +165,7 @@ TEST_F(ComputePipelineOverridableConstantsValidationTest, ConstantsIdentifierLoo } { // Error: c10 already has a constant numeric id specified - std::vector constants{{nullptr, "c10", 0}}; + std::vector constants{{nullptr, "c13", 0}}; ASSERT_DEVICE_ERROR(TestCreatePipeline(constants)); } { @@ -152,24 +195,23 @@ TEST_F(ComputePipelineOverridableConstantsValidationTest, UninitializedConstants {nullptr, "c2", 1}, // c5 is missing {nullptr, "c8", 1}, + {nullptr, "c11", 1}, }; ASSERT_DEVICE_ERROR(TestCreatePipeline(constants)); } { // Valid: all constants initialized std::vector constants{ - {nullptr, "c0", false}, - {nullptr, "c2", 1}, - {nullptr, "c5", 1}, - {nullptr, "c8", 1}, + {nullptr, "c0", false}, {nullptr, "c2", 1}, {nullptr, "c5", 1}, + {nullptr, "c8", 1}, {nullptr, "c11", 1}, }; TestCreatePipeline(constants); } { // Error: duplicate initializations std::vector constants{ - {nullptr, "c0", false}, {nullptr, "c2", 1}, {nullptr, "c5", 1}, - {nullptr, "c8", 1}, {nullptr, "c2", 2}, + {nullptr, "c0", false}, {nullptr, "c2", 1}, {nullptr, "c5", 1}, + {nullptr, "c8", 1}, {nullptr, "c11", 1}, {nullptr, "c2", 2}, }; ASSERT_DEVICE_ERROR(TestCreatePipeline(constants)); } @@ -214,7 +256,7 @@ TEST_F(ComputePipelineOverridableConstantsValidationTest, ConstantsIdentifierUni } { // Error: constant with numeric id cannot be referenced with variable name - std::vector constants{{nullptr, "c10", 0}}; + std::vector constants{{nullptr, "c13", 0}}; ASSERT_DEVICE_ERROR(TestCreatePipeline(constants)); } } @@ -267,6 +309,34 @@ TEST_F(ComputePipelineOverridableConstantsValidationTest, OutofRangeValue) { {nullptr, "c3", std::numeric_limits::max()}}; ASSERT_DEVICE_ERROR(TestCreatePipeline(constants)); } + { + // Valid: max f32 representable value + std::vector constants{ + {nullptr, "c3", std::numeric_limits::max()}}; + TestCreatePipeline(constants); + } + { + // Error: one ULP higher than max f32 representable value + std::vector constants{ + {nullptr, "c3", + std::nextafter(std::numeric_limits::max(), + std::numeric_limits::max())}}; + ASSERT_DEVICE_ERROR(TestCreatePipeline(constants)); + } + { + // Valid: lowest f32 representable value + std::vector constants{ + {nullptr, "c3", std::numeric_limits::lowest()}}; + TestCreatePipeline(constants); + } + { + // Error: one ULP lower than lowest f32 representable value + std::vector constants{ + {nullptr, "c3", + std::nextafter(std::numeric_limits::lowest(), + std::numeric_limits::lowest())}}; + ASSERT_DEVICE_ERROR(TestCreatePipeline(constants)); + } { // Error: i32 out of range std::vector constants{ @@ -291,4 +361,27 @@ TEST_F(ComputePipelineOverridableConstantsValidationTest, OutofRangeValue) { {nullptr, "c0", static_cast(std::numeric_limits::max()) + 1.0}}; TestCreatePipeline(constants); } + { + // Valid: max f16 representable value + std::vector constants{{nullptr, "c11", 65504.0}}; + TestCreatePipeline(constants); + } + { + // Error: one ULP higher than max f16 representable value + std::vector constants{ + {nullptr, "c11", std::nextafter(65504.0, std::numeric_limits::max())}}; + ASSERT_DEVICE_ERROR(TestCreatePipeline(constants)); + } + { + // Valid: lowest f16 representable value + std::vector constants{{nullptr, "c11", -65504.0}}; + TestCreatePipeline(constants); + } + { + // Error: one ULP lower than lowest f16 representable value + std::vector constants{ + {nullptr, "c11", + std::nextafter(-65504.0, std::numeric_limits::lowest())}}; + ASSERT_DEVICE_ERROR(TestCreatePipeline(constants)); + } } diff --git a/src/tint/cmd/helper.cc b/src/tint/cmd/helper.cc index 49b970a583..6ff1e7a241 100644 --- a/src/tint/cmd/helper.cc +++ b/src/tint/cmd/helper.cc @@ -505,6 +505,8 @@ std::string OverrideTypeToString(tint::inspector::Override::Type type) { return "bool"; case tint::inspector::Override::Type::kFloat32: return "f32"; + case tint::inspector::Override::Type::kFloat16: + return "f16"; case tint::inspector::Override::Type::kUint32: return "u32"; case tint::inspector::Override::Type::kInt32: diff --git a/src/tint/inspector/entry_point.h b/src/tint/inspector/entry_point.h index fd17ba064c..8fd003c66f 100644 --- a/src/tint/inspector/entry_point.h +++ b/src/tint/inspector/entry_point.h @@ -92,6 +92,7 @@ struct Override { kFloat32, kUint32, kInt32, + kFloat16, }; /// Type of the scalar diff --git a/src/tint/inspector/inspector.cc b/src/tint/inspector/inspector.cc index 93e5c585e0..0b8b22c648 100644 --- a/src/tint/inspector/inspector.cc +++ b/src/tint/inspector/inspector.cc @@ -204,7 +204,11 @@ EntryPoint Inspector::GetEntryPoint(const tint::ast::Function* func) { if (type->is_bool_scalar_or_vector()) { override.type = Override::Type::kBool; } else if (type->is_float_scalar()) { - override.type = Override::Type::kFloat32; + if (type->Is()) { + override.type = Override::Type::kFloat16; + } else { + override.type = Override::Type::kFloat32; + } } else if (type->is_signed_integer_scalar()) { override.type = Override::Type::kInt32; } else if (type->is_unsigned_integer_scalar()) { @@ -270,6 +274,10 @@ std::map Inspector::GetOverrideDefaultValues() { [&](const type::I32*) { return Scalar(value->ValueAs()); }, [&](const type::U32*) { return Scalar(value->ValueAs()); }, [&](const type::F32*) { return Scalar(value->ValueAs()); }, + [&](const type::F16*) { + // Default value of f16 override is also stored as float scalar. + return Scalar(static_cast(value->ValueAs())); + }, [&](const type::Bool*) { return Scalar(value->ValueAs()); }); continue; } diff --git a/src/tint/inspector/inspector_test.cc b/src/tint/inspector/inspector_test.cc index b3590743d4..de09369572 100644 --- a/src/tint/inspector/inspector_test.cc +++ b/src/tint/inspector/inspector_test.cc @@ -908,18 +908,23 @@ TEST_F(InspectorGetEntryPointTest, OverrideReferencedByArraySizeViaAlias) { } TEST_F(InspectorGetEntryPointTest, OverrideTypes) { + Enable(builtin::Extension::kF16); + Override("bool_var", ty.bool_()); Override("float_var", ty.f32()); Override("u32_var", ty.u32()); Override("i32_var", ty.i32()); + Override("f16_var", ty.f16()); MakePlainGlobalReferenceBodyFunction("bool_func", "bool_var", ty.bool_(), utils::Empty); MakePlainGlobalReferenceBodyFunction("float_func", "float_var", ty.f32(), utils::Empty); MakePlainGlobalReferenceBodyFunction("u32_func", "u32_var", ty.u32(), utils::Empty); MakePlainGlobalReferenceBodyFunction("i32_func", "i32_var", ty.i32(), utils::Empty); + MakePlainGlobalReferenceBodyFunction("f16_func", "f16_var", ty.f16(), utils::Empty); MakeCallerBodyFunction( - "ep_func", utils::Vector{std::string("bool_func"), "float_func", "u32_func", "i32_func"}, + "ep_func", + utils::Vector{std::string("bool_func"), "float_func", "u32_func", "i32_func", "f16_func"}, utils::Vector{ Stage(ast::PipelineStage::kCompute), WorkgroupSize(1_i), @@ -930,7 +935,7 @@ TEST_F(InspectorGetEntryPointTest, OverrideTypes) { auto result = inspector.GetEntryPoints(); ASSERT_EQ(1u, result.size()); - ASSERT_EQ(4u, result[0].overrides.size()); + ASSERT_EQ(5u, result[0].overrides.size()); EXPECT_EQ("bool_var", result[0].overrides[0].name); EXPECT_EQ(inspector::Override::Type::kBool, result[0].overrides[0].type); EXPECT_EQ("float_var", result[0].overrides[1].name); @@ -939,6 +944,8 @@ TEST_F(InspectorGetEntryPointTest, OverrideTypes) { EXPECT_EQ(inspector::Override::Type::kUint32, result[0].overrides[2].type); EXPECT_EQ("i32_var", result[0].overrides[3].name); EXPECT_EQ(inspector::Override::Type::kInt32, result[0].overrides[3].type); + EXPECT_EQ("f16_var", result[0].overrides[4].name); + EXPECT_EQ(inspector::Override::Type::kFloat16, result[0].overrides[4].type); } TEST_F(InspectorGetEntryPointTest, OverrideInitialized) { @@ -1572,7 +1579,7 @@ TEST_F(InspectorGetOverrideDefaultValuesTest, I32) { EXPECT_EQ(100, result[OverrideId{6000}].AsI32()); } -TEST_F(InspectorGetOverrideDefaultValuesTest, Float) { +TEST_F(InspectorGetOverrideDefaultValuesTest, F32) { Override("a", ty.f32(), Id(1_a)); Override("b", ty.f32(), Expr(0_f), Id(20_a)); Override("c", ty.f32(), Expr(-10_f), Id(300_a)); @@ -1609,6 +1616,46 @@ TEST_F(InspectorGetOverrideDefaultValuesTest, Float) { EXPECT_FLOAT_EQ(150.0f, result[OverrideId{6000}].AsFloat()); } +TEST_F(InspectorGetOverrideDefaultValuesTest, F16) { + Enable(builtin::Extension::kF16); + + Override("a", ty.f16(), Id(1_a)); + Override("b", ty.f16(), Expr(0_h), Id(20_a)); + Override("c", ty.f16(), Expr(-10_h), Id(300_a)); + Override("d", Expr(15_h), Id(4000_a)); + Override("3", Expr(42.0_h), Id(5000_a)); + Override("e", ty.f16(), Mul(15_h, 10_a), Id(6000_a)); + + Inspector& inspector = Build(); + + auto result = inspector.GetOverrideDefaultValues(); + ASSERT_EQ(6u, result.size()); + + ASSERT_TRUE(result.find(OverrideId{1}) != result.end()); + EXPECT_TRUE(result[OverrideId{1}].IsNull()); + + ASSERT_TRUE(result.find(OverrideId{20}) != result.end()); + // Default value of f16 override is also stored as float scalar. + EXPECT_TRUE(result[OverrideId{20}].IsFloat()); + EXPECT_FLOAT_EQ(0.0f, result[OverrideId{20}].AsFloat()); + + ASSERT_TRUE(result.find(OverrideId{300}) != result.end()); + EXPECT_TRUE(result[OverrideId{300}].IsFloat()); + EXPECT_FLOAT_EQ(-10.0f, result[OverrideId{300}].AsFloat()); + + ASSERT_TRUE(result.find(OverrideId{4000}) != result.end()); + EXPECT_TRUE(result[OverrideId{4000}].IsFloat()); + EXPECT_FLOAT_EQ(15.0f, result[OverrideId{4000}].AsFloat()); + + ASSERT_TRUE(result.find(OverrideId{5000}) != result.end()); + EXPECT_TRUE(result[OverrideId{5000}].IsFloat()); + EXPECT_FLOAT_EQ(42.0f, result[OverrideId{5000}].AsFloat()); + + ASSERT_TRUE(result.find(OverrideId{6000}) != result.end()); + EXPECT_TRUE(result[OverrideId{6000}].IsFloat()); + EXPECT_FLOAT_EQ(150.0f, result[OverrideId{6000}].AsFloat()); +} + TEST_F(InspectorGetConstantNameToIdMapTest, WithAndWithoutIds) { Override("v1", ty.f32(), Id(1_a)); Override("v20", ty.f32(), Id(20_a)); diff --git a/src/tint/resolver/override_test.cc b/src/tint/resolver/override_test.cc index fd0f649d52..c8259690d6 100644 --- a/src/tint/resolver/override_test.cc +++ b/src/tint/resolver/override_test.cc @@ -66,10 +66,12 @@ TEST_F(ResolverOverrideTest, WithoutId) { } TEST_F(ResolverOverrideTest, WithAndWithoutIds) { + Enable(builtin::Extension::kF16); + auto* a = Override("a", ty.f32(), Expr(1_f)); - auto* b = Override("b", ty.f32(), Expr(1_f)); - auto* c = Override("c", ty.f32(), Expr(1_f), Id(2_u)); - auto* d = Override("d", ty.f32(), Expr(1_f), Id(4_u)); + auto* b = Override("b", ty.f16(), Expr(1_h)); + auto* c = Override("c", ty.i32(), Expr(1_i), Id(2_u)); + auto* d = Override("d", ty.u32(), Expr(1_u), Id(4_u)); auto* e = Override("e", ty.f32(), Expr(1_f)); auto* f = Override("f", ty.f32(), Expr(1_f), Id(1_u)); @@ -102,16 +104,6 @@ TEST_F(ResolverOverrideTest, IdTooLarge) { EXPECT_EQ(r()->error(), "12:34 error: @id value must be between 0 and 65535"); } -TEST_F(ResolverOverrideTest, F16_TemporallyBan) { - Enable(builtin::Extension::kF16); - - Override(Source{{12, 34}}, "a", ty.f16(), Expr(1_h), Id(1_u)); - - EXPECT_FALSE(r()->Resolve()); - - EXPECT_EQ(r()->error(), "12:34 error: 'override' of type f16 is not implemented yet"); -} - TEST_F(ResolverOverrideTest, TransitiveReferences_DirectUse) { auto* a = Override("a", ty.f32()); auto* b = Override("b", ty.f32(), Expr(1_f)); diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc index c3b8a3d884..b005779d08 100644 --- a/src/tint/resolver/validator.cc +++ b/src/tint/resolver/validator.cc @@ -788,11 +788,6 @@ bool Validator::Override( return false; } - if (storage_ty->Is()) { - AddError("'override' of type f16 is not implemented yet", decl->source); - return false; - } - return true; } diff --git a/src/tint/transform/substitute_override_test.cc b/src/tint/transform/substitute_override_test.cc index 1b4646fa45..5deab72655 100644 --- a/src/tint/transform/substitute_override_test.cc +++ b/src/tint/transform/substitute_override_test.cc @@ -84,15 +84,16 @@ fn main() -> @builtin(position) vec4 { TEST_F(SubstituteOverrideTest, ImplicitId) { auto* src = R"( +enable f16; + override i_width: i32; override i_height = 1i; override f_width: f32; override f_height = 1.f; -// TODO(crbug.com/tint/1473) -// override h_width: f16; -// override h_height = 1.h; +override h_width: f16; +override h_height = 1.h; override b_width: bool; override b_height = true; @@ -106,6 +107,8 @@ fn main() -> @builtin(position) vec4 { )"; auto* expect = R"( +enable f16; + const i_width : i32 = 42i; const i_height = 11i; @@ -114,6 +117,10 @@ const f_width : f32 = 22.299999237060546875f; const f_height = 12.3999996185302734375f; +const h_width : f16 = 9.3984375h; + +const h_height = 3.3984375h; + const b_width : bool = true; const b_height = false; @@ -131,10 +138,10 @@ fn main() -> @builtin(position) vec4 { cfg.map.insert({OverrideId{1}, 11.0}); cfg.map.insert({OverrideId{2}, 22.3}); cfg.map.insert({OverrideId{3}, 12.4}); - // cfg.map.insert({OverrideId{4}, 9.4}); - // cfg.map.insert({OverrideId{5}, 3.4}); - cfg.map.insert({OverrideId{4}, 1.0}); - cfg.map.insert({OverrideId{5}, 0.0}); + cfg.map.insert({OverrideId{4}, 9.4}); + cfg.map.insert({OverrideId{5}, 3.4}); + cfg.map.insert({OverrideId{6}, 1.0}); + cfg.map.insert({OverrideId{7}, 0.0}); DataMap data; data.Add(cfg); @@ -153,9 +160,8 @@ enable f16; @id(1) override f_width: f32; @id(9) override f_height = 1.f; -// TODO(crbug.com/tint/1473) -// @id(2) override h_width: f16; -// @id(8) override h_height = 1.h; +@id(2) override h_width: f16; +@id(8) override h_height = 1.h; @id(3) override b_width: bool; @id(7) override b_height = true; @@ -179,6 +185,10 @@ const f_width : f32 = 22.299999237060546875f; const f_height = 12.3999996185302734375f; +const h_width : f16 = 9.3984375h; + +const h_height = 3.3984375h; + const b_width : bool = true; const b_height = false;