diff --git a/src/dawn/native/Pipeline.cpp b/src/dawn/native/Pipeline.cpp index 73a46bec76..344d948aba 100644 --- a/src/dawn/native/Pipeline.cpp +++ b/src/dawn/native/Pipeline.cpp @@ -37,6 +37,15 @@ namespace dawn::native { const EntryPointMetadata& metadata = module->GetEntryPoint(entryPoint); + if (!metadata.infringedLimitErrors.empty()) { + std::ostringstream out; + out << "Entry point \"" << entryPoint << "\" infringes limits:\n"; + for (const std::string& limit : metadata.infringedLimitErrors) { + out << " - " << limit << "\n"; + } + return DAWN_VALIDATION_ERROR(out.str()); + } + DAWN_INVALID_IF(metadata.stage != stage, "The stage (%s) of the entry point \"%s\" isn't the expected one (%s).", metadata.stage, entryPoint, stage); diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp index afacee75de..285e4db3c3 100644 --- a/src/dawn/native/ShaderModule.cpp +++ b/src/dawn/native/ShaderModule.cpp @@ -606,6 +606,17 @@ namespace dawn::native { std::unique_ptr metadata = std::make_unique(); + // Returns the invalid argument, and if it is true additionally store the formatted + // error in metadata.infringedLimits. This is to delay the emission of these validation + // errors until the entry point is used. +#define DelayedInvalidIf(invalid, ...) \ + ([&]() { \ + if (invalid) { \ + metadata->infringedLimitErrors.push_back(absl::StrFormat(__VA_ARGS__)); \ + } \ + return invalid; \ + })() + if (!entryPoint.overridable_constants.empty()) { DAWN_INVALID_IF(device->IsToggleEnabled(Toggle::DisallowUnsafeAPIs), "Pipeline overridable constants are disallowed because they " @@ -657,7 +668,7 @@ namespace dawn::native { DAWN_TRY_ASSIGN(metadata->stage, TintPipelineStageToShaderStage(entryPoint.stage)); if (metadata->stage == SingleShaderStage::Compute) { - DAWN_INVALID_IF( + DelayedInvalidIf( entryPoint.workgroup_size_x > limits.v1.maxComputeWorkgroupSizeX || entryPoint.workgroup_size_y > limits.v1.maxComputeWorkgroupSizeY || entryPoint.workgroup_size_z > limits.v1.maxComputeWorkgroupSizeZ, @@ -671,17 +682,17 @@ namespace dawn::native { // Cast to uint64_t to avoid overflow in this multiplication. uint64_t numInvocations = static_cast(entryPoint.workgroup_size_x) * entryPoint.workgroup_size_y * entryPoint.workgroup_size_z; - DAWN_INVALID_IF(numInvocations > limits.v1.maxComputeInvocationsPerWorkgroup, - "The total number of workgroup invocations (%u) exceeds the " - "maximum allowed (%u).", - numInvocations, limits.v1.maxComputeInvocationsPerWorkgroup); + DelayedInvalidIf(numInvocations > limits.v1.maxComputeInvocationsPerWorkgroup, + "The total number of workgroup invocations (%u) exceeds the " + "maximum allowed (%u).", + numInvocations, limits.v1.maxComputeInvocationsPerWorkgroup); const size_t workgroupStorageSize = inspector->GetWorkgroupStorageSize(entryPoint.name); - DAWN_INVALID_IF(workgroupStorageSize > limits.v1.maxComputeWorkgroupStorageSize, - "The total use of workgroup storage (%u bytes) is larger than " - "the maximum allowed (%u bytes).", - workgroupStorageSize, limits.v1.maxComputeWorkgroupStorageSize); + DelayedInvalidIf(workgroupStorageSize > limits.v1.maxComputeWorkgroupStorageSize, + "The total use of workgroup storage (%u bytes) is larger than " + "the maximum allowed (%u bytes).", + workgroupStorageSize, limits.v1.maxComputeWorkgroupStorageSize); metadata->localWorkgroupSize.x = entryPoint.workgroup_size_x; metadata->localWorkgroupSize.y = entryPoint.workgroup_size_y; @@ -698,12 +709,15 @@ namespace dawn::native { inputVar.name); uint32_t unsanitizedLocation = inputVar.location_decoration; - DAWN_INVALID_IF(unsanitizedLocation >= kMaxVertexAttributes, - "Vertex input variable \"%s\" has a location (%u) that " - "exceeds the maximum (%u)", - inputVar.name, unsanitizedLocation, kMaxVertexAttributes); - VertexAttributeLocation location(static_cast(unsanitizedLocation)); + if (DelayedInvalidIf(unsanitizedLocation >= kMaxVertexAttributes, + "Vertex input variable \"%s\" has a location (%u) that " + "exceeds the maximum (%u)", + inputVar.name, unsanitizedLocation, + kMaxVertexAttributes)) { + continue; + } + VertexAttributeLocation location(static_cast(unsanitizedLocation)); DAWN_TRY_ASSIGN( metadata->vertexInputBaseTypes[location], TintComponentTypeToVertexFormatBaseType(inputVar.component_type)); @@ -714,36 +728,38 @@ namespace dawn::native { // output variable by Tint so we directly add its components to the total. uint32_t totalInterStageShaderComponents = 4; for (const auto& outputVar : entryPoint.output_variables) { + EntryPointMetadata::InterStageVariableInfo variable; + DAWN_TRY_ASSIGN(variable.baseType, TintComponentTypeToInterStageComponentType( + outputVar.component_type)); + DAWN_TRY_ASSIGN( + variable.componentCount, + TintCompositionTypeToInterStageComponentCount(outputVar.composition_type)); + DAWN_TRY_ASSIGN( + variable.interpolationType, + TintInterpolationTypeToInterpolationType(outputVar.interpolation_type)); + DAWN_TRY_ASSIGN(variable.interpolationSampling, + TintInterpolationSamplingToInterpolationSamplingType( + outputVar.interpolation_sampling)); + totalInterStageShaderComponents += variable.componentCount; + DAWN_INVALID_IF( !outputVar.has_location_decoration, "Vertex ouput variable \"%s\" doesn't have a location decoration.", outputVar.name); uint32_t location = outputVar.location_decoration; - DAWN_INVALID_IF(location > kMaxInterStageShaderLocation, - "Vertex output variable \"%s\" has a location (%u) that " - "exceeds the maximum (%u).", - outputVar.name, location, kMaxInterStageShaderLocation); + if (DelayedInvalidIf(location > kMaxInterStageShaderLocation, + "Vertex output variable \"%s\" has a location (%u) that " + "exceeds the maximum (%u).", + outputVar.name, location, kMaxInterStageShaderLocation)) { + continue; + } metadata->usedInterStageVariables.set(location); - DAWN_TRY_ASSIGN( - metadata->interStageVariables[location].baseType, - TintComponentTypeToInterStageComponentType(outputVar.component_type)); - DAWN_TRY_ASSIGN( - metadata->interStageVariables[location].componentCount, - TintCompositionTypeToInterStageComponentCount(outputVar.composition_type)); - DAWN_TRY_ASSIGN( - metadata->interStageVariables[location].interpolationType, - TintInterpolationTypeToInterpolationType(outputVar.interpolation_type)); - DAWN_TRY_ASSIGN(metadata->interStageVariables[location].interpolationSampling, - TintInterpolationSamplingToInterpolationSamplingType( - outputVar.interpolation_sampling)); - - totalInterStageShaderComponents += - metadata->interStageVariables[location].componentCount; + metadata->interStageVariables[location] = variable; } - DAWN_INVALID_IF( + DelayedInvalidIf( totalInterStageShaderComponents > kMaxInterStageShaderComponents, "Total vertex output components count (%u) exceeds the maximum (%u).", totalInterStageShaderComponents, kMaxInterStageShaderComponents); @@ -752,33 +768,35 @@ namespace dawn::native { if (metadata->stage == SingleShaderStage::Fragment) { uint32_t totalInterStageShaderComponents = 0; for (const auto& inputVar : entryPoint.input_variables) { + EntryPointMetadata::InterStageVariableInfo variable; + DAWN_TRY_ASSIGN(variable.baseType, TintComponentTypeToInterStageComponentType( + inputVar.component_type)); + DAWN_TRY_ASSIGN( + variable.componentCount, + TintCompositionTypeToInterStageComponentCount(inputVar.composition_type)); + DAWN_TRY_ASSIGN( + variable.interpolationType, + TintInterpolationTypeToInterpolationType(inputVar.interpolation_type)); + DAWN_TRY_ASSIGN(variable.interpolationSampling, + TintInterpolationSamplingToInterpolationSamplingType( + inputVar.interpolation_sampling)); + totalInterStageShaderComponents += variable.componentCount; + DAWN_INVALID_IF( !inputVar.has_location_decoration, "Fragment input variable \"%s\" doesn't have a location decoration.", inputVar.name); uint32_t location = inputVar.location_decoration; - DAWN_INVALID_IF(location > kMaxInterStageShaderLocation, - "Fragment input variable \"%s\" has a location (%u) that " - "exceeds the maximum (%u).", - inputVar.name, location, kMaxInterStageShaderLocation); + if (DelayedInvalidIf(location > kMaxInterStageShaderLocation, + "Fragment input variable \"%s\" has a location (%u) that " + "exceeds the maximum (%u).", + inputVar.name, location, kMaxInterStageShaderLocation)) { + continue; + } metadata->usedInterStageVariables.set(location); - DAWN_TRY_ASSIGN( - metadata->interStageVariables[location].baseType, - TintComponentTypeToInterStageComponentType(inputVar.component_type)); - DAWN_TRY_ASSIGN( - metadata->interStageVariables[location].componentCount, - TintCompositionTypeToInterStageComponentCount(inputVar.composition_type)); - DAWN_TRY_ASSIGN( - metadata->interStageVariables[location].interpolationType, - TintInterpolationTypeToInterpolationType(inputVar.interpolation_type)); - DAWN_TRY_ASSIGN(metadata->interStageVariables[location].interpolationSampling, - TintInterpolationSamplingToInterpolationSamplingType( - inputVar.interpolation_sampling)); - - totalInterStageShaderComponents += - metadata->interStageVariables[location].componentCount; + metadata->interStageVariables[location] = variable; } if (entryPoint.front_facing_used) { @@ -794,91 +812,77 @@ namespace dawn::native { totalInterStageShaderComponents += 4; } - DAWN_INVALID_IF( + DelayedInvalidIf( totalInterStageShaderComponents > kMaxInterStageShaderComponents, "Total fragment input components count (%u) exceeds the maximum (%u).", totalInterStageShaderComponents, kMaxInterStageShaderComponents); for (const auto& outputVar : entryPoint.output_variables) { + EntryPointMetadata::FragmentOutputVariableInfo variable; + DAWN_TRY_ASSIGN(variable.baseType, TintComponentTypeToTextureComponentType( + outputVar.component_type)); + DAWN_TRY_ASSIGN( + variable.componentCount, + TintCompositionTypeToInterStageComponentCount(outputVar.composition_type)); + ASSERT(variable.componentCount <= 4); + DAWN_INVALID_IF( !outputVar.has_location_decoration, "Fragment input variable \"%s\" doesn't have a location decoration.", outputVar.name); uint32_t unsanitizedAttachment = outputVar.location_decoration; - DAWN_INVALID_IF(unsanitizedAttachment >= kMaxColorAttachments, - "Fragment output variable \"%s\" has a location (%u) that " - "exceeds the maximum (%u).", - outputVar.name, unsanitizedAttachment, kMaxColorAttachments); - ColorAttachmentIndex attachment(static_cast(unsanitizedAttachment)); + if (DelayedInvalidIf(unsanitizedAttachment >= kMaxColorAttachments, + "Fragment output variable \"%s\" has a location (%u) that " + "exceeds the maximum (%u).", + outputVar.name, unsanitizedAttachment, + kMaxColorAttachments)) { + continue; + } - DAWN_TRY_ASSIGN( - metadata->fragmentOutputVariables[attachment].baseType, - TintComponentTypeToTextureComponentType(outputVar.component_type)); - uint32_t componentCount; - DAWN_TRY_ASSIGN(componentCount, TintCompositionTypeToInterStageComponentCount( - outputVar.composition_type)); - // componentCount should be no larger than 4u - ASSERT(componentCount <= 4u); - metadata->fragmentOutputVariables[attachment].componentCount = componentCount; + ColorAttachmentIndex attachment(static_cast(unsanitizedAttachment)); + metadata->fragmentOutputVariables[attachment] = variable; metadata->fragmentOutputsWritten.set(attachment); } } for (const tint::inspector::ResourceBinding& resource : inspector->GetResourceBindings(entryPoint.name)) { - DAWN_INVALID_IF(resource.bind_group >= kMaxBindGroups, - "The entry-point uses a binding with a group decoration (%u) " - "that exceeds the maximum (%u).", - resource.bind_group, kMaxBindGroups); + ShaderBindingInfo info; - BindingNumber bindingNumber(resource.binding); - BindGroupIndex bindGroupIndex(resource.bind_group); + info.bindingType = TintResourceTypeToBindingInfoType(resource.resource_type); - DAWN_INVALID_IF(bindingNumber > kMaxBindingNumberTyped, - "Binding number (%u) exceeds the maximum binding number (%u).", - uint32_t(bindingNumber), uint32_t(kMaxBindingNumberTyped)); - - const auto& [binding, inserted] = - metadata->bindings[bindGroupIndex].emplace(bindingNumber, ShaderBindingInfo{}); - DAWN_INVALID_IF(!inserted, - "Entry-point has a duplicate binding for (group:%u, binding:%u).", - resource.binding, resource.bind_group); - - ShaderBindingInfo* info = &binding->second; - info->bindingType = TintResourceTypeToBindingInfoType(resource.resource_type); - - switch (info->bindingType) { + switch (info.bindingType) { case BindingInfoType::Buffer: - info->buffer.minBindingSize = resource.size_no_padding; - DAWN_TRY_ASSIGN(info->buffer.type, TintResourceTypeToBufferBindingType( - resource.resource_type)); + info.buffer.minBindingSize = resource.size_no_padding; + DAWN_TRY_ASSIGN(info.buffer.type, TintResourceTypeToBufferBindingType( + resource.resource_type)); break; case BindingInfoType::Sampler: switch (resource.resource_type) { case tint::inspector::ResourceBinding::ResourceType::kSampler: - info->sampler.isComparison = false; + info.sampler.isComparison = false; break; case tint::inspector::ResourceBinding::ResourceType::kComparisonSampler: - info->sampler.isComparison = true; + info.sampler.isComparison = true; break; default: UNREACHABLE(); } break; case BindingInfoType::Texture: - info->texture.viewDimension = + info.texture.viewDimension = TintTextureDimensionToTextureViewDimension(resource.dim); if (resource.resource_type == tint::inspector::ResourceBinding::ResourceType::kDepthTexture || resource.resource_type == tint::inspector::ResourceBinding:: ResourceType::kDepthMultisampledTexture) { - info->texture.compatibleSampleTypes = SampleTypeBit::Depth; + info.texture.compatibleSampleTypes = SampleTypeBit::Depth; } else { - info->texture.compatibleSampleTypes = + info.texture.compatibleSampleTypes = TintSampledKindToSampleTypeBit(resource.sampled_kind); } - info->texture.multisampled = + info.texture.multisampled = resource.resource_type == tint::inspector::ResourceBinding:: ResourceType::kMultisampledTexture || resource.resource_type == tint::inspector::ResourceBinding:: @@ -887,11 +891,11 @@ namespace dawn::native { break; case BindingInfoType::StorageTexture: DAWN_TRY_ASSIGN( - info->storageTexture.access, + info.storageTexture.access, TintResourceTypeToStorageTextureAccess(resource.resource_type)); - info->storageTexture.format = + info.storageTexture.format = TintImageFormatToTextureFormat(resource.image_format); - info->storageTexture.viewDimension = + info.storageTexture.viewDimension = TintTextureDimensionToTextureViewDimension(resource.dim); break; @@ -900,6 +904,25 @@ namespace dawn::native { default: return DAWN_VALIDATION_ERROR("Unknown binding type in Shader"); } + + BindingNumber bindingNumber(resource.binding); + BindGroupIndex bindGroupIndex(resource.bind_group); + + if (DelayedInvalidIf(bindGroupIndex >= kMaxBindGroupsTyped, + "The entry-point uses a binding with a group decoration (%u) " + "that exceeds the maximum (%u).", + resource.bind_group, kMaxBindGroups) || + DelayedInvalidIf(bindingNumber > kMaxBindingNumberTyped, + "Binding number (%u) exceeds the maximum binding number (%u).", + uint32_t(bindingNumber), uint32_t(kMaxBindingNumberTyped))) { + continue; + } + + const auto& [binding, inserted] = + metadata->bindings[bindGroupIndex].emplace(bindingNumber, info); + DAWN_INVALID_IF(!inserted, + "Entry-point has a duplicate binding for (group:%u, binding:%u).", + resource.binding, resource.bind_group); } std::vector samplerTextureUses = @@ -916,6 +939,7 @@ namespace dawn::native { return result; }); +#undef DelayedInvalidIf return std::move(metadata); } diff --git a/src/dawn/native/ShaderModule.h b/src/dawn/native/ShaderModule.h index 1e8095ea10..ff643eb9f8 100644 --- a/src/dawn/native/ShaderModule.h +++ b/src/dawn/native/ShaderModule.h @@ -165,6 +165,12 @@ namespace dawn::native { // pointers to EntryPointMetadata are safe to store as long as you also keep a Ref to the // ShaderModuleBase. struct EntryPointMetadata { + // It is valid for a shader to contain entry points that go over limits. To keep this + // structure with packed arrays and bitsets, we still validate against limits when + // doing reflection, but store the errors in this vector, for later use if the application + // tries to use the entry point. + std::vector infringedLimitErrors; + // bindings[G][B] is the reflection data for the binding defined with // @group(G) @binding(B) in WGSL / SPIRV. BindingInfoArray bindings; diff --git a/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp b/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp index 0c5e6a0cba..ba702fb18a 100644 --- a/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp +++ b/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp @@ -12,12 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "dawn/common/Constants.h" - -#include "dawn/native/ShaderModule.h" - #include "dawn/tests/unittests/validation/ValidationTest.h" +#include "dawn/common/Constants.h" +#include "dawn/native/ShaderModule.h" +#include "dawn/utils/ComboRenderPipelineDescriptor.h" #include "dawn/utils/WGPUHelpers.h" #include @@ -214,85 +213,114 @@ TEST_F(ShaderModuleValidationTest, GetCompilationMessages) { // Validate the maximum location of effective inter-stage variables cannot be greater than 14 // (kMaxInterStageShaderComponents / 4 - 1). TEST_F(ShaderModuleValidationTest, MaximumShaderIOLocations) { - auto generateShaderForTest = [](uint32_t maximumOutputLocation, wgpu::ShaderStage shaderStage) { + auto CheckTestPipeline = [&](bool success, uint32_t maximumOutputLocation, + wgpu::ShaderStage failingShaderStage) { + // Build the ShaderIO struct containing variables up to maximumOutputLocation. std::ostringstream stream; stream << "struct ShaderIO {" << std::endl; for (uint32_t location = 1; location <= maximumOutputLocation; ++location) { - stream << "@location(" << location << ") var" << location << ": f32;" << std::endl; + stream << "@location(" << location << ") var" << location << ": f32," << std::endl; } - switch (shaderStage) { + + if (failingShaderStage == wgpu::ShaderStage::Vertex) { + stream << " @builtin(position) pos: vec4,"; + } + stream << "}\n"; + + std::string ioStruct = stream.str(); + + // Build the test pipeline. Note that it's not possible with just ASSERT_DEVICE_ERROR + // whether it is the vertex or fragment shader that fails. So instead we will look for the + // string "failingVertex" or "failingFragment" in the error message. + utils::ComboRenderPipelineDescriptor pDesc; + pDesc.cTargets[0].format = wgpu::TextureFormat::RGBA8Unorm; + + const char* errorMatcher = nullptr; + switch (failingShaderStage) { case wgpu::ShaderStage::Vertex: { - stream << R"( - @builtin(position) pos: vec4; - }; - @stage(vertex) fn main() -> ShaderIO { - var shaderIO : ShaderIO; - shaderIO.pos = vec4(0.0, 0.0, 0.0, 1.0); - return shaderIO; - })"; - } break; + errorMatcher = "failingVertex"; + pDesc.vertex.entryPoint = "failingVertex"; + pDesc.vertex.module = utils::CreateShaderModule(device, (ioStruct + R"( + @stage(vertex) fn failingVertex() -> ShaderIO { + var shaderIO : ShaderIO; + shaderIO.pos = vec4(0.0, 0.0, 0.0, 1.0); + return shaderIO; + } + )") + .c_str()); + pDesc.cFragment.module = utils::CreateShaderModule(device, R"( + @stage(fragment) fn main() -> @location(0) vec4 { + return vec4(0.0); + } + )"); + break; + } case wgpu::ShaderStage::Fragment: { - stream << R"( - }; - @stage(fragment) fn main(shaderIO: ShaderIO) -> @location(0) vec4 { - return vec4(0.0, 0.0, 0.0, 1.0); - })"; - } break; + errorMatcher = "failingFragment"; + pDesc.cFragment.entryPoint = "failingFragment"; + pDesc.cFragment.module = utils::CreateShaderModule(device, (ioStruct + R"( + @stage(fragment) fn failingFragment(io : ShaderIO) -> @location(0) vec4 { + return vec4(0.0); + } + )") + .c_str()); + pDesc.vertex.module = utils::CreateShaderModule(device, R"( + @stage(vertex) fn main() -> @builtin(position) vec4 { + return vec4(0.0); + } + )"); + break; + } - case wgpu::ShaderStage::Compute: default: UNREACHABLE(); } - return stream.str(); + if (success) { + ASSERT_DEVICE_ERROR( + device.CreateRenderPipeline(&pDesc), + testing::HasSubstr( + "One or more fragment inputs and vertex outputs are not one-to-one matching")); + } else { + ASSERT_DEVICE_ERROR(device.CreateRenderPipeline(&pDesc), + testing::HasSubstr(errorMatcher)); + } }; constexpr uint32_t kMaxInterShaderIOLocation = kMaxInterStageShaderComponents / 4 - 1; // It is allowed to create a shader module with the maximum active vertex output location == 14; - { - std::string vertexShader = - generateShaderForTest(kMaxInterShaderIOLocation, wgpu::ShaderStage::Vertex); - utils::CreateShaderModule(device, vertexShader.c_str()); - } + CheckTestPipeline(true, kMaxInterShaderIOLocation, wgpu::ShaderStage::Vertex); // It isn't allowed to create a shader module with the maximum active vertex output location > // 14; - { - std::string vertexShader = - generateShaderForTest(kMaxInterShaderIOLocation + 1, wgpu::ShaderStage::Vertex); - ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, vertexShader.c_str())); - } + CheckTestPipeline(false, kMaxInterShaderIOLocation + 1, wgpu::ShaderStage::Vertex); // It is allowed to create a shader module with the maximum active fragment input location == // 14; - { - std::string fragmentShader = - generateShaderForTest(kMaxInterShaderIOLocation, wgpu::ShaderStage::Fragment); - utils::CreateShaderModule(device, fragmentShader.c_str()); - } + CheckTestPipeline(true, kMaxInterShaderIOLocation, wgpu::ShaderStage::Fragment); // It is allowed to create a shader module with the maximum active vertex output location > 14; - { - std::string fragmentShader = - generateShaderForTest(kMaxInterShaderIOLocation + 1, wgpu::ShaderStage::Fragment); - ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, fragmentShader.c_str())); - } + CheckTestPipeline(false, kMaxInterShaderIOLocation + 1, wgpu::ShaderStage::Fragment); } // Validate the maximum number of total inter-stage user-defined variable component count and // built-in variables cannot exceed kMaxInterStageShaderComponents. TEST_F(ShaderModuleValidationTest, MaximumInterStageShaderComponents) { - auto generateShaderForTest = [](uint32_t totalUserDefinedInterStageShaderComponentCount, - wgpu::ShaderStage shaderStage, - const char* builtInDeclarations) { + auto CheckTestPipeline = [&](bool success, + uint32_t totalUserDefinedInterStageShaderComponentCount, + wgpu::ShaderStage failingShaderStage, + const char* extraBuiltInDeclarations = "") { + // Build the ShaderIO struct containing totalUserDefinedInterStageShaderComponentCount + // components. Components are added in two parts, a bunch of vec4s, then one additional + // variable for the remaining components. std::ostringstream stream; - stream << "struct ShaderIO {" << std::endl << builtInDeclarations << std::endl; + stream << "struct ShaderIO {" << std::endl << extraBuiltInDeclarations << std::endl; uint32_t vec4InputLocations = totalUserDefinedInterStageShaderComponentCount / 4; for (uint32_t location = 0; location < vec4InputLocations; ++location) { - stream << "@location(" << location << ") var" << location << ": vec4;" + stream << "@location(" << location << ") var" << location << ": vec4," << std::endl; } @@ -300,163 +328,161 @@ TEST_F(ShaderModuleValidationTest, MaximumInterStageShaderComponents) { if (lastComponentCount > 0) { stream << "@location(" << vec4InputLocations << ") var" << vec4InputLocations << ": "; if (lastComponentCount == 1) { - stream << "f32;"; + stream << "f32,"; } else { - stream << " vec" << lastComponentCount << ";"; + stream << " vec" << lastComponentCount << ","; } stream << std::endl; } - switch (shaderStage) { + if (failingShaderStage == wgpu::ShaderStage::Vertex) { + stream << " @builtin(position) pos: vec4,"; + } + stream << "}\n"; + + std::string ioStruct = stream.str(); + + // Build the test pipeline. Note that it's not possible with just ASSERT_DEVICE_ERROR + // whether it is the vertex or fragment shader that fails. So instead we will look for the + // string "failingVertex" or "failingFragment" in the error message. + utils::ComboRenderPipelineDescriptor pDesc; + pDesc.cTargets[0].format = wgpu::TextureFormat::RGBA8Unorm; + + const char* errorMatcher = nullptr; + switch (failingShaderStage) { case wgpu::ShaderStage::Vertex: { - stream << R"( - @builtin(position) pos: vec4; - }; - @stage(vertex) fn main() -> ShaderIO { - var shaderIO : ShaderIO; - shaderIO.pos = vec4(0.0, 0.0, 0.0, 1.0); - return shaderIO; - })"; - } break; + errorMatcher = "failingVertex"; + pDesc.vertex.entryPoint = "failingVertex"; + pDesc.vertex.module = utils::CreateShaderModule(device, (ioStruct + R"( + @stage(vertex) fn failingVertex() -> ShaderIO { + var shaderIO : ShaderIO; + shaderIO.pos = vec4(0.0, 0.0, 0.0, 1.0); + return shaderIO; + } + )") + .c_str()); + pDesc.cFragment.module = utils::CreateShaderModule(device, R"( + @stage(fragment) fn main() -> @location(0) vec4 { + return vec4(0.0); + } + )"); + break; + } case wgpu::ShaderStage::Fragment: { - stream << R"( - }; - @stage(fragment) fn main(shaderIO: ShaderIO) -> @location(0) vec4 { - return vec4(0.0, 0.0, 0.0, 1.0); - })"; - } break; + errorMatcher = "failingFragment"; + pDesc.cFragment.entryPoint = "failingFragment"; + pDesc.cFragment.module = utils::CreateShaderModule(device, (ioStruct + R"( + @stage(fragment) fn failingFragment(io : ShaderIO) -> @location(0) vec4 { + return vec4(0.0); + } + )") + .c_str()); + pDesc.vertex.module = utils::CreateShaderModule(device, R"( + @stage(vertex) fn main() -> @builtin(position) vec4 { + return vec4(0.0); + } + )"); + break; + } - case wgpu::ShaderStage::Compute: default: UNREACHABLE(); } - return stream.str(); + if (success) { + ASSERT_DEVICE_ERROR( + device.CreateRenderPipeline(&pDesc), + testing::HasSubstr( + "One or more fragment inputs and vertex outputs are not one-to-one matching")); + } else { + ASSERT_DEVICE_ERROR(device.CreateRenderPipeline(&pDesc), + testing::HasSubstr(errorMatcher)); + } }; // Verify when there is no input builtin variable in a fragment shader, the total user-defined // input component count must be less than kMaxInterStageShaderComponents. { - constexpr uint32_t kInterStageShaderComponentCount = kMaxInterStageShaderComponents; - std::string correctFragmentShader = - generateShaderForTest(kInterStageShaderComponentCount, wgpu::ShaderStage::Fragment, ""); - utils::CreateShaderModule(device, correctFragmentShader.c_str()); - - std::string errorFragmentShader = generateShaderForTest(kInterStageShaderComponentCount + 1, - wgpu::ShaderStage::Fragment, ""); - ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, errorFragmentShader.c_str())); + CheckTestPipeline(true, kMaxInterStageShaderComponents, wgpu::ShaderStage::Fragment); + CheckTestPipeline(false, kMaxInterStageShaderComponents + 1, wgpu::ShaderStage::Fragment); } - // @position should be counted into the maximum inter-stage component count. + // @builtin(position) should be counted into the maximum inter-stage component count. // Note that in vertex shader we always have @position so we don't need to specify it // again in the parameter "builtInDeclarations" of generateShaderForTest(). { - constexpr uint32_t kInterStageShaderComponentCount = kMaxInterStageShaderComponents - 4; - std::string vertexShader = - generateShaderForTest(kInterStageShaderComponentCount, wgpu::ShaderStage::Vertex, ""); - utils::CreateShaderModule(device, vertexShader.c_str()); - - std::string fragmentShader = - generateShaderForTest(kInterStageShaderComponentCount, wgpu::ShaderStage::Fragment, - "@builtin(position) fragCoord: vec4;"); - utils::CreateShaderModule(device, fragmentShader.c_str()); + CheckTestPipeline(true, kMaxInterStageShaderComponents - 4, wgpu::ShaderStage::Vertex); + CheckTestPipeline(false, kMaxInterStageShaderComponents - 3, wgpu::ShaderStage::Vertex); } + // @builtin(position) in fragment shaders should be counted into the maximum inter-stage + // component count. { - constexpr uint32_t kInterStageShaderComponentCount = kMaxInterStageShaderComponents - 3; - std::string vertexShader = - generateShaderForTest(kInterStageShaderComponentCount, wgpu::ShaderStage::Vertex, ""); - ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, vertexShader.c_str())); - - std::string fragmentShader = - generateShaderForTest(kInterStageShaderComponentCount, wgpu::ShaderStage::Fragment, - "@builtin(position) fragCoord: vec4;"); - ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, fragmentShader.c_str())); + CheckTestPipeline(true, kMaxInterStageShaderComponents - 4, wgpu::ShaderStage::Fragment, + "@builtin(position) fragCoord : vec4,"); + CheckTestPipeline(false, kMaxInterStageShaderComponents - 3, wgpu::ShaderStage::Fragment, + "@builtin(position) fragCoord : vec4,"); } - // front_facing should be counted into the maximum inter-stage component count. + // @builtin(front_facing) should be counted into the maximum inter-stage component count. { - const char* builtinDeclaration = "@builtin(front_facing) frontFacing : bool;"; - - { - std::string fragmentShader = - generateShaderForTest(kMaxInterStageShaderComponents - 1, - wgpu::ShaderStage::Fragment, builtinDeclaration); - utils::CreateShaderModule(device, fragmentShader.c_str()); - } - - { - std::string fragmentShader = generateShaderForTest( - kMaxInterStageShaderComponents, wgpu::ShaderStage::Fragment, builtinDeclaration); - ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, fragmentShader.c_str())); - } + CheckTestPipeline(true, kMaxInterStageShaderComponents - 1, wgpu::ShaderStage::Fragment, + "@builtin(front_facing) frontFacing : bool,"); + CheckTestPipeline(false, kMaxInterStageShaderComponents, wgpu::ShaderStage::Fragment, + "@builtin(front_facing) frontFacing : bool,"); } - // @sample_index should be counted into the maximum inter-stage component count. + // @builtin(sample_index) should be counted into the maximum inter-stage component count. { - const char* builtinDeclaration = "@builtin(sample_index) sampleIndex: u32;"; - - { - std::string fragmentShader = - generateShaderForTest(kMaxInterStageShaderComponents - 1, - wgpu::ShaderStage::Fragment, builtinDeclaration); - utils::CreateShaderModule(device, fragmentShader.c_str()); - } - - { - std::string fragmentShader = generateShaderForTest( - kMaxInterStageShaderComponents, wgpu::ShaderStage::Fragment, builtinDeclaration); - ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, fragmentShader.c_str())); - } + CheckTestPipeline(true, kMaxInterStageShaderComponents - 1, wgpu::ShaderStage::Fragment, + "@builtin(sample_index) sampleIndex : u32,"); + CheckTestPipeline(false, kMaxInterStageShaderComponents, wgpu::ShaderStage::Fragment, + "@builtin(sample_index) sampleIndex : u32,"); } - // @sample_mask should be counted into the maximum inter-stage component count. + // @builtin(sample_mask) should be counted into the maximum inter-stage component count. { - const char* builtinDeclaration = "@builtin(front_facing) frontFacing : bool;"; - - { - std::string fragmentShader = - generateShaderForTest(kMaxInterStageShaderComponents - 1, - wgpu::ShaderStage::Fragment, builtinDeclaration); - utils::CreateShaderModule(device, fragmentShader.c_str()); - } - - { - std::string fragmentShader = generateShaderForTest( - kMaxInterStageShaderComponents, wgpu::ShaderStage::Fragment, builtinDeclaration); - ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, fragmentShader.c_str())); - } + CheckTestPipeline(true, kMaxInterStageShaderComponents - 1, wgpu::ShaderStage::Fragment, + "@builtin(sample_mask) sampleMask : u32,"); + CheckTestPipeline(false, kMaxInterStageShaderComponents, wgpu::ShaderStage::Fragment, + "@builtin(sample_mask) sampleMask : u32,"); } } // Tests that we validate workgroup size limits. TEST_F(ShaderModuleValidationTest, ComputeWorkgroupSizeLimits) { - auto MakeShaderWithWorkgroupSize = [this](uint32_t x, uint32_t y, uint32_t z) { + auto CheckShaderWithWorkgroupSize = [this](bool success, uint32_t x, uint32_t y, uint32_t z) { std::ostringstream ss; ss << "@stage(compute) @workgroup_size(" << x << "," << y << "," << z << ") fn main() {}"; - utils::CreateShaderModule(device, ss.str().c_str()); + + wgpu::ComputePipelineDescriptor desc; + desc.compute.entryPoint = "main"; + desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str()); + + if (success) { + device.CreateComputePipeline(&desc); + } else { + ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc)); + } }; wgpu::Limits supportedLimits = GetSupportedLimits().limits; - MakeShaderWithWorkgroupSize(1, 1, 1); - MakeShaderWithWorkgroupSize(supportedLimits.maxComputeWorkgroupSizeX, 1, 1); - MakeShaderWithWorkgroupSize(1, supportedLimits.maxComputeWorkgroupSizeY, 1); - MakeShaderWithWorkgroupSize(1, 1, supportedLimits.maxComputeWorkgroupSizeZ); + CheckShaderWithWorkgroupSize(true, 1, 1, 1); + CheckShaderWithWorkgroupSize(true, supportedLimits.maxComputeWorkgroupSizeX, 1, 1); + CheckShaderWithWorkgroupSize(true, 1, supportedLimits.maxComputeWorkgroupSizeY, 1); + CheckShaderWithWorkgroupSize(true, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ); - ASSERT_DEVICE_ERROR( - MakeShaderWithWorkgroupSize(supportedLimits.maxComputeWorkgroupSizeX + 1, 1, 1)); - ASSERT_DEVICE_ERROR( - MakeShaderWithWorkgroupSize(1, supportedLimits.maxComputeWorkgroupSizeY + 1, 1)); - ASSERT_DEVICE_ERROR( - MakeShaderWithWorkgroupSize(1, 1, supportedLimits.maxComputeWorkgroupSizeZ + 1)); + CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX + 1, 1, 1); + CheckShaderWithWorkgroupSize(false, 1, supportedLimits.maxComputeWorkgroupSizeY + 1, 1); + CheckShaderWithWorkgroupSize(false, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ + 1); // No individual dimension exceeds its limit, but the combined size should definitely exceed the // total invocation limit. - ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupSize(supportedLimits.maxComputeWorkgroupSizeX, - supportedLimits.maxComputeWorkgroupSizeY, - supportedLimits.maxComputeWorkgroupSizeZ)); + CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX, + supportedLimits.maxComputeWorkgroupSizeY, + supportedLimits.maxComputeWorkgroupSizeZ); } // Tests that we validate workgroup storage size limits. @@ -468,7 +494,8 @@ TEST_F(ShaderModuleValidationTest, ComputeWorkgroupStorageSizeLimits) { constexpr uint32_t kMat4Size = 64; const uint32_t maxMat4Count = supportedLimits.maxComputeWorkgroupStorageSize / kMat4Size; - auto MakeShaderWithWorkgroupStorage = [this](uint32_t vec4_count, uint32_t mat4_count) { + auto CheckPipelineWithWorkgroupStorage = [this](bool success, uint32_t vec4_count, + uint32_t mat4_count) { std::ostringstream ss; std::ostringstream body; if (vec4_count > 0) { @@ -480,18 +507,28 @@ TEST_F(ShaderModuleValidationTest, ComputeWorkgroupStorageSizeLimits) { body << "_ = mat4_data;"; } ss << "@stage(compute) @workgroup_size(1) fn main() { " << body.str() << " }"; - utils::CreateShaderModule(device, ss.str().c_str()); + + wgpu::ComputePipelineDescriptor desc; + desc.compute.entryPoint = "main"; + desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str()); + + if (success) { + device.CreateComputePipeline(&desc); + } else { + ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc)); + } }; - MakeShaderWithWorkgroupStorage(1, 1); - MakeShaderWithWorkgroupStorage(maxVec4Count, 0); - MakeShaderWithWorkgroupStorage(0, maxMat4Count); - MakeShaderWithWorkgroupStorage(maxVec4Count - 4, 1); - MakeShaderWithWorkgroupStorage(4, maxMat4Count - 1); - ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupStorage(maxVec4Count + 1, 0)); - ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupStorage(maxVec4Count - 3, 1)); - ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupStorage(0, maxMat4Count + 1)); - ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupStorage(4, maxMat4Count)); + CheckPipelineWithWorkgroupStorage(true, 1, 1); + CheckPipelineWithWorkgroupStorage(true, maxVec4Count, 0); + CheckPipelineWithWorkgroupStorage(true, 0, maxMat4Count); + CheckPipelineWithWorkgroupStorage(true, maxVec4Count - 4, 1); + CheckPipelineWithWorkgroupStorage(true, 4, maxMat4Count - 1); + + CheckPipelineWithWorkgroupStorage(false, maxVec4Count + 1, 0); + CheckPipelineWithWorkgroupStorage(false, maxVec4Count - 3, 1); + CheckPipelineWithWorkgroupStorage(false, 0, maxMat4Count + 1); + CheckPipelineWithWorkgroupStorage(false, 4, maxMat4Count); } // Test that numeric ID must be unique @@ -517,21 +554,24 @@ struct Buf { TEST_F(ShaderModuleValidationTest, MaxBindingNumber) { static_assert(kMaxBindingNumber == 65535); + wgpu::ComputePipelineDescriptor desc; + desc.compute.entryPoint = "main"; + // kMaxBindingNumber is valid. - utils::CreateShaderModule(device, R"( + desc.compute.module = utils::CreateShaderModule(device, R"( @group(0) @binding(65535) var s : sampler; - @stage(fragment) fn main() -> @location(0) u32 { + @stage(compute) @workgroup_size(1) fn main() { _ = s; - return 0u; } )"); + device.CreateComputePipeline(&desc); // kMaxBindingNumber + 1 is an error - ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, R"( + desc.compute.module = utils::CreateShaderModule(device, R"( @group(0) @binding(65536) var s : sampler; - @stage(fragment) fn main() -> @location(0) u32 { + @stage(compute) @workgroup_size(1) fn main() { _ = s; - return 0u; } - )")); + )"); + ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc)); }