Implement render pipeline vertex format base type validation.

Bug: dawn:1008

Change-Id: I04d1ff1d46c1106147a8c50415c989db5789cbfc
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/59031
Auto-Submit: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Jiawei Shao <jiawei.shao@intel.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
This commit is contained in:
Corentin Wallez 2021-07-22 11:43:10 +00:00 committed by Dawn LUCI CQ
parent 1384e8b163
commit e9c8225f54
7 changed files with 161 additions and 25 deletions

View File

@ -28,16 +28,19 @@ namespace dawn_native {
// Helper functions
namespace {
MaybeError ValidateVertexAttribute(DeviceBase* device,
MaybeError ValidateVertexAttribute(
DeviceBase* device,
const VertexAttribute* attribute,
const EntryPointMetadata& metadata,
uint64_t vertexBufferStride,
std::bitset<kMaxVertexAttributes>* attributesSetMask) {
ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes>* attributesSetMask) {
DAWN_TRY(ValidateVertexFormat(attribute->format));
const VertexFormatInfo& formatInfo = GetVertexFormatInfo(attribute->format);
if (attribute->shaderLocation >= kMaxVertexAttributes) {
return DAWN_VALIDATION_ERROR("Setting attribute out of bounds");
}
VertexAttributeLocation location(static_cast<uint8_t>(attribute->shaderLocation));
// No underflow is possible because the max vertex format size is smaller than
// kMaxVertexBufferArrayStride.
@ -59,18 +62,25 @@ namespace dawn_native {
"Attribute offset needs to be a multiple of the size format's components");
}
if ((*attributesSetMask)[attribute->shaderLocation]) {
if (metadata.usedVertexInputs[location] &&
formatInfo.baseType != metadata.vertexInputBaseTypes[location]) {
return DAWN_VALIDATION_ERROR(
"Attribute base type must match the base type in the shader.");
}
if ((*attributesSetMask)[location]) {
return DAWN_VALIDATION_ERROR("Setting already set attribute");
}
attributesSetMask->set(attribute->shaderLocation);
attributesSetMask->set(location);
return {};
}
MaybeError ValidateVertexBufferLayout(
DeviceBase* device,
const VertexBufferLayout* buffer,
std::bitset<kMaxVertexAttributes>* attributesSetMask) {
const EntryPointMetadata& metadata,
ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes>* attributesSetMask) {
DAWN_TRY(ValidateInputStepMode(buffer->stepMode));
if (buffer->arrayStride > kMaxVertexBufferArrayStride) {
return DAWN_VALIDATION_ERROR("Setting arrayStride out of bounds");
@ -82,7 +92,7 @@ namespace dawn_native {
}
for (uint32_t i = 0; i < buffer->attributeCount; ++i) {
DAWN_TRY(ValidateVertexAttribute(device, &buffer->attributes[i],
DAWN_TRY(ValidateVertexAttribute(device, &buffer->attributes[i], metadata,
buffer->arrayStride, attributesSetMask));
}
@ -100,10 +110,15 @@ namespace dawn_native {
return DAWN_VALIDATION_ERROR("Vertex buffer count exceeds maximum");
}
std::bitset<kMaxVertexAttributes> attributesSetMask;
DAWN_TRY(ValidateProgrammableStage(device, descriptor->module, descriptor->entryPoint,
layout, SingleShaderStage::Vertex));
const EntryPointMetadata& vertexMetadata =
descriptor->module->GetEntryPoint(descriptor->entryPoint);
ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes> attributesSetMask;
uint32_t totalAttributesNum = 0;
for (uint32_t i = 0; i < descriptor->bufferCount; ++i) {
DAWN_TRY(ValidateVertexBufferLayout(device, &descriptor->buffers[i],
DAWN_TRY(ValidateVertexBufferLayout(device, &descriptor->buffers[i], vertexMetadata,
&attributesSetMask));
totalAttributesNum += descriptor->buffers[i].attributeCount;
}
@ -114,11 +129,7 @@ namespace dawn_native {
// attribute number never exceed kMaxVertexAttributes.
ASSERT(totalAttributesNum <= kMaxVertexAttributes);
DAWN_TRY(ValidateProgrammableStage(device, descriptor->module, descriptor->entryPoint,
layout, SingleShaderStage::Vertex));
const EntryPointMetadata& vertexMetadata =
descriptor->module->GetEntryPoint(descriptor->entryPoint);
if (!IsSubset(vertexMetadata.usedVertexAttributes, attributesSetMask)) {
if (!IsSubset(vertexMetadata.usedVertexInputs, attributesSetMask)) {
return DAWN_VALIDATION_ERROR(
"Pipeline vertex stage uses vertex buffers not in the vertex state");
}

View File

@ -295,6 +295,21 @@ namespace dawn_native {
}
}
ResultOrError<VertexFormatBaseType> TintComponentTypeToVertexFormatBaseType(
tint::inspector::ComponentType type) {
switch (type) {
case tint::inspector::ComponentType::kFloat:
return VertexFormatBaseType::Float;
case tint::inspector::ComponentType::kSInt:
return VertexFormatBaseType::Sint;
case tint::inspector::ComponentType::kUInt:
return VertexFormatBaseType::Uint;
case tint::inspector::ComponentType::kUnknown:
return DAWN_VALIDATION_ERROR(
"Attempted to convert 'Unknown' component type from Tint");
}
}
ResultOrError<wgpu::BufferBindingType> TintResourceTypeToBufferBindingType(
tint::inspector::ResourceBinding::ResourceType resource_type) {
switch (resource_type) {
@ -811,13 +826,19 @@ namespace dawn_native {
return DAWN_VALIDATION_ERROR(
"Unable to find Location decoration for Vertex input");
}
uint32_t location = compiler.get_decoration(attrib.id, spv::DecorationLocation);
uint32_t unsanitizedLocation =
compiler.get_decoration(attrib.id, spv::DecorationLocation);
if (location >= kMaxVertexAttributes) {
if (unsanitizedLocation >= kMaxVertexAttributes) {
return DAWN_VALIDATION_ERROR("Attribute location over limits in the SPIRV");
}
VertexAttributeLocation location(static_cast<uint8_t>(unsanitizedLocation));
metadata->usedVertexAttributes.set(location);
spirv_cross::SPIRType::BaseType inputBaseType =
compiler.get_type(attrib.base_type_id).basetype;
metadata->vertexInputBaseTypes[location] =
SpirvBaseTypeToVertexFormatBaseType(inputBaseType);
metadata->usedVertexInputs.set(location);
}
// Without a location qualifier on vertex outputs, spirv_cross::CompilerMSL gives
@ -846,6 +867,7 @@ namespace dawn_native {
}
uint32_t unsanitizedAttachment =
compiler.get_decoration(fragmentOutput.id, spv::DecorationLocation);
if (unsanitizedAttachment >= kMaxColorAttachments) {
return DAWN_VALIDATION_ERROR(
"Fragment output index must be less than max number of color "
@ -958,13 +980,17 @@ namespace dawn_native {
return DAWN_VALIDATION_ERROR(
"Need Location decoration on Vertex input");
}
uint32_t location = input_var.location_decoration;
if (DAWN_UNLIKELY(location >= kMaxVertexAttributes)) {
uint32_t unsanitizedLocation = input_var.location_decoration;
if (DAWN_UNLIKELY(unsanitizedLocation >= kMaxVertexAttributes)) {
std::stringstream ss;
ss << "Attribute location (" << location << ") over limits";
ss << "Attribute location (" << unsanitizedLocation << ") over limits";
return DAWN_VALIDATION_ERROR(ss.str());
}
metadata->usedVertexAttributes.set(location);
VertexAttributeLocation location(static_cast<uint8_t>(unsanitizedLocation));
DAWN_TRY_ASSIGN(
metadata->vertexInputBaseTypes[location],
TintComponentTypeToVertexFormatBaseType(input_var.component_type));
metadata->usedVertexInputs.set(location);
}
for (const auto& output_var : entryPoint.output_variables) {

View File

@ -25,6 +25,7 @@
#include "dawn_native/Forward.h"
#include "dawn_native/IntegerTypes.h"
#include "dawn_native/PerStage.h"
#include "dawn_native/VertexFormat.h"
#include "dawn_native/dawn_platform.h"
#include <bitset>
@ -147,7 +148,9 @@ namespace dawn_native {
std::vector<SamplerTexturePair> samplerTexturePairs;
// The set of vertex attributes this entryPoint uses.
std::bitset<kMaxVertexAttributes> usedVertexAttributes;
ityp::array<VertexAttributeLocation, VertexFormatBaseType, kMaxVertexAttributes>
vertexInputBaseTypes;
ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes> usedVertexInputs;
// An array to record the basic types (float, int and uint) of the fragment shader outputs.
ityp::array<ColorAttachmentIndex, wgpu::TextureComponentType, kMaxColorAttachments>

View File

@ -161,4 +161,18 @@ namespace dawn_native {
}
}
VertexFormatBaseType SpirvBaseTypeToVertexFormatBaseType(
spirv_cross::SPIRType::BaseType spirvBaseType) {
switch (spirvBaseType) {
case spirv_cross::SPIRType::Float:
return VertexFormatBaseType::Float;
case spirv_cross::SPIRType::Int:
return VertexFormatBaseType::Sint;
case spirv_cross::SPIRType::UInt:
return VertexFormatBaseType::Uint;
default:
UNREACHABLE();
}
}
} // namespace dawn_native

View File

@ -20,6 +20,7 @@
#include "dawn_native/Format.h"
#include "dawn_native/PerStage.h"
#include "dawn_native/VertexFormat.h"
#include "dawn_native/dawn_platform.h"
#include <spirv_cross.hpp>
@ -41,6 +42,10 @@ namespace dawn_native {
spirv_cross::SPIRType::BaseType spirvBaseType);
SampleTypeBit SpirvBaseTypeToSampleTypeBit(spirv_cross::SPIRType::BaseType spirvBaseType);
// Returns the VertexFormatBaseType corresponding to the SPIRV base type.
VertexFormatBaseType SpirvBaseTypeToVertexFormatBaseType(
spirv_cross::SPIRType::BaseType spirvBaseType);
} // namespace dawn_native
#endif // DAWNNATIVE_SPIRV_UTILS_H_

View File

@ -356,7 +356,7 @@ namespace dawn_native { namespace metal {
out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]);
if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
GetEntryPoint(entryPointName).usedVertexAttributes.any()) {
GetEntryPoint(entryPointName).usedVertexInputs.any()) {
out->needsStorageBufferLength = true;
}

View File

@ -306,7 +306,7 @@ TEST_F(VertexStateTest, SetOffsetNotAligned) {
state.cAttributes[0].offset = 2;
CreatePipeline(true, state, kDummyVertexShader);
state.cAttributes[0].format = wgpu::VertexFormat::Uint8x2;
state.cAttributes[0].format = wgpu::VertexFormat::Unorm8x2;
state.cAttributes[0].offset = 1;
CreatePipeline(true, state, kDummyVertexShader);
@ -338,3 +338,80 @@ TEST_F(VertexStateTest, VertexFormatLargerThanNonZeroStride) {
state.cAttributes[0].format = wgpu::VertexFormat::Float32x4;
CreatePipeline(false, state, kDummyVertexShader);
}
// Check that the vertex format base type must match the shader's variable base type.
TEST_F(VertexStateTest, BaseTypeMatching) {
auto DoTest = [&](wgpu::VertexFormat format, std::string shaderType, bool success) {
utils::ComboVertexStateDescriptor state;
state.vertexBufferCount = 1;
state.cVertexBuffers[0].arrayStride = 16;
state.cVertexBuffers[0].attributeCount = 1;
state.cAttributes[0].format = format;
std::string shader = "[[stage(vertex)]] fn main([[location(0)]] attrib : " + shaderType +
R"() -> [[builtin(position)]] vec4<f32> {
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
})";
CreatePipeline(success, state, shader.c_str());
};
// Test that a float format is compatible only with f32 base type.
DoTest(wgpu::VertexFormat::Float32, "f32", true);
DoTest(wgpu::VertexFormat::Float32, "i32", false);
DoTest(wgpu::VertexFormat::Float32, "u32", false);
// Test that an unorm format is compatible only with f32.
DoTest(wgpu::VertexFormat::Unorm16x2, "f32", true);
DoTest(wgpu::VertexFormat::Unorm16x2, "i32", false);
DoTest(wgpu::VertexFormat::Unorm16x2, "u32", false);
// Test that an snorm format is compatible only with f32.
DoTest(wgpu::VertexFormat::Snorm16x4, "f32", true);
DoTest(wgpu::VertexFormat::Snorm16x4, "i32", false);
DoTest(wgpu::VertexFormat::Snorm16x4, "u32", false);
// Test that an uint format is compatible only with u32.
DoTest(wgpu::VertexFormat::Uint32x3, "f32", false);
DoTest(wgpu::VertexFormat::Uint32x3, "i32", false);
DoTest(wgpu::VertexFormat::Uint32x3, "u32", true);
// Test that an sint format is compatible only with u32.
DoTest(wgpu::VertexFormat::Sint8x4, "f32", false);
DoTest(wgpu::VertexFormat::Sint8x4, "i32", true);
DoTest(wgpu::VertexFormat::Sint8x4, "u32", false);
// Test that formats are compatible with any width of vectors.
DoTest(wgpu::VertexFormat::Float32, "f32", true);
DoTest(wgpu::VertexFormat::Float32, "vec2<f32>", true);
DoTest(wgpu::VertexFormat::Float32, "vec3<f32>", true);
DoTest(wgpu::VertexFormat::Float32, "vec4<f32>", true);
DoTest(wgpu::VertexFormat::Float32x4, "f32", true);
DoTest(wgpu::VertexFormat::Float32x4, "vec2<f32>", true);
DoTest(wgpu::VertexFormat::Float32x4, "vec3<f32>", true);
DoTest(wgpu::VertexFormat::Float32x4, "vec4<f32>", true);
}
// Check that we only check base type compatibility for vertex inputs the shader uses.
TEST_F(VertexStateTest, BaseTypeMatchingForInexistentInput) {
auto DoTest = [&](wgpu::VertexFormat format) {
utils::ComboVertexStateDescriptor state;
state.vertexBufferCount = 1;
state.cVertexBuffers[0].arrayStride = 16;
state.cVertexBuffers[0].attributeCount = 1;
state.cAttributes[0].format = format;
std::string shader = R"([[stage(vertex)]] fn main() -> [[builtin(position)]] vec4<f32> {
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
})";
CreatePipeline(true, state, shader.c_str());
};
DoTest(wgpu::VertexFormat::Float32);
DoTest(wgpu::VertexFormat::Unorm16x2);
DoTest(wgpu::VertexFormat::Snorm16x4);
DoTest(wgpu::VertexFormat::Uint8x4);
DoTest(wgpu::VertexFormat::Sint32x2);
}