From 64f4dd71278a76e1cffb4578e6441a09fd231283 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Thu, 26 Sep 2019 00:12:41 +0000 Subject: [PATCH] Add check between color state format and fragment shader output This patch adds the validation on the compatibility between the format of the color states and the fragment shader output when we create a render pipeline state object as is required in Vulkan (Vulkan SPEC Chapter 14.3 "Fragment Output Interface"): "if the type of the values written by the fragment shader do not match the format of the corresponding color attachment, the resulting values are undefined for those components". BUG=dawn:202 TEST=dawn_unittests Change-Id: I3a72baa11999bd07c69050c42b094720ef4708b2 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/11461 Reviewed-by: Kai Ninomiya Reviewed-by: Corentin Wallez Commit-Queue: Jiawei Shao --- src/dawn_native/RenderPipeline.cpp | 14 ++++++-- src/dawn_native/ShaderModule.cpp | 30 ++++++++++++++++ src/dawn_native/ShaderModule.h | 8 +++++ .../validation/BindGroupValidationTests.cpp | 2 +- .../RenderPipelineValidationTests.cpp | 35 +++++++++++++++++++ 5 files changed, 86 insertions(+), 3 deletions(-) diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp index 4492817781..b2a44240b9 100644 --- a/src/dawn_native/RenderPipeline.cpp +++ b/src/dawn_native/RenderPipeline.cpp @@ -117,7 +117,8 @@ namespace dawn_native { } MaybeError ValidateColorStateDescriptor(const DeviceBase* device, - const ColorStateDescriptor& descriptor) { + const ColorStateDescriptor& descriptor, + Format::Type fragmentOutputBaseType) { if (descriptor.nextInChain != nullptr) { return DAWN_VALIDATION_ERROR("nextInChain must be nullptr"); } @@ -134,6 +135,11 @@ namespace dawn_native { if (!format->IsColor() || !format->isRenderable) { return DAWN_VALIDATION_ERROR("Color format must be color renderable"); } + if (fragmentOutputBaseType != Format::Type::Other && + fragmentOutputBaseType != format->type) { + return DAWN_VALIDATION_ERROR( + "Color format must match the fragment stage output type"); + } return {}; } @@ -310,8 +316,12 @@ namespace dawn_native { return DAWN_VALIDATION_ERROR("Should have at least one attachment"); } + ASSERT(descriptor->fragmentStage != nullptr); + const ShaderModuleBase::FragmentOutputBaseTypes& fragmentOutputBaseTypes = + descriptor->fragmentStage->module->GetFragmentOutputBaseTypes(); for (uint32_t i = 0; i < descriptor->colorStateCount; ++i) { - DAWN_TRY(ValidateColorStateDescriptor(device, descriptor->colorStates[i])); + DAWN_TRY(ValidateColorStateDescriptor(device, descriptor->colorStates[i], + fragmentOutputBaseTypes[i])); } if (descriptor->depthStencilState) { diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp index e3a95a038b..f19a123b34 100644 --- a/src/dawn_native/ShaderModule.cpp +++ b/src/dawn_native/ShaderModule.cpp @@ -27,6 +27,22 @@ namespace dawn_native { + namespace { + Format::Type SpirvCrossBaseTypeToFormatType(spirv_cross::SPIRType::BaseType spirvBaseType) { + switch (spirvBaseType) { + case spirv_cross::SPIRType::Float: + return Format::Float; + case spirv_cross::SPIRType::Int: + return Format::Sint; + case spirv_cross::SPIRType::UInt: + return Format::Uint; + default: + UNREACHABLE(); + return Format::Other; + } + } + } // anonymous namespace + MaybeError ValidateShaderModuleDescriptor(DeviceBase*, const ShaderModuleDescriptor* descriptor) { if (descriptor->nextInChain != nullptr) { @@ -74,6 +90,7 @@ namespace dawn_native { : ObjectBase(device), mCode(descriptor->code, descriptor->code + descriptor->codeSize), mIsBlueprint(blueprint) { + mFragmentOutputFormatBaseTypes.fill(Format::Other); } ShaderModuleBase::ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag) @@ -201,6 +218,13 @@ namespace dawn_native { "Fragment output location over limits in the SPIRV"); return; } + + spirv_cross::SPIRType::BaseType shaderFragmentOutputBaseType = + compiler.get_type(fragmentOutput.base_type_id).basetype; + Format::Type formatType = + SpirvCrossBaseTypeToFormatType(shaderFragmentOutputBaseType); + ASSERT(formatType != Format::Type::Other); + mFragmentOutputFormatBaseTypes[location] = formatType; } } } @@ -215,6 +239,12 @@ namespace dawn_native { return mUsedVertexAttributes; } + const ShaderModuleBase::FragmentOutputBaseTypes& ShaderModuleBase::GetFragmentOutputBaseTypes() + const { + ASSERT(!IsError()); + return mFragmentOutputFormatBaseTypes; + } + SingleShaderStage ShaderModuleBase::GetExecutionModel() const { ASSERT(!IsError()); return mExecutionModel; diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h index 35c90207c4..f2c133c835 100644 --- a/src/dawn_native/ShaderModule.h +++ b/src/dawn_native/ShaderModule.h @@ -17,6 +17,7 @@ #include "common/Constants.h" #include "dawn_native/Error.h" +#include "dawn_native/Format.h" #include "dawn_native/Forward.h" #include "dawn_native/ObjectBase.h" #include "dawn_native/PerStage.h" @@ -61,6 +62,11 @@ namespace dawn_native { const std::bitset& GetUsedVertexAttributes() const; SingleShaderStage GetExecutionModel() const; + // An array to record the basic types (float, int and uint) of the fragment shader outputs + // or Format::Type::Other means the fragment shader output is unused. + using FragmentOutputBaseTypes = std::array; + const FragmentOutputBaseTypes& GetFragmentOutputBaseTypes() const; + bool IsCompatibleWithPipelineLayout(const PipelineLayoutBase* layout); // Functors necessary for the unordered_set-based cache. @@ -84,6 +90,8 @@ namespace dawn_native { ModuleBindingInfo mBindingInfo; std::bitset mUsedVertexAttributes; SingleShaderStage mExecutionModel; + + FragmentOutputBaseTypes mFragmentOutputFormatBaseTypes; }; } // namespace dawn_native diff --git a/src/tests/unittests/validation/BindGroupValidationTests.cpp b/src/tests/unittests/validation/BindGroupValidationTests.cpp index 6d464bf2d1..bb732a14f5 100644 --- a/src/tests/unittests/validation/BindGroupValidationTests.cpp +++ b/src/tests/unittests/validation/BindGroupValidationTests.cpp @@ -649,7 +649,7 @@ class SetBindGroupValidationTest : public ValidationTest { layout(std140, set = 0, binding = 1) buffer SBuffer { vec2 value2; } sBuffer; - layout(location = 0) out uvec4 fragColor; + layout(location = 0) out vec4 fragColor; void main() { })"); diff --git a/src/tests/unittests/validation/RenderPipelineValidationTests.cpp b/src/tests/unittests/validation/RenderPipelineValidationTests.cpp index da11021c9d..4be8f671e4 100644 --- a/src/tests/unittests/validation/RenderPipelineValidationTests.cpp +++ b/src/tests/unittests/validation/RenderPipelineValidationTests.cpp @@ -18,6 +18,8 @@ #include "utils/ComboRenderPipelineDescriptor.h" #include "utils/DawnHelpers.h" +#include + class RenderPipelineValidationTest : public ValidationTest { protected: void SetUp() override { @@ -114,6 +116,39 @@ TEST_F(RenderPipelineValidationTest, NonRenderableFormat) { } } +// Tests that the format of the color state descriptor must match the output of the fragment shader. +TEST_F(RenderPipelineValidationTest, FragmentOutputFormatCompatibility) { + constexpr uint32_t kNumTextureFormatBaseType = 3u; + std::array kVecPreFix = {{"", "i", "u"}}; + std::array kColorFormats = { + {dawn::TextureFormat::RGBA8Unorm, dawn::TextureFormat::RGBA8Sint, + dawn::TextureFormat::RGBA8Uint}}; + + for (size_t i = 0; i < kNumTextureFormatBaseType; ++i) { + for (size_t j = 0; j < kNumTextureFormatBaseType; ++j) { + utils::ComboRenderPipelineDescriptor descriptor(device); + descriptor.vertexStage.module = vsModule; + descriptor.cColorStates[0].format = kColorFormats[j]; + + std::ostringstream stream; + stream << R"( + #version 450 + layout(location = 0) out )" + << kVecPreFix[i] << R"(vec4 fragColor; + void main() { + })"; + descriptor.cFragmentStage.module = utils::CreateShaderModule( + device, utils::SingleShaderStage::Fragment, stream.str().c_str()); + + if (i == j) { + device.CreateRenderPipeline(&descriptor); + } else { + ASSERT_DEVICE_ERROR(device.CreateRenderPipeline(&descriptor)); + } + } + } +} + /// Tests that the sample count of the render pipeline must be valid. TEST_F(RenderPipelineValidationTest, SampleCount) { {