diff --git a/src/dawn/native/Pipeline.cpp b/src/dawn/native/Pipeline.cpp index dc55f637d1..78604f3aa5 100644 --- a/src/dawn/native/Pipeline.cpp +++ b/src/dawn/native/Pipeline.cpp @@ -67,6 +67,9 @@ MaybeError ValidateProgrammableStage(DeviceBase* device, DAWN_INVALID_IF(metadata.overrides.count(constants[i].key) == 0, "Pipeline overridable constant \"%s\" not found in %s.", constants[i].key, module); + DAWN_INVALID_IF(!std::isfinite(constants[i].value), + "Pipeline overridable constant \"%s\" with value (%f) is not finite", + constants[i].key, constants[i].value); if (stageInitializedConstantIdentifiers.count(constants[i].key) == 0) { if (metadata.uninitializedOverrides.count(constants[i].key) > 0) { diff --git a/src/dawn/tests/unittests/validation/OverridableConstantsValidationTests.cpp b/src/dawn/tests/unittests/validation/OverridableConstantsValidationTests.cpp index c0f9eec3b5..4fd31508c8 100644 --- a/src/dawn/tests/unittests/validation/OverridableConstantsValidationTests.cpp +++ b/src/dawn/tests/unittests/validation/OverridableConstantsValidationTests.cpp @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include "dawn/common/Constants.h" @@ -217,3 +218,36 @@ TEST_F(ComputePipelineOverridableConstantsValidationTest, ConstantsIdentifierUni ASSERT_DEVICE_ERROR(TestCreatePipeline(constants)); } } + +// Test that values like NaN and Inf are treated as invalid. +TEST_F(ComputePipelineOverridableConstantsValidationTest, InvalidValue) { + SetUpShadersWithDefaultValueConstants(); + { + // Error:: NaN + std::vector constants{{nullptr, "c3", std::nan("")}}; + ASSERT_DEVICE_ERROR(TestCreatePipeline(constants)); + } + { + // Error:: -NaN + std::vector constants{{nullptr, "c3", -std::nan("")}}; + ASSERT_DEVICE_ERROR(TestCreatePipeline(constants)); + } + { + // Error:: Inf + std::vector constants{ + {nullptr, "c3", std::numeric_limits::infinity()}}; + ASSERT_DEVICE_ERROR(TestCreatePipeline(constants)); + } + { + // Error:: -Inf + std::vector constants{ + {nullptr, "c3", -std::numeric_limits::infinity()}}; + ASSERT_DEVICE_ERROR(TestCreatePipeline(constants)); + } + { + // Valid:: Max + std::vector constants{ + {nullptr, "c3", std::numeric_limits::max()}}; + TestCreatePipeline(constants); + } +}