diff --git a/dawn.json b/dawn.json index de13f0d5a1..03cf47d0fe 100644 --- a/dawn.json +++ b/dawn.json @@ -2426,6 +2426,15 @@ {"name": "code", "type": "char", "annotation": "const*", "length": "strlen", "tags": ["upstream"]} ] }, + "dawn shader module SPIRV options descriptor": { + "category": "structure", + "chained": "in", + "chain roots": ["shader module descriptor"], + "tags": ["dawn"], + "members": [ + {"name": "allow non uniform derivatives", "type": "bool", "default": "false"} + ] + }, "shader stage": { "category": "bitmask", "values": [ @@ -2619,7 +2628,8 @@ {"value": 1005, "name": "dawn cache device descriptor", "tags": ["dawn", "native"]}, {"value": 1006, "name": "dawn adapter properties power preference", "tags": ["dawn", "native"]}, {"value": 1007, "name": "dawn buffer descriptor error info from wire client", "tags": ["dawn"]}, - {"value": 1008, "name": "dawn toggles descriptor", "tags": ["dawn", "native"]} + {"value": 1008, "name": "dawn toggles descriptor", "tags": ["dawn", "native"]}, + {"value": 1009, "name": "dawn shader module SPIRV options descriptor", "tags": ["dawn"]} ] }, "texture": { diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp index c2ea13efdf..7db19a232b 100644 --- a/src/dawn/native/ShaderModule.cpp +++ b/src/dawn/native/ShaderModule.cpp @@ -315,8 +315,13 @@ ResultOrError ParseWGSL(const tint::Source::File* file, #if TINT_BUILD_SPV_READER ResultOrError ParseSPIRV(const std::vector& spirv, - OwnedCompilationMessages* outMessages) { - tint::Program program = tint::reader::spirv::Parse(spirv); + OwnedCompilationMessages* outMessages, + const DawnShaderModuleSPIRVOptionsDescriptor* optionsDesc) { + tint::reader::spirv::Options options; + if (optionsDesc) { + options.allow_non_uniform_derivatives = optionsDesc->allowNonUniformDerivatives; + } + tint::Program program = tint::reader::spirv::Parse(spirv, options); if (outMessages != nullptr) { DAWN_TRY(outMessages->AddMessages(program.Diagnostics())); } @@ -905,10 +910,13 @@ MaybeError ValidateAndParseShaderModule(DeviceBase* device, DAWN_INVALID_IF(chainedDescriptor == nullptr, "Shader module descriptor missing chained descriptor"); -// For now only a single WGSL (or SPIRV, if enabled) subdescriptor is allowed. +// A WGSL (or SPIR-V, if enabled) subdescriptor is required, and a Dawn-specific SPIR-V options +// descriptor is allowed when using SPIR-V. #if TINT_BUILD_SPV_READER - DAWN_TRY(ValidateSingleSType(chainedDescriptor, wgpu::SType::ShaderModuleSPIRVDescriptor, - wgpu::SType::ShaderModuleWGSLDescriptor)); + DAWN_TRY(ValidateSTypes( + chainedDescriptor, + {{wgpu::SType::ShaderModuleSPIRVDescriptor, wgpu::SType::ShaderModuleWGSLDescriptor}, + {wgpu::SType::DawnShaderModuleSPIRVOptionsDescriptor}})); #else DAWN_TRY(ValidateSingleSType(chainedDescriptor, wgpu::SType::ShaderModuleWGSLDescriptor)); #endif @@ -918,10 +926,19 @@ MaybeError ValidateAndParseShaderModule(DeviceBase* device, const ShaderModuleWGSLDescriptor* wgslDesc = nullptr; FindInChain(chainedDescriptor, &wgslDesc); + const DawnShaderModuleSPIRVOptionsDescriptor* spirvOptions = nullptr; + FindInChain(chainedDescriptor, &spirvOptions); + + DAWN_INVALID_IF(wgslDesc != nullptr && spirvOptions != nullptr, + "SPIR-V options descriptor not valid with WGSL descriptor"); + #if TINT_BUILD_SPV_READER const ShaderModuleSPIRVDescriptor* spirvDesc = nullptr; FindInChain(chainedDescriptor, &spirvDesc); + DAWN_INVALID_IF(spirvOptions != nullptr && spirvDesc == nullptr, + "SPIR-V options descriptor can only be used with SPIR-V input"); + // We have a temporary toggle to force the SPIRV ingestion to go through a WGSL // intermediate step. It is done by switching the spirvDesc for a wgslDesc below. ShaderModuleWGSLDescriptor newWgslDesc; @@ -930,7 +947,7 @@ MaybeError ValidateAndParseShaderModule(DeviceBase* device, #if TINT_BUILD_WGSL_WRITER std::vector spirv(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize); tint::Program program; - DAWN_TRY_ASSIGN(program, ParseSPIRV(spirv, outMessages)); + DAWN_TRY_ASSIGN(program, ParseSPIRV(spirv, outMessages, spirvOptions)); tint::writer::wgsl::Options options; auto result = tint::writer::wgsl::Generate(&program, options); @@ -953,7 +970,7 @@ MaybeError ValidateAndParseShaderModule(DeviceBase* device, std::vector spirv(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize); tint::Program program; - DAWN_TRY_ASSIGN(program, ParseSPIRV(spirv, outMessages)); + DAWN_TRY_ASSIGN(program, ParseSPIRV(spirv, outMessages, spirvOptions)); parseResult->tintProgram = std::make_unique(std::move(program)); return {}; diff --git a/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp b/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp index 0f8e4f70f8..9efb35fd92 100644 --- a/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp +++ b/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp @@ -137,6 +137,56 @@ TEST_F(ShaderModuleValidationTest, MultisampledArrayTexture) { ASSERT_DEVICE_ERROR(utils::CreateShaderModuleFromASM(device, shader)); } + +const char* kShaderWithNonUniformDerivative = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %foo "foo" %x + OpExecutionMode %foo OriginUpperLeft + OpDecorate %x Location 0 + %float = OpTypeFloat 32 +%_ptr_Input_float = OpTypePointer Input %float + %x = OpVariable %_ptr_Input_float Input + %void = OpTypeVoid + %float_0 = OpConstantNull %float + %bool = OpTypeBool + %func_type = OpTypeFunction %void + %foo = OpFunction %void None %func_type + %foo_start = OpLabel + %x_value = OpLoad %float %x + %condition = OpFOrdGreaterThan %bool %x_value %float_0 + OpSelectionMerge %merge None + OpBranchConditional %condition %true_branch %merge +%true_branch = OpLabel + %result = OpDPdx %float %x_value + OpBranch %merge + %merge = OpLabel + OpReturn + OpFunctionEnd)"; + +// Test that creating a module with a SPIR-V shader that has a uniformity violation fails when no +// SPIR-V options descriptor is used. +TEST_F(ShaderModuleValidationTest, NonUniformDerivatives_NoOptions) { + ASSERT_DEVICE_ERROR(utils::CreateShaderModuleFromASM(device, kShaderWithNonUniformDerivative)); +} + +// Test that creating a module with a SPIR-V shader that has a uniformity violation fails when +// passing a SPIR-V options descriptor with the `allowNonUniformDerivatives` flag set to `false`. +TEST_F(ShaderModuleValidationTest, NonUniformDerivatives_FlagSetToFalse) { + wgpu::DawnShaderModuleSPIRVOptionsDescriptor spirv_options_desc = {}; + spirv_options_desc.allowNonUniformDerivatives = false; + ASSERT_DEVICE_ERROR(utils::CreateShaderModuleFromASM(device, kShaderWithNonUniformDerivative, + &spirv_options_desc)); +} + +// Test that creating a module with a SPIR-V shader that has a uniformity violation succeeds when +// passing a SPIR-V options descriptor with the `allowNonUniformDerivatives` flag set to `true`. +TEST_F(ShaderModuleValidationTest, NonUniformDerivatives_FlagSetToTrue) { + wgpu::DawnShaderModuleSPIRVOptionsDescriptor spirv_options_desc = {}; + spirv_options_desc.allowNonUniformDerivatives = true; + utils::CreateShaderModuleFromASM(device, kShaderWithNonUniformDerivative, &spirv_options_desc); +} + #endif // TINT_BUILD_SPV_READER // Test that it is invalid to create a shader module with no chained descriptor. (It must be @@ -146,6 +196,47 @@ TEST_F(ShaderModuleValidationTest, NoChainedDescriptor) { ASSERT_DEVICE_ERROR(device.CreateShaderModule(&desc)); } +// Test that it is invalid to create a shader module that uses both the WGSL descriptor and the +// SPIRV descriptor. +TEST_F(ShaderModuleValidationTest, MultipleChainedDescriptor_WgslAndSpirv) { + uint32_t code = 42; + wgpu::ShaderModuleDescriptor desc = {}; + wgpu::ShaderModuleSPIRVDescriptor spirv_desc = {}; + spirv_desc.code = &code; + spirv_desc.codeSize = 1; + wgpu::ShaderModuleWGSLDescriptor wgsl_desc = {}; + wgsl_desc.source = ""; + wgsl_desc.nextInChain = &spirv_desc; + desc.nextInChain = &wgsl_desc; + ASSERT_DEVICE_ERROR(device.CreateShaderModule(&desc), + testing::HasSubstr("is part of a group of exclusive sTypes")); +} + +// Test that it is invalid to create a shader module that uses both the WGSL descriptor and the +// Dawn SPIRV options descriptor. +TEST_F(ShaderModuleValidationTest, MultipleChainedDescriptor_WgslAndDawnSpirvOptions) { + wgpu::ShaderModuleDescriptor desc = {}; + wgpu::DawnShaderModuleSPIRVOptionsDescriptor spirv_options_desc = {}; + wgpu::ShaderModuleWGSLDescriptor wgsl_desc = {}; + wgsl_desc.nextInChain = &spirv_options_desc; + wgsl_desc.source = ""; + desc.nextInChain = &wgsl_desc; + ASSERT_DEVICE_ERROR( + device.CreateShaderModule(&desc), + testing::HasSubstr("SPIR-V options descriptor not valid with WGSL descriptor")); +} + +// Test that it is invalid to create a shader module that only uses the Dawn SPIRV options +// descriptor without the SPIRV descriptor. +TEST_F(ShaderModuleValidationTest, OnlySpirvOptionsDescriptor) { + wgpu::ShaderModuleDescriptor desc = {}; + wgpu::DawnShaderModuleSPIRVOptionsDescriptor spirv_options_desc = {}; + desc.nextInChain = &spirv_options_desc; + ASSERT_DEVICE_ERROR( + device.CreateShaderModule(&desc), + testing::HasSubstr("SPIR-V options descriptor can only be used with SPIR-V input")); +} + // Tests that shader module compilation messages can be queried. TEST_F(ShaderModuleValidationTest, GetCompilationMessages) { // This test works assuming ShaderModule is backed by a dawn::native::ShaderModuleBase, which diff --git a/src/dawn/utils/WGPUHelpers.cpp b/src/dawn/utils/WGPUHelpers.cpp index d6b1a4a9a9..e517945fc2 100644 --- a/src/dawn/utils/WGPUHelpers.cpp +++ b/src/dawn/utils/WGPUHelpers.cpp @@ -41,7 +41,10 @@ std::array kGammaEncodeSrgb = {1 / 2.4, 1.137119, 0.0, 12.92, 0.003130 namespace utils { #if TINT_BUILD_SPV_READER -wgpu::ShaderModule CreateShaderModuleFromASM(const wgpu::Device& device, const char* source) { +wgpu::ShaderModule CreateShaderModuleFromASM( + const wgpu::Device& device, + const char* source, + wgpu::DawnShaderModuleSPIRVOptionsDescriptor* spirv_options) { // Use SPIRV-Tools's C API to assemble the SPIR-V assembly text to binary. Because the types // aren't RAII, we don't return directly on success and instead always go through the code // path that destroys the SPIRV-Tools objects. @@ -59,6 +62,7 @@ wgpu::ShaderModule CreateShaderModuleFromASM(const wgpu::Device& device, const c wgpu::ShaderModuleSPIRVDescriptor spirvDesc; spirvDesc.codeSize = static_cast(spirv->wordCount); spirvDesc.code = spirv->code; + spirvDesc.nextInChain = spirv_options; wgpu::ShaderModuleDescriptor descriptor; descriptor.nextInChain = &spirvDesc; diff --git a/src/dawn/utils/WGPUHelpers.h b/src/dawn/utils/WGPUHelpers.h index d56dc59ad3..f05e323be9 100644 --- a/src/dawn/utils/WGPUHelpers.h +++ b/src/dawn/utils/WGPUHelpers.h @@ -28,7 +28,10 @@ namespace utils { enum Expectation { Success, Failure }; #if TINT_BUILD_SPV_READER -wgpu::ShaderModule CreateShaderModuleFromASM(const wgpu::Device& device, const char* source); +wgpu::ShaderModule CreateShaderModuleFromASM( + const wgpu::Device& device, + const char* source, + wgpu::DawnShaderModuleSPIRVOptionsDescriptor* spirv_options = nullptr); #endif wgpu::ShaderModule CreateShaderModule(const wgpu::Device& device, const char* source);