From b7c7f62829b0ed8297d5f35539c862d94d098a64 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Fri, 30 Jul 2021 00:40:26 +0000 Subject: [PATCH] Implement inter-stage variable matching rules - Part II This patch implements the inter-stage variable matching rules on the interpolation attributes ('interpolation type' and 'interpolation sampling'). WebGPU SPEC requires that the interpolation attributes must match between vertex outputs and fragment inputs with the same location assignment within the same pipeline. BUG=dawn:802 TEST=dawn_unittests Change-Id: Ied38d68f73868c30b0392954683963a801e3f3aa Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/60160 Commit-Queue: Jiawei Shao Reviewed-by: Corentin Wallez Reviewed-by: Austin Eng --- src/dawn_native/RenderPipeline.cpp | 7 + src/dawn_native/ShaderModule.cpp | 46 ++++++ src/dawn_native/ShaderModule.h | 15 ++ .../RenderPipelineValidationTests.cpp | 139 ++++++++++++++++++ 4 files changed, 207 insertions(+) diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp index 7eff93db3d..9b18105955 100644 --- a/src/dawn_native/RenderPipeline.cpp +++ b/src/dawn_native/RenderPipeline.cpp @@ -351,6 +351,13 @@ namespace dawn_native { if (vertexOutputInfo.componentCount != fragmentInputInfo.componentCount) { return DAWN_VALIDATION_ERROR(generateErrorString("componentCount", i)); } + if (vertexOutputInfo.interpolationType != fragmentInputInfo.interpolationType) { + return DAWN_VALIDATION_ERROR(generateErrorString("interpolation type", i)); + } + if (vertexOutputInfo.interpolationSampling != + fragmentInputInfo.interpolationSampling) { + return DAWN_VALIDATION_ERROR(generateErrorString("interpolation sampling", i)); + } } return {}; diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp index 2fe2f5ccfb..79a9514b29 100644 --- a/src/dawn_native/ShaderModule.cpp +++ b/src/dawn_native/ShaderModule.cpp @@ -370,6 +370,38 @@ namespace dawn_native { } } + ResultOrError TintInterpolationTypeToInterpolationType( + tint::inspector::InterpolationType type) { + switch (type) { + case tint::inspector::InterpolationType::kPerspective: + return InterpolationType::Perspective; + case tint::inspector::InterpolationType::kLinear: + return InterpolationType::Linear; + case tint::inspector::InterpolationType::kFlat: + return InterpolationType::Flat; + case tint::inspector::InterpolationType::kUnknown: + return DAWN_VALIDATION_ERROR( + "Attempted to convert 'Unknown' interpolation type from Tint"); + } + } + + ResultOrError TintInterpolationSamplingToInterpolationSamplingType( + tint::inspector::InterpolationSampling type) { + switch (type) { + case tint::inspector::InterpolationSampling::kNone: + return InterpolationSampling::None; + case tint::inspector::InterpolationSampling::kCenter: + return InterpolationSampling::Center; + case tint::inspector::InterpolationSampling::kCentroid: + return InterpolationSampling::Centroid; + case tint::inspector::InterpolationSampling::kSample: + return InterpolationSampling::Sample; + case tint::inspector::InterpolationSampling::kUnknown: + return DAWN_VALIDATION_ERROR( + "Attempted to convert 'Unknown' interpolation sampling type from Tint"); + } + } + MaybeError ValidateSpirv(const uint32_t* code, uint32_t codeSize) { spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_1); @@ -1041,6 +1073,13 @@ namespace dawn_native { DAWN_TRY_ASSIGN(metadata->interStageVariables[location].componentCount, TintCompositionTypeToInterStageComponentCount( output_var.composition_type)); + DAWN_TRY_ASSIGN(metadata->interStageVariables[location].interpolationType, + TintInterpolationTypeToInterpolationType( + output_var.interpolation_type)); + DAWN_TRY_ASSIGN( + metadata->interStageVariables[location].interpolationSampling, + TintInterpolationSamplingToInterpolationSamplingType( + output_var.interpolation_sampling)); } } @@ -1063,6 +1102,13 @@ namespace dawn_native { DAWN_TRY_ASSIGN(metadata->interStageVariables[location].componentCount, TintCompositionTypeToInterStageComponentCount( input_var.composition_type)); + DAWN_TRY_ASSIGN( + metadata->interStageVariables[location].interpolationType, + TintInterpolationTypeToInterpolationType(input_var.interpolation_type)); + DAWN_TRY_ASSIGN( + metadata->interStageVariables[location].interpolationSampling, + TintInterpolationSamplingToInterpolationSamplingType( + input_var.interpolation_sampling)); } for (const auto& output_var : entryPoint.output_variables) { diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h index 2e501537f0..613033aeb5 100644 --- a/src/dawn_native/ShaderModule.h +++ b/src/dawn_native/ShaderModule.h @@ -60,6 +60,19 @@ namespace dawn_native { Float, }; + enum class InterpolationType { + Perspective, + Linear, + Flat, + }; + + enum class InterpolationSampling { + None, + Center, + Centroid, + Sample, + }; + using PipelineLayoutEntryPointPair = std::pair; struct PipelineLayoutEntryPointPairHashFunc { size_t operator()(const PipelineLayoutEntryPointPair& pair) const; @@ -169,6 +182,8 @@ namespace dawn_native { struct InterStageVariableInfo { InterStageComponentType baseType; uint32_t componentCount; + InterpolationType interpolationType; + InterpolationSampling interpolationSampling; }; // Now that we only support vertex and fragment stages, there can't be both inter-stage // inputs and outputs in one shader stage. diff --git a/src/tests/unittests/validation/RenderPipelineValidationTests.cpp b/src/tests/unittests/validation/RenderPipelineValidationTests.cpp index 628a37a7a2..d6e60ede4e 100644 --- a/src/tests/unittests/validation/RenderPipelineValidationTests.cpp +++ b/src/tests/unittests/validation/RenderPipelineValidationTests.cpp @@ -1087,3 +1087,142 @@ TEST_F(InterStageVariableMatchingValidationTest, DifferentTypeAtSameLocation) { } } } + +// Tests that creating render pipeline should fail when the interpolation attribute of a vertex +// stage output variable doesn't match the type of the fragment stage input variable at the same +// location. +TEST_F(InterStageVariableMatchingValidationTest, DifferentInterpolationAttributeAtSameLocation) { + enum class InterpolationType : uint8_t { + None = 0, + Perspective, + Linear, + Flat, + Count, + }; + enum class InterpolationSampling : uint8_t { + None = 0, + Center, + Centroid, + Sample, + Count, + }; + constexpr std::array(InterpolationType::Count)> + kInterpolationTypeString = {{"", "perspective", "linear", "flat"}}; + constexpr std::array(InterpolationSampling::Count)> + kInterpolationSamplingString = {{"", "center", "centroid", "sample"}}; + + struct InterpolationAttribute { + InterpolationType interpolationType; + InterpolationSampling interpolationSampling; + }; + + // Interpolation sampling is not used with flat interpolation. + constexpr std::array validInterpolationAttributes = {{ + {InterpolationType::None, InterpolationSampling::None}, + {InterpolationType::Flat, InterpolationSampling::None}, + {InterpolationType::Linear, InterpolationSampling::None}, + {InterpolationType::Linear, InterpolationSampling::Center}, + {InterpolationType::Linear, InterpolationSampling::Centroid}, + {InterpolationType::Linear, InterpolationSampling::Sample}, + {InterpolationType::Perspective, InterpolationSampling::None}, + {InterpolationType::Perspective, InterpolationSampling::Center}, + {InterpolationType::Perspective, InterpolationSampling::Centroid}, + {InterpolationType::Perspective, InterpolationSampling::Sample}, + }}; + + std::vector vertexModules(validInterpolationAttributes.size()); + std::vector fragmentModules(validInterpolationAttributes.size()); + for (uint32_t i = 0; i < validInterpolationAttributes.size(); ++i) { + std::string interfaceDeclaration; + { + const auto& interpolationAttribute = validInterpolationAttributes[i]; + std::ostringstream sstream; + sstream << "struct A { [[location(0)"; + if (interpolationAttribute.interpolationType != InterpolationType::None) { + sstream << ", interpolate(" + << kInterpolationTypeString[static_cast( + interpolationAttribute.interpolationType)]; + if (interpolationAttribute.interpolationSampling != InterpolationSampling::None) { + sstream << ", " + << kInterpolationSamplingString[static_cast( + interpolationAttribute.interpolationSampling)]; + } + sstream << ")"; + } + sstream << " ]] a : vec4;" << std::endl; + interfaceDeclaration = sstream.str(); + } + { + std::ostringstream vertexStream; + vertexStream << interfaceDeclaration << R"( + [[builtin(position)]] pos: vec4; + }; + [[stage(vertex)]] fn main() -> A { + var vertexOut: A; + vertexOut.pos = vec4(0.0, 0.0, 0.0, 1.0); + return vertexOut; + })"; + vertexModules[i] = utils::CreateShaderModule(device, vertexStream.str().c_str()); + } + { + std::ostringstream fragmentStream; + fragmentStream << interfaceDeclaration << R"( + }; + [[stage(fragment)]] fn main(fragmentIn: A) -> [[location(0)]] vec4 { + return fragmentIn.a; + })"; + fragmentModules[i] = utils::CreateShaderModule(device, fragmentStream.str().c_str()); + } + } + + auto GetAppliedInterpolationAttribute = [](const InterpolationAttribute& attribute) { + InterpolationAttribute appliedAttribute = {attribute.interpolationType, + attribute.interpolationSampling}; + switch (attribute.interpolationType) { + // If the interpolation attribute is not specified, then + // [[interpolate(perspective, center)]] or [[interpolate(perspective)]] is assumed. + case InterpolationType::None: + appliedAttribute.interpolationType = InterpolationType::Perspective; + appliedAttribute.interpolationSampling = InterpolationSampling::Center; + break; + + // If the interpolation type is perspective or linear, and the interpolation + // sampling is not specified, then 'center' is assumed. + case InterpolationType::Perspective: + case InterpolationType::Linear: + if (appliedAttribute.interpolationSampling == InterpolationSampling::None) { + appliedAttribute.interpolationSampling = InterpolationSampling::Center; + } + break; + + case InterpolationType::Flat: + break; + default: + UNREACHABLE(); + } + return appliedAttribute; + }; + + auto InterpolationAttributeMatch = [GetAppliedInterpolationAttribute]( + const InterpolationAttribute& attribute1, + const InterpolationAttribute& attribute2) { + InterpolationAttribute appliedAttribute1 = GetAppliedInterpolationAttribute(attribute1); + InterpolationAttribute appliedAttribute2 = GetAppliedInterpolationAttribute(attribute2); + + return appliedAttribute1.interpolationType == appliedAttribute2.interpolationType && + appliedAttribute1.interpolationSampling == appliedAttribute2.interpolationSampling; + }; + + for (uint32_t vertexModuleIndex = 0; vertexModuleIndex < validInterpolationAttributes.size(); + ++vertexModuleIndex) { + wgpu::ShaderModule vertexModule = vertexModules[vertexModuleIndex]; + for (uint32_t fragmentModuleIndex = 0; + fragmentModuleIndex < validInterpolationAttributes.size(); ++fragmentModuleIndex) { + wgpu::ShaderModule fragmentModule = fragmentModules[fragmentModuleIndex]; + bool shouldSuccess = + InterpolationAttributeMatch(validInterpolationAttributes[vertexModuleIndex], + validInterpolationAttributes[fragmentModuleIndex]); + CheckCreatingRenderPipeline(vertexModule, fragmentModule, shouldSuccess); + } + } +}