Implement render pipeline vertex format base type validation.

Bug: dawn:1008

Change-Id: I04d1ff1d46c1106147a8c50415c989db5789cbfc
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/59031
Auto-Submit: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Jiawei Shao <jiawei.shao@intel.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
This commit is contained in:
Corentin Wallez 2021-07-22 11:43:10 +00:00 committed by Dawn LUCI CQ
parent 1384e8b163
commit e9c8225f54
7 changed files with 161 additions and 25 deletions

View File

@ -28,16 +28,19 @@ namespace dawn_native {
// Helper functions // Helper functions
namespace { namespace {
MaybeError ValidateVertexAttribute(DeviceBase* device, MaybeError ValidateVertexAttribute(
const VertexAttribute* attribute, DeviceBase* device,
uint64_t vertexBufferStride, const VertexAttribute* attribute,
std::bitset<kMaxVertexAttributes>* attributesSetMask) { const EntryPointMetadata& metadata,
uint64_t vertexBufferStride,
ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes>* attributesSetMask) {
DAWN_TRY(ValidateVertexFormat(attribute->format)); DAWN_TRY(ValidateVertexFormat(attribute->format));
const VertexFormatInfo& formatInfo = GetVertexFormatInfo(attribute->format); const VertexFormatInfo& formatInfo = GetVertexFormatInfo(attribute->format);
if (attribute->shaderLocation >= kMaxVertexAttributes) { if (attribute->shaderLocation >= kMaxVertexAttributes) {
return DAWN_VALIDATION_ERROR("Setting attribute out of bounds"); return DAWN_VALIDATION_ERROR("Setting attribute out of bounds");
} }
VertexAttributeLocation location(static_cast<uint8_t>(attribute->shaderLocation));
// No underflow is possible because the max vertex format size is smaller than // No underflow is possible because the max vertex format size is smaller than
// kMaxVertexBufferArrayStride. // kMaxVertexBufferArrayStride.
@ -59,18 +62,25 @@ namespace dawn_native {
"Attribute offset needs to be a multiple of the size format's components"); "Attribute offset needs to be a multiple of the size format's components");
} }
if ((*attributesSetMask)[attribute->shaderLocation]) { if (metadata.usedVertexInputs[location] &&
formatInfo.baseType != metadata.vertexInputBaseTypes[location]) {
return DAWN_VALIDATION_ERROR(
"Attribute base type must match the base type in the shader.");
}
if ((*attributesSetMask)[location]) {
return DAWN_VALIDATION_ERROR("Setting already set attribute"); return DAWN_VALIDATION_ERROR("Setting already set attribute");
} }
attributesSetMask->set(attribute->shaderLocation); attributesSetMask->set(location);
return {}; return {};
} }
MaybeError ValidateVertexBufferLayout( MaybeError ValidateVertexBufferLayout(
DeviceBase* device, DeviceBase* device,
const VertexBufferLayout* buffer, const VertexBufferLayout* buffer,
std::bitset<kMaxVertexAttributes>* attributesSetMask) { const EntryPointMetadata& metadata,
ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes>* attributesSetMask) {
DAWN_TRY(ValidateInputStepMode(buffer->stepMode)); DAWN_TRY(ValidateInputStepMode(buffer->stepMode));
if (buffer->arrayStride > kMaxVertexBufferArrayStride) { if (buffer->arrayStride > kMaxVertexBufferArrayStride) {
return DAWN_VALIDATION_ERROR("Setting arrayStride out of bounds"); return DAWN_VALIDATION_ERROR("Setting arrayStride out of bounds");
@ -82,7 +92,7 @@ namespace dawn_native {
} }
for (uint32_t i = 0; i < buffer->attributeCount; ++i) { for (uint32_t i = 0; i < buffer->attributeCount; ++i) {
DAWN_TRY(ValidateVertexAttribute(device, &buffer->attributes[i], DAWN_TRY(ValidateVertexAttribute(device, &buffer->attributes[i], metadata,
buffer->arrayStride, attributesSetMask)); buffer->arrayStride, attributesSetMask));
} }
@ -100,10 +110,15 @@ namespace dawn_native {
return DAWN_VALIDATION_ERROR("Vertex buffer count exceeds maximum"); return DAWN_VALIDATION_ERROR("Vertex buffer count exceeds maximum");
} }
std::bitset<kMaxVertexAttributes> attributesSetMask; DAWN_TRY(ValidateProgrammableStage(device, descriptor->module, descriptor->entryPoint,
layout, SingleShaderStage::Vertex));
const EntryPointMetadata& vertexMetadata =
descriptor->module->GetEntryPoint(descriptor->entryPoint);
ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes> attributesSetMask;
uint32_t totalAttributesNum = 0; uint32_t totalAttributesNum = 0;
for (uint32_t i = 0; i < descriptor->bufferCount; ++i) { for (uint32_t i = 0; i < descriptor->bufferCount; ++i) {
DAWN_TRY(ValidateVertexBufferLayout(device, &descriptor->buffers[i], DAWN_TRY(ValidateVertexBufferLayout(device, &descriptor->buffers[i], vertexMetadata,
&attributesSetMask)); &attributesSetMask));
totalAttributesNum += descriptor->buffers[i].attributeCount; totalAttributesNum += descriptor->buffers[i].attributeCount;
} }
@ -114,11 +129,7 @@ namespace dawn_native {
// attribute number never exceed kMaxVertexAttributes. // attribute number never exceed kMaxVertexAttributes.
ASSERT(totalAttributesNum <= kMaxVertexAttributes); ASSERT(totalAttributesNum <= kMaxVertexAttributes);
DAWN_TRY(ValidateProgrammableStage(device, descriptor->module, descriptor->entryPoint, if (!IsSubset(vertexMetadata.usedVertexInputs, attributesSetMask)) {
layout, SingleShaderStage::Vertex));
const EntryPointMetadata& vertexMetadata =
descriptor->module->GetEntryPoint(descriptor->entryPoint);
if (!IsSubset(vertexMetadata.usedVertexAttributes, attributesSetMask)) {
return DAWN_VALIDATION_ERROR( return DAWN_VALIDATION_ERROR(
"Pipeline vertex stage uses vertex buffers not in the vertex state"); "Pipeline vertex stage uses vertex buffers not in the vertex state");
} }

View File

@ -295,6 +295,21 @@ namespace dawn_native {
} }
} }
ResultOrError<VertexFormatBaseType> TintComponentTypeToVertexFormatBaseType(
tint::inspector::ComponentType type) {
switch (type) {
case tint::inspector::ComponentType::kFloat:
return VertexFormatBaseType::Float;
case tint::inspector::ComponentType::kSInt:
return VertexFormatBaseType::Sint;
case tint::inspector::ComponentType::kUInt:
return VertexFormatBaseType::Uint;
case tint::inspector::ComponentType::kUnknown:
return DAWN_VALIDATION_ERROR(
"Attempted to convert 'Unknown' component type from Tint");
}
}
ResultOrError<wgpu::BufferBindingType> TintResourceTypeToBufferBindingType( ResultOrError<wgpu::BufferBindingType> TintResourceTypeToBufferBindingType(
tint::inspector::ResourceBinding::ResourceType resource_type) { tint::inspector::ResourceBinding::ResourceType resource_type) {
switch (resource_type) { switch (resource_type) {
@ -811,13 +826,19 @@ namespace dawn_native {
return DAWN_VALIDATION_ERROR( return DAWN_VALIDATION_ERROR(
"Unable to find Location decoration for Vertex input"); "Unable to find Location decoration for Vertex input");
} }
uint32_t location = compiler.get_decoration(attrib.id, spv::DecorationLocation); uint32_t unsanitizedLocation =
compiler.get_decoration(attrib.id, spv::DecorationLocation);
if (location >= kMaxVertexAttributes) { if (unsanitizedLocation >= kMaxVertexAttributes) {
return DAWN_VALIDATION_ERROR("Attribute location over limits in the SPIRV"); return DAWN_VALIDATION_ERROR("Attribute location over limits in the SPIRV");
} }
VertexAttributeLocation location(static_cast<uint8_t>(unsanitizedLocation));
metadata->usedVertexAttributes.set(location); spirv_cross::SPIRType::BaseType inputBaseType =
compiler.get_type(attrib.base_type_id).basetype;
metadata->vertexInputBaseTypes[location] =
SpirvBaseTypeToVertexFormatBaseType(inputBaseType);
metadata->usedVertexInputs.set(location);
} }
// Without a location qualifier on vertex outputs, spirv_cross::CompilerMSL gives // Without a location qualifier on vertex outputs, spirv_cross::CompilerMSL gives
@ -846,6 +867,7 @@ namespace dawn_native {
} }
uint32_t unsanitizedAttachment = uint32_t unsanitizedAttachment =
compiler.get_decoration(fragmentOutput.id, spv::DecorationLocation); compiler.get_decoration(fragmentOutput.id, spv::DecorationLocation);
if (unsanitizedAttachment >= kMaxColorAttachments) { if (unsanitizedAttachment >= kMaxColorAttachments) {
return DAWN_VALIDATION_ERROR( return DAWN_VALIDATION_ERROR(
"Fragment output index must be less than max number of color " "Fragment output index must be less than max number of color "
@ -958,13 +980,17 @@ namespace dawn_native {
return DAWN_VALIDATION_ERROR( return DAWN_VALIDATION_ERROR(
"Need Location decoration on Vertex input"); "Need Location decoration on Vertex input");
} }
uint32_t location = input_var.location_decoration; uint32_t unsanitizedLocation = input_var.location_decoration;
if (DAWN_UNLIKELY(location >= kMaxVertexAttributes)) { if (DAWN_UNLIKELY(unsanitizedLocation >= kMaxVertexAttributes)) {
std::stringstream ss; std::stringstream ss;
ss << "Attribute location (" << location << ") over limits"; ss << "Attribute location (" << unsanitizedLocation << ") over limits";
return DAWN_VALIDATION_ERROR(ss.str()); return DAWN_VALIDATION_ERROR(ss.str());
} }
metadata->usedVertexAttributes.set(location); VertexAttributeLocation location(static_cast<uint8_t>(unsanitizedLocation));
DAWN_TRY_ASSIGN(
metadata->vertexInputBaseTypes[location],
TintComponentTypeToVertexFormatBaseType(input_var.component_type));
metadata->usedVertexInputs.set(location);
} }
for (const auto& output_var : entryPoint.output_variables) { for (const auto& output_var : entryPoint.output_variables) {

View File

@ -25,6 +25,7 @@
#include "dawn_native/Forward.h" #include "dawn_native/Forward.h"
#include "dawn_native/IntegerTypes.h" #include "dawn_native/IntegerTypes.h"
#include "dawn_native/PerStage.h" #include "dawn_native/PerStage.h"
#include "dawn_native/VertexFormat.h"
#include "dawn_native/dawn_platform.h" #include "dawn_native/dawn_platform.h"
#include <bitset> #include <bitset>
@ -147,7 +148,9 @@ namespace dawn_native {
std::vector<SamplerTexturePair> samplerTexturePairs; std::vector<SamplerTexturePair> samplerTexturePairs;
// The set of vertex attributes this entryPoint uses. // The set of vertex attributes this entryPoint uses.
std::bitset<kMaxVertexAttributes> usedVertexAttributes; ityp::array<VertexAttributeLocation, VertexFormatBaseType, kMaxVertexAttributes>
vertexInputBaseTypes;
ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes> usedVertexInputs;
// An array to record the basic types (float, int and uint) of the fragment shader outputs. // An array to record the basic types (float, int and uint) of the fragment shader outputs.
ityp::array<ColorAttachmentIndex, wgpu::TextureComponentType, kMaxColorAttachments> ityp::array<ColorAttachmentIndex, wgpu::TextureComponentType, kMaxColorAttachments>

View File

@ -161,4 +161,18 @@ namespace dawn_native {
} }
} }
VertexFormatBaseType SpirvBaseTypeToVertexFormatBaseType(
spirv_cross::SPIRType::BaseType spirvBaseType) {
switch (spirvBaseType) {
case spirv_cross::SPIRType::Float:
return VertexFormatBaseType::Float;
case spirv_cross::SPIRType::Int:
return VertexFormatBaseType::Sint;
case spirv_cross::SPIRType::UInt:
return VertexFormatBaseType::Uint;
default:
UNREACHABLE();
}
}
} // namespace dawn_native } // namespace dawn_native

View File

@ -20,6 +20,7 @@
#include "dawn_native/Format.h" #include "dawn_native/Format.h"
#include "dawn_native/PerStage.h" #include "dawn_native/PerStage.h"
#include "dawn_native/VertexFormat.h"
#include "dawn_native/dawn_platform.h" #include "dawn_native/dawn_platform.h"
#include <spirv_cross.hpp> #include <spirv_cross.hpp>
@ -41,6 +42,10 @@ namespace dawn_native {
spirv_cross::SPIRType::BaseType spirvBaseType); spirv_cross::SPIRType::BaseType spirvBaseType);
SampleTypeBit SpirvBaseTypeToSampleTypeBit(spirv_cross::SPIRType::BaseType spirvBaseType); SampleTypeBit SpirvBaseTypeToSampleTypeBit(spirv_cross::SPIRType::BaseType spirvBaseType);
// Returns the VertexFormatBaseType corresponding to the SPIRV base type.
VertexFormatBaseType SpirvBaseTypeToVertexFormatBaseType(
spirv_cross::SPIRType::BaseType spirvBaseType);
} // namespace dawn_native } // namespace dawn_native
#endif // DAWNNATIVE_SPIRV_UTILS_H_ #endif // DAWNNATIVE_SPIRV_UTILS_H_

View File

@ -356,7 +356,7 @@ namespace dawn_native { namespace metal {
out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]); out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]);
if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) && if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
GetEntryPoint(entryPointName).usedVertexAttributes.any()) { GetEntryPoint(entryPointName).usedVertexInputs.any()) {
out->needsStorageBufferLength = true; out->needsStorageBufferLength = true;
} }

View File

@ -306,7 +306,7 @@ TEST_F(VertexStateTest, SetOffsetNotAligned) {
state.cAttributes[0].offset = 2; state.cAttributes[0].offset = 2;
CreatePipeline(true, state, kDummyVertexShader); CreatePipeline(true, state, kDummyVertexShader);
state.cAttributes[0].format = wgpu::VertexFormat::Uint8x2; state.cAttributes[0].format = wgpu::VertexFormat::Unorm8x2;
state.cAttributes[0].offset = 1; state.cAttributes[0].offset = 1;
CreatePipeline(true, state, kDummyVertexShader); CreatePipeline(true, state, kDummyVertexShader);
@ -338,3 +338,80 @@ TEST_F(VertexStateTest, VertexFormatLargerThanNonZeroStride) {
state.cAttributes[0].format = wgpu::VertexFormat::Float32x4; state.cAttributes[0].format = wgpu::VertexFormat::Float32x4;
CreatePipeline(false, state, kDummyVertexShader); CreatePipeline(false, state, kDummyVertexShader);
} }
// Check that the vertex format base type must match the shader's variable base type.
TEST_F(VertexStateTest, BaseTypeMatching) {
auto DoTest = [&](wgpu::VertexFormat format, std::string shaderType, bool success) {
utils::ComboVertexStateDescriptor state;
state.vertexBufferCount = 1;
state.cVertexBuffers[0].arrayStride = 16;
state.cVertexBuffers[0].attributeCount = 1;
state.cAttributes[0].format = format;
std::string shader = "[[stage(vertex)]] fn main([[location(0)]] attrib : " + shaderType +
R"() -> [[builtin(position)]] vec4<f32> {
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
})";
CreatePipeline(success, state, shader.c_str());
};
// Test that a float format is compatible only with f32 base type.
DoTest(wgpu::VertexFormat::Float32, "f32", true);
DoTest(wgpu::VertexFormat::Float32, "i32", false);
DoTest(wgpu::VertexFormat::Float32, "u32", false);
// Test that an unorm format is compatible only with f32.
DoTest(wgpu::VertexFormat::Unorm16x2, "f32", true);
DoTest(wgpu::VertexFormat::Unorm16x2, "i32", false);
DoTest(wgpu::VertexFormat::Unorm16x2, "u32", false);
// Test that an snorm format is compatible only with f32.
DoTest(wgpu::VertexFormat::Snorm16x4, "f32", true);
DoTest(wgpu::VertexFormat::Snorm16x4, "i32", false);
DoTest(wgpu::VertexFormat::Snorm16x4, "u32", false);
// Test that an uint format is compatible only with u32.
DoTest(wgpu::VertexFormat::Uint32x3, "f32", false);
DoTest(wgpu::VertexFormat::Uint32x3, "i32", false);
DoTest(wgpu::VertexFormat::Uint32x3, "u32", true);
// Test that an sint format is compatible only with u32.
DoTest(wgpu::VertexFormat::Sint8x4, "f32", false);
DoTest(wgpu::VertexFormat::Sint8x4, "i32", true);
DoTest(wgpu::VertexFormat::Sint8x4, "u32", false);
// Test that formats are compatible with any width of vectors.
DoTest(wgpu::VertexFormat::Float32, "f32", true);
DoTest(wgpu::VertexFormat::Float32, "vec2<f32>", true);
DoTest(wgpu::VertexFormat::Float32, "vec3<f32>", true);
DoTest(wgpu::VertexFormat::Float32, "vec4<f32>", true);
DoTest(wgpu::VertexFormat::Float32x4, "f32", true);
DoTest(wgpu::VertexFormat::Float32x4, "vec2<f32>", true);
DoTest(wgpu::VertexFormat::Float32x4, "vec3<f32>", true);
DoTest(wgpu::VertexFormat::Float32x4, "vec4<f32>", true);
}
// Check that we only check base type compatibility for vertex inputs the shader uses.
TEST_F(VertexStateTest, BaseTypeMatchingForInexistentInput) {
auto DoTest = [&](wgpu::VertexFormat format) {
utils::ComboVertexStateDescriptor state;
state.vertexBufferCount = 1;
state.cVertexBuffers[0].arrayStride = 16;
state.cVertexBuffers[0].attributeCount = 1;
state.cAttributes[0].format = format;
std::string shader = R"([[stage(vertex)]] fn main() -> [[builtin(position)]] vec4<f32> {
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
})";
CreatePipeline(true, state, shader.c_str());
};
DoTest(wgpu::VertexFormat::Float32);
DoTest(wgpu::VertexFormat::Unorm16x2);
DoTest(wgpu::VertexFormat::Snorm16x4);
DoTest(wgpu::VertexFormat::Uint8x4);
DoTest(wgpu::VertexFormat::Sint32x2);
}