diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp index 4492817781..b2a44240b9 100644 --- a/src/dawn_native/RenderPipeline.cpp +++ b/src/dawn_native/RenderPipeline.cpp @@ -117,7 +117,8 @@ namespace dawn_native { } MaybeError ValidateColorStateDescriptor(const DeviceBase* device, - const ColorStateDescriptor& descriptor) { + const ColorStateDescriptor& descriptor, + Format::Type fragmentOutputBaseType) { if (descriptor.nextInChain != nullptr) { return DAWN_VALIDATION_ERROR("nextInChain must be nullptr"); } @@ -134,6 +135,11 @@ namespace dawn_native { if (!format->IsColor() || !format->isRenderable) { return DAWN_VALIDATION_ERROR("Color format must be color renderable"); } + if (fragmentOutputBaseType != Format::Type::Other && + fragmentOutputBaseType != format->type) { + return DAWN_VALIDATION_ERROR( + "Color format must match the fragment stage output type"); + } return {}; } @@ -310,8 +316,12 @@ namespace dawn_native { return DAWN_VALIDATION_ERROR("Should have at least one attachment"); } + ASSERT(descriptor->fragmentStage != nullptr); + const ShaderModuleBase::FragmentOutputBaseTypes& fragmentOutputBaseTypes = + descriptor->fragmentStage->module->GetFragmentOutputBaseTypes(); for (uint32_t i = 0; i < descriptor->colorStateCount; ++i) { - DAWN_TRY(ValidateColorStateDescriptor(device, descriptor->colorStates[i])); + DAWN_TRY(ValidateColorStateDescriptor(device, descriptor->colorStates[i], + fragmentOutputBaseTypes[i])); } if (descriptor->depthStencilState) { diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp index e3a95a038b..f19a123b34 100644 --- a/src/dawn_native/ShaderModule.cpp +++ b/src/dawn_native/ShaderModule.cpp @@ -27,6 +27,22 @@ namespace dawn_native { + namespace { + Format::Type SpirvCrossBaseTypeToFormatType(spirv_cross::SPIRType::BaseType spirvBaseType) { + switch (spirvBaseType) { + case spirv_cross::SPIRType::Float: + return Format::Float; + case spirv_cross::SPIRType::Int: + return Format::Sint; + case spirv_cross::SPIRType::UInt: + return Format::Uint; + default: + UNREACHABLE(); + return Format::Other; + } + } + } // anonymous namespace + MaybeError ValidateShaderModuleDescriptor(DeviceBase*, const ShaderModuleDescriptor* descriptor) { if (descriptor->nextInChain != nullptr) { @@ -74,6 +90,7 @@ namespace dawn_native { : ObjectBase(device), mCode(descriptor->code, descriptor->code + descriptor->codeSize), mIsBlueprint(blueprint) { + mFragmentOutputFormatBaseTypes.fill(Format::Other); } ShaderModuleBase::ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag) @@ -201,6 +218,13 @@ namespace dawn_native { "Fragment output location over limits in the SPIRV"); return; } + + spirv_cross::SPIRType::BaseType shaderFragmentOutputBaseType = + compiler.get_type(fragmentOutput.base_type_id).basetype; + Format::Type formatType = + SpirvCrossBaseTypeToFormatType(shaderFragmentOutputBaseType); + ASSERT(formatType != Format::Type::Other); + mFragmentOutputFormatBaseTypes[location] = formatType; } } } @@ -215,6 +239,12 @@ namespace dawn_native { return mUsedVertexAttributes; } + const ShaderModuleBase::FragmentOutputBaseTypes& ShaderModuleBase::GetFragmentOutputBaseTypes() + const { + ASSERT(!IsError()); + return mFragmentOutputFormatBaseTypes; + } + SingleShaderStage ShaderModuleBase::GetExecutionModel() const { ASSERT(!IsError()); return mExecutionModel; diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h index 35c90207c4..f2c133c835 100644 --- a/src/dawn_native/ShaderModule.h +++ b/src/dawn_native/ShaderModule.h @@ -17,6 +17,7 @@ #include "common/Constants.h" #include "dawn_native/Error.h" +#include "dawn_native/Format.h" #include "dawn_native/Forward.h" #include "dawn_native/ObjectBase.h" #include "dawn_native/PerStage.h" @@ -61,6 +62,11 @@ namespace dawn_native { const std::bitset& GetUsedVertexAttributes() const; SingleShaderStage GetExecutionModel() const; + // An array to record the basic types (float, int and uint) of the fragment shader outputs + // or Format::Type::Other means the fragment shader output is unused. + using FragmentOutputBaseTypes = std::array; + const FragmentOutputBaseTypes& GetFragmentOutputBaseTypes() const; + bool IsCompatibleWithPipelineLayout(const PipelineLayoutBase* layout); // Functors necessary for the unordered_set-based cache. @@ -84,6 +90,8 @@ namespace dawn_native { ModuleBindingInfo mBindingInfo; std::bitset mUsedVertexAttributes; SingleShaderStage mExecutionModel; + + FragmentOutputBaseTypes mFragmentOutputFormatBaseTypes; }; } // namespace dawn_native diff --git a/src/tests/unittests/validation/BindGroupValidationTests.cpp b/src/tests/unittests/validation/BindGroupValidationTests.cpp index 6d464bf2d1..bb732a14f5 100644 --- a/src/tests/unittests/validation/BindGroupValidationTests.cpp +++ b/src/tests/unittests/validation/BindGroupValidationTests.cpp @@ -649,7 +649,7 @@ class SetBindGroupValidationTest : public ValidationTest { layout(std140, set = 0, binding = 1) buffer SBuffer { vec2 value2; } sBuffer; - layout(location = 0) out uvec4 fragColor; + layout(location = 0) out vec4 fragColor; void main() { })"); diff --git a/src/tests/unittests/validation/RenderPipelineValidationTests.cpp b/src/tests/unittests/validation/RenderPipelineValidationTests.cpp index da11021c9d..4be8f671e4 100644 --- a/src/tests/unittests/validation/RenderPipelineValidationTests.cpp +++ b/src/tests/unittests/validation/RenderPipelineValidationTests.cpp @@ -18,6 +18,8 @@ #include "utils/ComboRenderPipelineDescriptor.h" #include "utils/DawnHelpers.h" +#include + class RenderPipelineValidationTest : public ValidationTest { protected: void SetUp() override { @@ -114,6 +116,39 @@ TEST_F(RenderPipelineValidationTest, NonRenderableFormat) { } } +// Tests that the format of the color state descriptor must match the output of the fragment shader. +TEST_F(RenderPipelineValidationTest, FragmentOutputFormatCompatibility) { + constexpr uint32_t kNumTextureFormatBaseType = 3u; + std::array kVecPreFix = {{"", "i", "u"}}; + std::array kColorFormats = { + {dawn::TextureFormat::RGBA8Unorm, dawn::TextureFormat::RGBA8Sint, + dawn::TextureFormat::RGBA8Uint}}; + + for (size_t i = 0; i < kNumTextureFormatBaseType; ++i) { + for (size_t j = 0; j < kNumTextureFormatBaseType; ++j) { + utils::ComboRenderPipelineDescriptor descriptor(device); + descriptor.vertexStage.module = vsModule; + descriptor.cColorStates[0].format = kColorFormats[j]; + + std::ostringstream stream; + stream << R"( + #version 450 + layout(location = 0) out )" + << kVecPreFix[i] << R"(vec4 fragColor; + void main() { + })"; + descriptor.cFragmentStage.module = utils::CreateShaderModule( + device, utils::SingleShaderStage::Fragment, stream.str().c_str()); + + if (i == j) { + device.CreateRenderPipeline(&descriptor); + } else { + ASSERT_DEVICE_ERROR(device.CreateRenderPipeline(&descriptor)); + } + } + } +} + /// Tests that the sample count of the render pipeline must be valid. TEST_F(RenderPipelineValidationTest, SampleCount) { {