diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp index 3f59535964..7392f1d9b8 100644 --- a/src/dawn_native/RenderPipeline.cpp +++ b/src/dawn_native/RenderPipeline.cpp @@ -28,16 +28,19 @@ namespace dawn_native { // Helper functions namespace { - MaybeError ValidateVertexAttribute(DeviceBase* device, - const VertexAttribute* attribute, - uint64_t vertexBufferStride, - std::bitset* attributesSetMask) { + MaybeError ValidateVertexAttribute( + DeviceBase* device, + const VertexAttribute* attribute, + const EntryPointMetadata& metadata, + uint64_t vertexBufferStride, + ityp::bitset* attributesSetMask) { DAWN_TRY(ValidateVertexFormat(attribute->format)); const VertexFormatInfo& formatInfo = GetVertexFormatInfo(attribute->format); if (attribute->shaderLocation >= kMaxVertexAttributes) { return DAWN_VALIDATION_ERROR("Setting attribute out of bounds"); } + VertexAttributeLocation location(static_cast(attribute->shaderLocation)); // No underflow is possible because the max vertex format size is smaller than // kMaxVertexBufferArrayStride. @@ -59,18 +62,25 @@ namespace dawn_native { "Attribute offset needs to be a multiple of the size format's components"); } - if ((*attributesSetMask)[attribute->shaderLocation]) { + if (metadata.usedVertexInputs[location] && + formatInfo.baseType != metadata.vertexInputBaseTypes[location]) { + return DAWN_VALIDATION_ERROR( + "Attribute base type must match the base type in the shader."); + } + + if ((*attributesSetMask)[location]) { return DAWN_VALIDATION_ERROR("Setting already set attribute"); } - attributesSetMask->set(attribute->shaderLocation); + attributesSetMask->set(location); return {}; } MaybeError ValidateVertexBufferLayout( DeviceBase* device, const VertexBufferLayout* buffer, - std::bitset* attributesSetMask) { + const EntryPointMetadata& metadata, + ityp::bitset* attributesSetMask) { DAWN_TRY(ValidateInputStepMode(buffer->stepMode)); if (buffer->arrayStride > kMaxVertexBufferArrayStride) { return DAWN_VALIDATION_ERROR("Setting arrayStride out of bounds"); @@ -82,7 +92,7 @@ namespace dawn_native { } for (uint32_t i = 0; i < buffer->attributeCount; ++i) { - DAWN_TRY(ValidateVertexAttribute(device, &buffer->attributes[i], + DAWN_TRY(ValidateVertexAttribute(device, &buffer->attributes[i], metadata, buffer->arrayStride, attributesSetMask)); } @@ -100,10 +110,15 @@ namespace dawn_native { return DAWN_VALIDATION_ERROR("Vertex buffer count exceeds maximum"); } - std::bitset attributesSetMask; + DAWN_TRY(ValidateProgrammableStage(device, descriptor->module, descriptor->entryPoint, + layout, SingleShaderStage::Vertex)); + const EntryPointMetadata& vertexMetadata = + descriptor->module->GetEntryPoint(descriptor->entryPoint); + + ityp::bitset attributesSetMask; uint32_t totalAttributesNum = 0; for (uint32_t i = 0; i < descriptor->bufferCount; ++i) { - DAWN_TRY(ValidateVertexBufferLayout(device, &descriptor->buffers[i], + DAWN_TRY(ValidateVertexBufferLayout(device, &descriptor->buffers[i], vertexMetadata, &attributesSetMask)); totalAttributesNum += descriptor->buffers[i].attributeCount; } @@ -114,11 +129,7 @@ namespace dawn_native { // attribute number never exceed kMaxVertexAttributes. ASSERT(totalAttributesNum <= kMaxVertexAttributes); - DAWN_TRY(ValidateProgrammableStage(device, descriptor->module, descriptor->entryPoint, - layout, SingleShaderStage::Vertex)); - const EntryPointMetadata& vertexMetadata = - descriptor->module->GetEntryPoint(descriptor->entryPoint); - if (!IsSubset(vertexMetadata.usedVertexAttributes, attributesSetMask)) { + if (!IsSubset(vertexMetadata.usedVertexInputs, attributesSetMask)) { return DAWN_VALIDATION_ERROR( "Pipeline vertex stage uses vertex buffers not in the vertex state"); } diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp index 9bd3b7e45d..55b771dad8 100644 --- a/src/dawn_native/ShaderModule.cpp +++ b/src/dawn_native/ShaderModule.cpp @@ -295,6 +295,21 @@ namespace dawn_native { } } + ResultOrError TintComponentTypeToVertexFormatBaseType( + tint::inspector::ComponentType type) { + switch (type) { + case tint::inspector::ComponentType::kFloat: + return VertexFormatBaseType::Float; + case tint::inspector::ComponentType::kSInt: + return VertexFormatBaseType::Sint; + case tint::inspector::ComponentType::kUInt: + return VertexFormatBaseType::Uint; + case tint::inspector::ComponentType::kUnknown: + return DAWN_VALIDATION_ERROR( + "Attempted to convert 'Unknown' component type from Tint"); + } + } + ResultOrError TintResourceTypeToBufferBindingType( tint::inspector::ResourceBinding::ResourceType resource_type) { switch (resource_type) { @@ -811,13 +826,19 @@ namespace dawn_native { return DAWN_VALIDATION_ERROR( "Unable to find Location decoration for Vertex input"); } - uint32_t location = compiler.get_decoration(attrib.id, spv::DecorationLocation); + uint32_t unsanitizedLocation = + compiler.get_decoration(attrib.id, spv::DecorationLocation); - if (location >= kMaxVertexAttributes) { + if (unsanitizedLocation >= kMaxVertexAttributes) { return DAWN_VALIDATION_ERROR("Attribute location over limits in the SPIRV"); } + VertexAttributeLocation location(static_cast(unsanitizedLocation)); - metadata->usedVertexAttributes.set(location); + spirv_cross::SPIRType::BaseType inputBaseType = + compiler.get_type(attrib.base_type_id).basetype; + metadata->vertexInputBaseTypes[location] = + SpirvBaseTypeToVertexFormatBaseType(inputBaseType); + metadata->usedVertexInputs.set(location); } // Without a location qualifier on vertex outputs, spirv_cross::CompilerMSL gives @@ -846,6 +867,7 @@ namespace dawn_native { } uint32_t unsanitizedAttachment = compiler.get_decoration(fragmentOutput.id, spv::DecorationLocation); + if (unsanitizedAttachment >= kMaxColorAttachments) { return DAWN_VALIDATION_ERROR( "Fragment output index must be less than max number of color " @@ -958,13 +980,17 @@ namespace dawn_native { return DAWN_VALIDATION_ERROR( "Need Location decoration on Vertex input"); } - uint32_t location = input_var.location_decoration; - if (DAWN_UNLIKELY(location >= kMaxVertexAttributes)) { + uint32_t unsanitizedLocation = input_var.location_decoration; + if (DAWN_UNLIKELY(unsanitizedLocation >= kMaxVertexAttributes)) { std::stringstream ss; - ss << "Attribute location (" << location << ") over limits"; + ss << "Attribute location (" << unsanitizedLocation << ") over limits"; return DAWN_VALIDATION_ERROR(ss.str()); } - metadata->usedVertexAttributes.set(location); + VertexAttributeLocation location(static_cast(unsanitizedLocation)); + DAWN_TRY_ASSIGN( + metadata->vertexInputBaseTypes[location], + TintComponentTypeToVertexFormatBaseType(input_var.component_type)); + metadata->usedVertexInputs.set(location); } for (const auto& output_var : entryPoint.output_variables) { diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h index 271704295e..da948e3900 100644 --- a/src/dawn_native/ShaderModule.h +++ b/src/dawn_native/ShaderModule.h @@ -25,6 +25,7 @@ #include "dawn_native/Forward.h" #include "dawn_native/IntegerTypes.h" #include "dawn_native/PerStage.h" +#include "dawn_native/VertexFormat.h" #include "dawn_native/dawn_platform.h" #include @@ -147,7 +148,9 @@ namespace dawn_native { std::vector samplerTexturePairs; // The set of vertex attributes this entryPoint uses. - std::bitset usedVertexAttributes; + ityp::array + vertexInputBaseTypes; + ityp::bitset usedVertexInputs; // An array to record the basic types (float, int and uint) of the fragment shader outputs. ityp::array diff --git a/src/dawn_native/SpirvUtils.cpp b/src/dawn_native/SpirvUtils.cpp index 9472508c38..01749de1dd 100644 --- a/src/dawn_native/SpirvUtils.cpp +++ b/src/dawn_native/SpirvUtils.cpp @@ -161,4 +161,18 @@ namespace dawn_native { } } + VertexFormatBaseType SpirvBaseTypeToVertexFormatBaseType( + spirv_cross::SPIRType::BaseType spirvBaseType) { + switch (spirvBaseType) { + case spirv_cross::SPIRType::Float: + return VertexFormatBaseType::Float; + case spirv_cross::SPIRType::Int: + return VertexFormatBaseType::Sint; + case spirv_cross::SPIRType::UInt: + return VertexFormatBaseType::Uint; + default: + UNREACHABLE(); + } + } + } // namespace dawn_native diff --git a/src/dawn_native/SpirvUtils.h b/src/dawn_native/SpirvUtils.h index 158b165e7c..ff356df837 100644 --- a/src/dawn_native/SpirvUtils.h +++ b/src/dawn_native/SpirvUtils.h @@ -20,6 +20,7 @@ #include "dawn_native/Format.h" #include "dawn_native/PerStage.h" +#include "dawn_native/VertexFormat.h" #include "dawn_native/dawn_platform.h" #include @@ -41,6 +42,10 @@ namespace dawn_native { spirv_cross::SPIRType::BaseType spirvBaseType); SampleTypeBit SpirvBaseTypeToSampleTypeBit(spirv_cross::SPIRType::BaseType spirvBaseType); + // Returns the VertexFormatBaseType corresponding to the SPIRV base type. + VertexFormatBaseType SpirvBaseTypeToVertexFormatBaseType( + spirv_cross::SPIRType::BaseType spirvBaseType); + } // namespace dawn_native #endif // DAWNNATIVE_SPIRV_UTILS_H_ diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm index d6510e1e58..03652c68b3 100644 --- a/src/dawn_native/metal/ShaderModuleMTL.mm +++ b/src/dawn_native/metal/ShaderModuleMTL.mm @@ -356,7 +356,7 @@ namespace dawn_native { namespace metal { out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]); if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) && - GetEntryPoint(entryPointName).usedVertexAttributes.any()) { + GetEntryPoint(entryPointName).usedVertexInputs.any()) { out->needsStorageBufferLength = true; } diff --git a/src/tests/unittests/validation/VertexStateValidationTests.cpp b/src/tests/unittests/validation/VertexStateValidationTests.cpp index 9ebea1e592..974dacb6b0 100644 --- a/src/tests/unittests/validation/VertexStateValidationTests.cpp +++ b/src/tests/unittests/validation/VertexStateValidationTests.cpp @@ -306,7 +306,7 @@ TEST_F(VertexStateTest, SetOffsetNotAligned) { state.cAttributes[0].offset = 2; CreatePipeline(true, state, kDummyVertexShader); - state.cAttributes[0].format = wgpu::VertexFormat::Uint8x2; + state.cAttributes[0].format = wgpu::VertexFormat::Unorm8x2; state.cAttributes[0].offset = 1; CreatePipeline(true, state, kDummyVertexShader); @@ -338,3 +338,80 @@ TEST_F(VertexStateTest, VertexFormatLargerThanNonZeroStride) { state.cAttributes[0].format = wgpu::VertexFormat::Float32x4; CreatePipeline(false, state, kDummyVertexShader); } + +// Check that the vertex format base type must match the shader's variable base type. +TEST_F(VertexStateTest, BaseTypeMatching) { + auto DoTest = [&](wgpu::VertexFormat format, std::string shaderType, bool success) { + utils::ComboVertexStateDescriptor state; + state.vertexBufferCount = 1; + state.cVertexBuffers[0].arrayStride = 16; + state.cVertexBuffers[0].attributeCount = 1; + state.cAttributes[0].format = format; + + std::string shader = "[[stage(vertex)]] fn main([[location(0)]] attrib : " + shaderType + + R"() -> [[builtin(position)]] vec4 { + return vec4(0.0, 0.0, 0.0, 0.0); + })"; + + CreatePipeline(success, state, shader.c_str()); + }; + + // Test that a float format is compatible only with f32 base type. + DoTest(wgpu::VertexFormat::Float32, "f32", true); + DoTest(wgpu::VertexFormat::Float32, "i32", false); + DoTest(wgpu::VertexFormat::Float32, "u32", false); + + // Test that an unorm format is compatible only with f32. + DoTest(wgpu::VertexFormat::Unorm16x2, "f32", true); + DoTest(wgpu::VertexFormat::Unorm16x2, "i32", false); + DoTest(wgpu::VertexFormat::Unorm16x2, "u32", false); + + // Test that an snorm format is compatible only with f32. + DoTest(wgpu::VertexFormat::Snorm16x4, "f32", true); + DoTest(wgpu::VertexFormat::Snorm16x4, "i32", false); + DoTest(wgpu::VertexFormat::Snorm16x4, "u32", false); + + // Test that an uint format is compatible only with u32. + DoTest(wgpu::VertexFormat::Uint32x3, "f32", false); + DoTest(wgpu::VertexFormat::Uint32x3, "i32", false); + DoTest(wgpu::VertexFormat::Uint32x3, "u32", true); + + // Test that an sint format is compatible only with u32. + DoTest(wgpu::VertexFormat::Sint8x4, "f32", false); + DoTest(wgpu::VertexFormat::Sint8x4, "i32", true); + DoTest(wgpu::VertexFormat::Sint8x4, "u32", false); + + // Test that formats are compatible with any width of vectors. + DoTest(wgpu::VertexFormat::Float32, "f32", true); + DoTest(wgpu::VertexFormat::Float32, "vec2", true); + DoTest(wgpu::VertexFormat::Float32, "vec3", true); + DoTest(wgpu::VertexFormat::Float32, "vec4", true); + + DoTest(wgpu::VertexFormat::Float32x4, "f32", true); + DoTest(wgpu::VertexFormat::Float32x4, "vec2", true); + DoTest(wgpu::VertexFormat::Float32x4, "vec3", true); + DoTest(wgpu::VertexFormat::Float32x4, "vec4", true); +} + +// Check that we only check base type compatibility for vertex inputs the shader uses. +TEST_F(VertexStateTest, BaseTypeMatchingForInexistentInput) { + auto DoTest = [&](wgpu::VertexFormat format) { + utils::ComboVertexStateDescriptor state; + state.vertexBufferCount = 1; + state.cVertexBuffers[0].arrayStride = 16; + state.cVertexBuffers[0].attributeCount = 1; + state.cAttributes[0].format = format; + + std::string shader = R"([[stage(vertex)]] fn main() -> [[builtin(position)]] vec4 { + return vec4(0.0, 0.0, 0.0, 0.0); + })"; + + CreatePipeline(true, state, shader.c_str()); + }; + + DoTest(wgpu::VertexFormat::Float32); + DoTest(wgpu::VertexFormat::Unorm16x2); + DoTest(wgpu::VertexFormat::Snorm16x4); + DoTest(wgpu::VertexFormat::Uint8x4); + DoTest(wgpu::VertexFormat::Sint32x2); +}