Add Metal vertex pulling behind a flag

Implements vertex pulling on the Metal backend, hidden behind a flag
until ready for use (we are missing support for more complicated vertex
input types).

Bug: dawn:480
Change-Id: I38028b80673693ebf21309ad5336561fb99f40dc
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/26522
Commit-Queue: Idan Raiter <idanr@google.com>
Reviewed-by: Austin Eng <enga@chromium.org>
This commit is contained in:
Idan Raiter 2020-08-13 23:53:59 +00:00 committed by Commit Bot service account
parent e1604b9a64
commit d315022be5
13 changed files with 515 additions and 23 deletions

View File

@ -295,6 +295,82 @@ namespace dawn_native {
<< " binding " << static_cast<uint32_t>(binding); << " binding " << static_cast<uint32_t>(binding);
return ostream.str(); return ostream.str();
} }
#ifdef DAWN_ENABLE_WGSL
tint::ast::transform::VertexFormat ToTintVertexFormat(wgpu::VertexFormat format) {
switch (format) {
case wgpu::VertexFormat::UChar2:
return tint::ast::transform::VertexFormat::kVec2U8;
case wgpu::VertexFormat::UChar4:
return tint::ast::transform::VertexFormat::kVec4U8;
case wgpu::VertexFormat::Char2:
return tint::ast::transform::VertexFormat::kVec2I8;
case wgpu::VertexFormat::Char4:
return tint::ast::transform::VertexFormat::kVec4I8;
case wgpu::VertexFormat::UChar2Norm:
return tint::ast::transform::VertexFormat::kVec2U8Norm;
case wgpu::VertexFormat::UChar4Norm:
return tint::ast::transform::VertexFormat::kVec4U8Norm;
case wgpu::VertexFormat::Char2Norm:
return tint::ast::transform::VertexFormat::kVec2I8Norm;
case wgpu::VertexFormat::Char4Norm:
return tint::ast::transform::VertexFormat::kVec4I8Norm;
case wgpu::VertexFormat::UShort2:
return tint::ast::transform::VertexFormat::kVec2U16;
case wgpu::VertexFormat::UShort4:
return tint::ast::transform::VertexFormat::kVec4U16;
case wgpu::VertexFormat::Short2:
return tint::ast::transform::VertexFormat::kVec2I16;
case wgpu::VertexFormat::Short4:
return tint::ast::transform::VertexFormat::kVec4I16;
case wgpu::VertexFormat::UShort2Norm:
return tint::ast::transform::VertexFormat::kVec2U16Norm;
case wgpu::VertexFormat::UShort4Norm:
return tint::ast::transform::VertexFormat::kVec4U16Norm;
case wgpu::VertexFormat::Short2Norm:
return tint::ast::transform::VertexFormat::kVec2I16Norm;
case wgpu::VertexFormat::Short4Norm:
return tint::ast::transform::VertexFormat::kVec4I16Norm;
case wgpu::VertexFormat::Half2:
return tint::ast::transform::VertexFormat::kVec2F16;
case wgpu::VertexFormat::Half4:
return tint::ast::transform::VertexFormat::kVec4F16;
case wgpu::VertexFormat::Float:
return tint::ast::transform::VertexFormat::kF32;
case wgpu::VertexFormat::Float2:
return tint::ast::transform::VertexFormat::kVec2F32;
case wgpu::VertexFormat::Float3:
return tint::ast::transform::VertexFormat::kVec3F32;
case wgpu::VertexFormat::Float4:
return tint::ast::transform::VertexFormat::kVec4F32;
case wgpu::VertexFormat::UInt:
return tint::ast::transform::VertexFormat::kU32;
case wgpu::VertexFormat::UInt2:
return tint::ast::transform::VertexFormat::kVec2U32;
case wgpu::VertexFormat::UInt3:
return tint::ast::transform::VertexFormat::kVec3U32;
case wgpu::VertexFormat::UInt4:
return tint::ast::transform::VertexFormat::kVec4U32;
case wgpu::VertexFormat::Int:
return tint::ast::transform::VertexFormat::kI32;
case wgpu::VertexFormat::Int2:
return tint::ast::transform::VertexFormat::kVec2I32;
case wgpu::VertexFormat::Int3:
return tint::ast::transform::VertexFormat::kVec3I32;
case wgpu::VertexFormat::Int4:
return tint::ast::transform::VertexFormat::kVec4I32;
}
}
tint::ast::transform::InputStepMode ToTintInputStepMode(wgpu::InputStepMode mode) {
switch (mode) {
case wgpu::InputStepMode::Vertex:
return tint::ast::transform::InputStepMode::kVertex;
case wgpu::InputStepMode::Instance:
return tint::ast::transform::InputStepMode::kInstance;
}
}
#endif
} // anonymous namespace } // anonymous namespace
MaybeError ValidateSpirv(DeviceBase*, const uint32_t* code, uint32_t codeSize) { MaybeError ValidateSpirv(DeviceBase*, const uint32_t* code, uint32_t codeSize) {
@ -400,6 +476,75 @@ namespace dawn_native {
std::vector<uint32_t> spirv = generator.result(); std::vector<uint32_t> spirv = generator.result();
return std::move(spirv); return std::move(spirv);
} }
ResultOrError<std::vector<uint32_t>> ConvertWGSLToSPIRVWithPulling(
const char* source,
const VertexStateDescriptor& vertexState,
const std::string& entryPoint,
uint32_t pullingBufferBindingSet) {
std::ostringstream errorStream;
errorStream << "Tint WGSL->SPIR-V failure:" << std::endl;
tint::Context context;
tint::reader::wgsl::Parser parser(&context, source);
// TODO: This is a duplicate parse with ValidateWGSL, need to store
// state between calls to avoid this.
if (!parser.Parse()) {
errorStream << "Parser: " << parser.error() << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::ast::Module module = parser.module();
if (!module.IsValid()) {
errorStream << "Invalid module generated..." << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::ast::transform::VertexPullingTransform transform(&context, &module);
auto state = std::make_unique<tint::ast::transform::VertexStateDescriptor>();
for (uint32_t i = 0; i < vertexState.vertexBufferCount; ++i) {
auto& vertexBuffer = vertexState.vertexBuffers[i];
tint::ast::transform::VertexBufferLayoutDescriptor layout;
layout.array_stride = vertexBuffer.arrayStride;
layout.step_mode = ToTintInputStepMode(vertexBuffer.stepMode);
for (uint32_t j = 0; j < vertexBuffer.attributeCount; ++j) {
auto& attribute = vertexBuffer.attributes[j];
tint::ast::transform::VertexAttributeDescriptor attr;
attr.format = ToTintVertexFormat(attribute.format);
attr.offset = attribute.offset;
attr.shader_location = attribute.shaderLocation;
layout.attributes.push_back(std::move(attr));
}
state->vertex_buffers.push_back(std::move(layout));
}
transform.SetVertexState(std::move(state));
transform.SetEntryPoint(entryPoint);
transform.SetPullingBufferBindingSet(pullingBufferBindingSet);
if (!transform.Run()) {
errorStream << "Vertex pulling transform: " << transform.GetError();
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::TypeDeterminer type_determiner(&context, &module);
if (!type_determiner.Determine()) {
errorStream << "Type Determination: " << type_determiner.error();
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::writer::spirv::Generator generator(std::move(module));
if (!generator.Generate()) {
errorStream << "Generator: " << generator.error() << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
std::vector<uint32_t> spirv = generator.result();
return std::move(spirv);
}
#endif // DAWN_ENABLE_WGSL #endif // DAWN_ENABLE_WGSL
MaybeError ValidateShaderModuleDescriptor(DeviceBase* device, MaybeError ValidateShaderModuleDescriptor(DeviceBase* device,
@ -1094,10 +1239,22 @@ namespace dawn_native {
return mSpirv; return mSpirv;
} }
#ifdef DAWN_ENABLE_WGSL
ResultOrError<std::vector<uint32_t>> ShaderModuleBase::GeneratePullingSpirv(
const VertexStateDescriptor& vertexState,
const std::string& entryPoint,
uint32_t pullingBufferBindingSet) const {
return ConvertWGSLToSPIRVWithPulling(mWgsl.c_str(), vertexState, entryPoint,
pullingBufferBindingSet);
}
#endif
shaderc_spvc::CompileOptions ShaderModuleBase::GetCompileOptions() const { shaderc_spvc::CompileOptions ShaderModuleBase::GetCompileOptions() const {
shaderc_spvc::CompileOptions options; shaderc_spvc::CompileOptions options;
options.SetValidate(GetDevice()->IsValidationEnabled()); options.SetValidate(GetDevice()->IsValidationEnabled());
options.SetRobustBufferAccessPass(GetDevice()->IsRobustnessEnabled()); options.SetRobustBufferAccessPass(GetDevice()->IsRobustnessEnabled());
options.SetSourceEnvironment(shaderc_target_env_vulkan, shaderc_env_version_vulkan_1_1);
options.SetTargetEnvironment(shaderc_target_env_vulkan, shaderc_env_version_vulkan_1_1);
return options; return options;
} }

View File

@ -91,6 +91,13 @@ namespace dawn_native {
shaderc_spvc::Context* GetContext(); shaderc_spvc::Context* GetContext();
const std::vector<uint32_t>& GetSpirv() const; const std::vector<uint32_t>& GetSpirv() const;
#ifdef DAWN_ENABLE_WGSL
ResultOrError<std::vector<uint32_t>> GeneratePullingSpirv(
const VertexStateDescriptor& vertexState,
const std::string& entryPoint,
uint32_t pullingBufferBindingSet) const;
#endif
protected: protected:
static MaybeError CheckSpvcSuccess(shaderc_spvc_status status, const char* error_msg); static MaybeError CheckSpvcSuccess(shaderc_spvc_status status, const char* error_msg);
shaderc_spvc::CompileOptions GetCompileOptions() const; shaderc_spvc::CompileOptions GetCompileOptions() const;

View File

@ -138,7 +138,11 @@ namespace dawn_native {
"Clear buffers on their first use. This is a temporary toggle only for the " "Clear buffers on their first use. This is a temporary toggle only for the "
"development of buffer lazy initialization and will be removed after buffer lazy " "development of buffer lazy initialization and will be removed after buffer lazy "
"initialization is completely implemented.", "initialization is completely implemented.",
"https://crbug.com/dawn/414"}}}}; "https://crbug.com/dawn/414"}},
{Toggle::MetalEnableVertexPulling,
{"metal_enable_vertex_pulling",
"Uses vertex pulling to protect out-of-bounds reads on Metal",
"https://crbug.com/dawn/480"}}}};
} // anonymous namespace } // anonymous namespace

View File

@ -44,6 +44,7 @@ namespace dawn_native {
UseDXC, UseDXC,
DisableRobustness, DisableRobustness,
LazyClearBufferOnFirstUse, LazyClearBufferOnFirstUse,
MetalEnableVertexPulling,
EnumCount, EnumCount,
InvalidEnum = EnumCount, InvalidEnum = EnumCount,

View File

@ -263,7 +263,9 @@ namespace dawn_native { namespace metal {
// MSL code generated by SPIRV-Cross expects. // MSL code generated by SPIRV-Cross expects.
PerStage<std::array<uint32_t, kGenericMetalBufferSlots>> data; PerStage<std::array<uint32_t, kGenericMetalBufferSlots>> data;
void Apply(id<MTLRenderCommandEncoder> render, RenderPipeline* pipeline) { void Apply(id<MTLRenderCommandEncoder> render,
RenderPipeline* pipeline,
bool enableVertexPulling) {
wgpu::ShaderStage stagesToApply = wgpu::ShaderStage stagesToApply =
dirtyStages & pipeline->GetStagesRequiringStorageBufferLength(); dirtyStages & pipeline->GetStagesRequiringStorageBufferLength();
@ -274,6 +276,11 @@ namespace dawn_native { namespace metal {
if (stagesToApply & wgpu::ShaderStage::Vertex) { if (stagesToApply & wgpu::ShaderStage::Vertex) {
uint32_t bufferCount = ToBackend(pipeline->GetLayout()) uint32_t bufferCount = ToBackend(pipeline->GetLayout())
->GetBufferBindingCount(SingleShaderStage::Vertex); ->GetBufferBindingCount(SingleShaderStage::Vertex);
if (enableVertexPulling) {
bufferCount += pipeline->GetVertexStateDescriptor()->vertexBufferCount;
}
[render setVertexBytes:data[SingleShaderStage::Vertex].data() [render setVertexBytes:data[SingleShaderStage::Vertex].data()
length:sizeof(uint32_t) * bufferCount length:sizeof(uint32_t) * bufferCount
atIndex:kBufferLengthBufferSlot]; atIndex:kBufferLengthBufferSlot];
@ -483,10 +490,17 @@ namespace dawn_native { namespace metal {
// all the relevant state. // all the relevant state.
class VertexBufferTracker { class VertexBufferTracker {
public: public:
explicit VertexBufferTracker(StorageBufferLengthTracker* lengthTracker)
: mLengthTracker(lengthTracker) {
}
void OnSetVertexBuffer(uint32_t slot, Buffer* buffer, uint64_t offset) { void OnSetVertexBuffer(uint32_t slot, Buffer* buffer, uint64_t offset) {
mVertexBuffers[slot] = buffer->GetMTLBuffer(); mVertexBuffers[slot] = buffer->GetMTLBuffer();
mVertexBufferOffsets[slot] = offset; mVertexBufferOffsets[slot] = offset;
ASSERT(buffer->GetSize() < std::numeric_limits<uint32_t>::max());
mVertexBufferBindingSizes[slot] = static_cast<uint32_t>(buffer->GetSize() - offset);
// Use 64 bit masks and make sure there are no shift UB // Use 64 bit masks and make sure there are no shift UB
static_assert(kMaxVertexBuffers <= 8 * sizeof(unsigned long long) - 1, ""); static_assert(kMaxVertexBuffers <= 8 * sizeof(unsigned long long) - 1, "");
mDirtyVertexBuffers |= 1ull << slot; mDirtyVertexBuffers |= 1ull << slot;
@ -499,13 +513,22 @@ namespace dawn_native { namespace metal {
mDirtyVertexBuffers |= pipeline->GetVertexBufferSlotsUsed(); mDirtyVertexBuffers |= pipeline->GetVertexBufferSlotsUsed();
} }
void Apply(id<MTLRenderCommandEncoder> encoder, RenderPipeline* pipeline) { void Apply(id<MTLRenderCommandEncoder> encoder,
RenderPipeline* pipeline,
bool enableVertexPulling) {
std::bitset<kMaxVertexBuffers> vertexBuffersToApply = std::bitset<kMaxVertexBuffers> vertexBuffersToApply =
mDirtyVertexBuffers & pipeline->GetVertexBufferSlotsUsed(); mDirtyVertexBuffers & pipeline->GetVertexBufferSlotsUsed();
for (uint32_t dawnIndex : IterateBitSet(vertexBuffersToApply)) { for (uint32_t dawnIndex : IterateBitSet(vertexBuffersToApply)) {
uint32_t metalIndex = pipeline->GetMtlVertexBufferIndex(dawnIndex); uint32_t metalIndex = pipeline->GetMtlVertexBufferIndex(dawnIndex);
if (enableVertexPulling) {
// Insert lengths for vertex buffers bound as storage buffers
mLengthTracker->data[SingleShaderStage::Vertex][metalIndex] =
mVertexBufferBindingSizes[dawnIndex];
mLengthTracker->dirtyStages |= wgpu::ShaderStage::Vertex;
}
[encoder setVertexBuffers:&mVertexBuffers[dawnIndex] [encoder setVertexBuffers:&mVertexBuffers[dawnIndex]
offsets:&mVertexBufferOffsets[dawnIndex] offsets:&mVertexBufferOffsets[dawnIndex]
withRange:NSMakeRange(metalIndex, 1)]; withRange:NSMakeRange(metalIndex, 1)];
@ -519,6 +542,9 @@ namespace dawn_native { namespace metal {
std::bitset<kMaxVertexBuffers> mDirtyVertexBuffers; std::bitset<kMaxVertexBuffers> mDirtyVertexBuffers;
std::array<id<MTLBuffer>, kMaxVertexBuffers> mVertexBuffers; std::array<id<MTLBuffer>, kMaxVertexBuffers> mVertexBuffers;
std::array<NSUInteger, kMaxVertexBuffers> mVertexBufferOffsets; std::array<NSUInteger, kMaxVertexBuffers> mVertexBufferOffsets;
std::array<uint32_t, kMaxVertexBuffers> mVertexBufferBindingSizes;
StorageBufferLengthTracker* mLengthTracker;
}; };
} // anonymous namespace } // anonymous namespace
@ -949,11 +975,12 @@ namespace dawn_native { namespace metal {
MTLRenderPassDescriptor* mtlRenderPass, MTLRenderPassDescriptor* mtlRenderPass,
uint32_t width, uint32_t width,
uint32_t height) { uint32_t height) {
bool enableVertexPulling = GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling);
RenderPipeline* lastPipeline = nullptr; RenderPipeline* lastPipeline = nullptr;
id<MTLBuffer> indexBuffer = nil; id<MTLBuffer> indexBuffer = nil;
uint32_t indexBufferBaseOffset = 0; uint32_t indexBufferBaseOffset = 0;
VertexBufferTracker vertexBuffers;
StorageBufferLengthTracker storageBufferLengths = {}; StorageBufferLengthTracker storageBufferLengths = {};
VertexBufferTracker vertexBuffers(&storageBufferLengths);
BindGroupTracker bindGroups(&storageBufferLengths); BindGroupTracker bindGroups(&storageBufferLengths);
id<MTLRenderCommandEncoder> encoder = commandContext->BeginRender(mtlRenderPass); id<MTLRenderCommandEncoder> encoder = commandContext->BeginRender(mtlRenderPass);
@ -963,9 +990,9 @@ namespace dawn_native { namespace metal {
case Command::Draw: { case Command::Draw: {
DrawCmd* draw = iter->NextCommand<DrawCmd>(); DrawCmd* draw = iter->NextCommand<DrawCmd>();
vertexBuffers.Apply(encoder, lastPipeline); vertexBuffers.Apply(encoder, lastPipeline, enableVertexPulling);
bindGroups.Apply(encoder); bindGroups.Apply(encoder);
storageBufferLengths.Apply(encoder, lastPipeline); storageBufferLengths.Apply(encoder, lastPipeline, enableVertexPulling);
// The instance count must be non-zero, otherwise no-op // The instance count must be non-zero, otherwise no-op
if (draw->instanceCount != 0) { if (draw->instanceCount != 0) {
@ -991,9 +1018,9 @@ namespace dawn_native { namespace metal {
size_t formatSize = size_t formatSize =
IndexFormatSize(lastPipeline->GetVertexStateDescriptor()->indexFormat); IndexFormatSize(lastPipeline->GetVertexStateDescriptor()->indexFormat);
vertexBuffers.Apply(encoder, lastPipeline); vertexBuffers.Apply(encoder, lastPipeline, enableVertexPulling);
bindGroups.Apply(encoder); bindGroups.Apply(encoder);
storageBufferLengths.Apply(encoder, lastPipeline); storageBufferLengths.Apply(encoder, lastPipeline, enableVertexPulling);
// The index and instance count must be non-zero, otherwise no-op // The index and instance count must be non-zero, otherwise no-op
if (draw->indexCount != 0 && draw->instanceCount != 0) { if (draw->indexCount != 0 && draw->instanceCount != 0) {
@ -1025,9 +1052,9 @@ namespace dawn_native { namespace metal {
case Command::DrawIndirect: { case Command::DrawIndirect: {
DrawIndirectCmd* draw = iter->NextCommand<DrawIndirectCmd>(); DrawIndirectCmd* draw = iter->NextCommand<DrawIndirectCmd>();
vertexBuffers.Apply(encoder, lastPipeline); vertexBuffers.Apply(encoder, lastPipeline, enableVertexPulling);
bindGroups.Apply(encoder); bindGroups.Apply(encoder);
storageBufferLengths.Apply(encoder, lastPipeline); storageBufferLengths.Apply(encoder, lastPipeline, enableVertexPulling);
Buffer* buffer = ToBackend(draw->indirectBuffer.Get()); Buffer* buffer = ToBackend(draw->indirectBuffer.Get());
id<MTLBuffer> indirectBuffer = buffer->GetMTLBuffer(); id<MTLBuffer> indirectBuffer = buffer->GetMTLBuffer();
@ -1040,9 +1067,9 @@ namespace dawn_native { namespace metal {
case Command::DrawIndexedIndirect: { case Command::DrawIndexedIndirect: {
DrawIndirectCmd* draw = iter->NextCommand<DrawIndirectCmd>(); DrawIndirectCmd* draw = iter->NextCommand<DrawIndirectCmd>();
vertexBuffers.Apply(encoder, lastPipeline); vertexBuffers.Apply(encoder, lastPipeline, enableVertexPulling);
bindGroups.Apply(encoder); bindGroups.Apply(encoder);
storageBufferLengths.Apply(encoder, lastPipeline); storageBufferLengths.Apply(encoder, lastPipeline, enableVertexPulling);
Buffer* buffer = ToBackend(draw->indirectBuffer.Get()); Buffer* buffer = ToBackend(draw->indirectBuffer.Get());
id<MTLBuffer> indirectBuffer = buffer->GetMTLBuffer(); id<MTLBuffer> indirectBuffer = buffer->GetMTLBuffer();

View File

@ -61,6 +61,11 @@ namespace dawn_native { namespace metal {
MaybeError Device::Initialize() { MaybeError Device::Initialize() {
InitTogglesFromDriver(); InitTogglesFromDriver();
if (!IsRobustnessEnabled() || !IsToggleEnabled(Toggle::UseSpvc)) {
ForceSetToggle(Toggle::MetalEnableVertexPulling, false);
}
mCommandQueue = [mMtlDevice newCommandQueue]; mCommandQueue = [mMtlDevice newCommandQueue];
return DeviceBase::Initialize(new Queue(this)); return DeviceBase::Initialize(new Queue(this));

View File

@ -329,11 +329,24 @@ namespace dawn_native { namespace metal {
MTLRenderPipelineDescriptor* descriptorMTL = [MTLRenderPipelineDescriptor new]; MTLRenderPipelineDescriptor* descriptorMTL = [MTLRenderPipelineDescriptor new];
// TODO: MakeVertexDesc should be const in the future, so we don't need to call it here when
// vertex pulling is enabled
MTLVertexDescriptor* vertexDesc = MakeVertexDesc();
descriptorMTL.vertexDescriptor = vertexDesc;
[vertexDesc release];
if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) {
// Calling MakeVertexDesc first is important since it sets indices for packed bindings
MTLVertexDescriptor* emptyVertexDesc = [MTLVertexDescriptor new];
descriptorMTL.vertexDescriptor = emptyVertexDesc;
[emptyVertexDesc release];
}
ShaderModule* vertexModule = ToBackend(descriptor->vertexStage.module); ShaderModule* vertexModule = ToBackend(descriptor->vertexStage.module);
const char* vertexEntryPoint = descriptor->vertexStage.entryPoint; const char* vertexEntryPoint = descriptor->vertexStage.entryPoint;
ShaderModule::MetalFunctionData vertexData; ShaderModule::MetalFunctionData vertexData;
DAWN_TRY(vertexModule->GetFunction(vertexEntryPoint, SingleShaderStage::Vertex, DAWN_TRY(vertexModule->GetFunction(vertexEntryPoint, SingleShaderStage::Vertex,
ToBackend(GetLayout()), &vertexData)); ToBackend(GetLayout()), &vertexData, 0xFFFFFFFF, this));
descriptorMTL.vertexFunction = vertexData.function; descriptorMTL.vertexFunction = vertexData.function;
if (vertexData.needsStorageBufferLength) { if (vertexData.needsStorageBufferLength) {
@ -377,11 +390,6 @@ namespace dawn_native { namespace metal {
} }
descriptorMTL.inputPrimitiveTopology = MTLInputPrimitiveTopology(GetPrimitiveTopology()); descriptorMTL.inputPrimitiveTopology = MTLInputPrimitiveTopology(GetPrimitiveTopology());
MTLVertexDescriptor* vertexDesc = MakeVertexDesc();
descriptorMTL.vertexDescriptor = vertexDesc;
[vertexDesc release];
descriptorMTL.sampleCount = GetSampleCount(); descriptorMTL.sampleCount = GetSampleCount();
descriptorMTL.alphaToCoverageEnabled = descriptor->alphaToCoverageEnabled; descriptorMTL.alphaToCoverageEnabled = descriptor->alphaToCoverageEnabled;

View File

@ -29,6 +29,7 @@ namespace dawn_native { namespace metal {
class Device; class Device;
class PipelineLayout; class PipelineLayout;
class RenderPipeline;
class ShaderModule final : public ShaderModuleBase { class ShaderModule final : public ShaderModuleBase {
public: public:
@ -47,7 +48,8 @@ namespace dawn_native { namespace metal {
SingleShaderStage functionStage, SingleShaderStage functionStage,
const PipelineLayout* layout, const PipelineLayout* layout,
MetalFunctionData* out, MetalFunctionData* out,
uint32_t sampleMask = 0xFFFFFFFF); uint32_t sampleMask = 0xFFFFFFFF,
const RenderPipeline* renderPipeline = nullptr);
private: private:
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor); ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);

View File

@ -17,6 +17,7 @@
#include "dawn_native/BindGroupLayout.h" #include "dawn_native/BindGroupLayout.h"
#include "dawn_native/metal/DeviceMTL.h" #include "dawn_native/metal/DeviceMTL.h"
#include "dawn_native/metal/PipelineLayoutMTL.h" #include "dawn_native/metal/PipelineLayoutMTL.h"
#include "dawn_native/metal/RenderPipelineMTL.h"
#include <spirv_msl.hpp> #include <spirv_msl.hpp>
@ -92,10 +93,24 @@ namespace dawn_native { namespace metal {
SingleShaderStage functionStage, SingleShaderStage functionStage,
const PipelineLayout* layout, const PipelineLayout* layout,
ShaderModule::MetalFunctionData* out, ShaderModule::MetalFunctionData* out,
uint32_t sampleMask) { uint32_t sampleMask,
const RenderPipeline* renderPipeline) {
ASSERT(!IsError()); ASSERT(!IsError());
ASSERT(out); ASSERT(out);
const std::vector<uint32_t>& spirv = GetSpirv(); const std::vector<uint32_t>* spirv = &GetSpirv();
#ifdef DAWN_ENABLE_WGSL
// Use set 4 since it is bigger than what users can access currently
static const uint32_t kPullingBufferBindingSet = 4;
std::vector<uint32_t> pullingSpirv;
if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
functionStage == SingleShaderStage::Vertex) {
DAWN_TRY_ASSIGN(pullingSpirv,
GeneratePullingSpirv(*renderPipeline->GetVertexStateDescriptor(),
functionName, kPullingBufferBindingSet));
spirv = &pullingSpirv;
}
#endif
std::unique_ptr<spirv_cross::CompilerMSL> compilerImpl; std::unique_ptr<spirv_cross::CompilerMSL> compilerImpl;
spirv_cross::CompilerMSL* compiler; spirv_cross::CompilerMSL* compiler;
@ -103,7 +118,7 @@ namespace dawn_native { namespace metal {
// Initializing the compiler is needed every call, because this method uses reflection // Initializing the compiler is needed every call, because this method uses reflection
// to mutate the compiler's IR. // to mutate the compiler's IR.
DAWN_TRY( DAWN_TRY(
CheckSpvcSuccess(mSpvcContext.InitializeForMsl(spirv.data(), spirv.size(), CheckSpvcSuccess(mSpvcContext.InitializeForMsl(spirv->data(), spirv->size(),
GetMSLCompileOptions(sampleMask)), GetMSLCompileOptions(sampleMask)),
"Unable to initialize instance of spvc")); "Unable to initialize instance of spvc"));
DAWN_TRY(CheckSpvcSuccess(mSpvcContext.GetCompiler(reinterpret_cast<void**>(&compiler)), DAWN_TRY(CheckSpvcSuccess(mSpvcContext.GetCompiler(reinterpret_cast<void**>(&compiler)),
@ -126,7 +141,7 @@ namespace dawn_native { namespace metal {
options_msl.additional_fixed_sample_mask = sampleMask; options_msl.additional_fixed_sample_mask = sampleMask;
compilerImpl = std::make_unique<spirv_cross::CompilerMSL>(spirv); compilerImpl = std::make_unique<spirv_cross::CompilerMSL>(*spirv);
compiler = compilerImpl.get(); compiler = compilerImpl.get();
compiler->set_msl_options(options_msl); compiler->set_msl_options(options_msl);
} }
@ -172,6 +187,22 @@ namespace dawn_native { namespace metal {
} }
} }
// Add vertex buffers bound as storage buffers
if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
functionStage == SingleShaderStage::Vertex) {
for (uint32_t dawnIndex : IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(dawnIndex);
shaderc_spvc_msl_resource_binding mslBinding;
mslBinding.stage = ToSpvcExecutionModel(SingleShaderStage::Vertex);
mslBinding.desc_set = kPullingBufferBindingSet;
mslBinding.binding = dawnIndex;
mslBinding.msl_buffer = metalIndex;
DAWN_TRY(CheckSpvcSuccess(mSpvcContext.AddMSLResourceBinding(mslBinding),
"Unable to add MSL Resource Binding"));
}
}
{ {
if (GetDevice()->IsToggleEnabled(Toggle::UseSpvc)) { if (GetDevice()->IsToggleEnabled(Toggle::UseSpvc)) {
shaderc_spvc_execution_model executionModel = ToSpvcExecutionModel(functionStage); shaderc_spvc_execution_model executionModel = ToSpvcExecutionModel(functionStage);
@ -245,6 +276,11 @@ namespace dawn_native { namespace metal {
out->needsStorageBufferLength = compiler->needs_buffer_size_buffer(); out->needsStorageBufferLength = compiler->needs_buffer_size_buffer();
} }
if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
functionStage == SingleShaderStage::Vertex && GetUsedVertexAttributes().any()) {
out->needsStorageBufferLength = true;
}
return {}; return {};
} }

View File

@ -337,6 +337,10 @@ source_set("dawn_end2end_tests_sources") {
frameworks = [ "IOSurface.framework" ] frameworks = [ "IOSurface.framework" ]
} }
if (dawn_enable_wgsl) {
sources += [ "end2end/VertexBufferRobustnessTests.cpp" ]
}
if (dawn_enable_opengl) { if (dawn_enable_opengl) {
assert(dawn_supports_glfw_for_windowing) assert(dawn_supports_glfw_for_windowing)
} }

View File

@ -0,0 +1,231 @@
// Copyright 2020 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "common/Assert.h"
#include "common/Constants.h"
#include "common/Math.h"
#include "tests/DawnTest.h"
#include "utils/ComboRenderPipelineDescriptor.h"
#include "utils/WGPUHelpers.h"
// Vertex buffer robustness tests that clamping is applied on vertex attributes. This would happen
// on backends where vertex pulling is enabled, such as Metal.
class VertexBufferRobustnessTest : public DawnTest {
protected:
void SetUp() override {
DawnTest::SetUp();
// SPVC must be used currently, since we rely on the robustness pass in it
DAWN_SKIP_TEST_IF(!IsSpvcBeingUsed());
}
// Creates a vertex module that tests an expression with given attributes. If successful, the
// point drawn would be moved out of the viewport. On failure, the point is kept inside the
// viewport.
wgpu::ShaderModule CreateVertexModule(const std::string& attributes,
const std::string& successExpression) {
return utils::CreateShaderModuleFromWGSL(device, (R"(
entry_point vertex as "main" = vtx_main;
)" + attributes + R"(
[[builtin position]] var<out> Position : vec4<f32>;
fn vtx_main() -> void {
if ()" + successExpression + R"() {
# Success case, move the vertex out of the viewport
Position = vec4<f32>(-10.0, 0.0, 0.0, 1.0);
} else {
# Failure case, move the vertex inside the viewport
Position = vec4<f32>(0.0, 0.0, 0.0, 1.0);
}
return;
}
)")
.c_str());
}
// Runs the test, a true |expectation| meaning success
void DoTest(const std::string& attributes,
const std::string& successExpression,
utils::ComboVertexStateDescriptor vertexState,
wgpu::Buffer vertexBuffer,
uint64_t bufferOffset,
bool expectation) {
wgpu::ShaderModule vsModule = CreateVertexModule(attributes, successExpression);
wgpu::ShaderModule fsModule = utils::CreateShaderModuleFromWGSL(device, R"(
entry_point fragment as "main" = frag_main;
[[location 0]] var<out> outColor : vec4<f32>;
fn frag_main() -> void {
outColor = vec4<f32>(1.0, 1.0, 1.0, 1.0);
return;
}
)");
utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, 1, 1);
utils::ComboRenderPipelineDescriptor descriptor(device);
descriptor.vertexStage.module = vsModule;
descriptor.cFragmentStage.module = fsModule;
descriptor.primitiveTopology = wgpu::PrimitiveTopology::PointList;
descriptor.cVertexState = std::move(vertexState);
descriptor.cColorStates[0].format = renderPass.colorFormat;
renderPass.renderPassInfo.cColorAttachments[0].clearColor = {0, 0, 0, 1};
wgpu::RenderPipeline pipeline = device.CreateRenderPipeline(&descriptor);
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
pass.SetPipeline(pipeline);
pass.SetVertexBuffer(0, vertexBuffer, bufferOffset);
pass.Draw(1000);
pass.EndPass();
wgpu::CommandBuffer commands = encoder.Finish();
queue.Submit(1, &commands);
RGBA8 noOutput(0, 0, 0, 255);
RGBA8 someOutput(255, 255, 255, 255);
EXPECT_PIXEL_RGBA8_EQ(expectation ? noOutput : someOutput, renderPass.color, 0, 0);
}
};
TEST_P(VertexBufferRobustnessTest, DetectInvalidValues) {
utils::ComboVertexStateDescriptor vertexState;
vertexState.vertexBufferCount = 1;
vertexState.cVertexBuffers[0].arrayStride = sizeof(float);
vertexState.cVertexBuffers[0].attributeCount = 1;
vertexState.cAttributes[0].format = wgpu::VertexFormat::Float;
vertexState.cAttributes[0].offset = 0;
vertexState.cAttributes[0].shaderLocation = 0;
// Bind at an offset of 0, so we see 111.0, leading to failure
float kVertices[] = {111.0, 473.0, 473.0};
wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
wgpu::BufferUsage::Vertex);
DoTest("[[location 0]] var<in> a : f32;", "a == 473.0", std::move(vertexState), vertexBuffer, 0,
false);
}
TEST_P(VertexBufferRobustnessTest, FloatClamp) {
utils::ComboVertexStateDescriptor vertexState;
vertexState.vertexBufferCount = 1;
vertexState.cVertexBuffers[0].arrayStride = sizeof(float);
vertexState.cVertexBuffers[0].attributeCount = 1;
vertexState.cAttributes[0].format = wgpu::VertexFormat::Float;
vertexState.cAttributes[0].offset = 0;
vertexState.cAttributes[0].shaderLocation = 0;
// Bind at an offset of 4, so we clamp to only values containing 473.0
float kVertices[] = {111.0, 473.0, 473.0};
wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
wgpu::BufferUsage::Vertex);
DoTest("[[location 0]] var<in> a : f32;", "a == 473.0", std::move(vertexState), vertexBuffer, 4,
true);
}
TEST_P(VertexBufferRobustnessTest, IntClamp) {
utils::ComboVertexStateDescriptor vertexState;
vertexState.vertexBufferCount = 1;
vertexState.cVertexBuffers[0].arrayStride = sizeof(int32_t);
vertexState.cVertexBuffers[0].attributeCount = 1;
vertexState.cAttributes[0].format = wgpu::VertexFormat::Int;
vertexState.cAttributes[0].offset = 0;
vertexState.cAttributes[0].shaderLocation = 0;
// Bind at an offset of 4, so we clamp to only values containing 473
int32_t kVertices[] = {111, 473, 473};
wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
wgpu::BufferUsage::Vertex);
DoTest("[[location 0]] var<in> a : i32;", "a == 473", std::move(vertexState), vertexBuffer, 4,
true);
}
TEST_P(VertexBufferRobustnessTest, UIntClamp) {
utils::ComboVertexStateDescriptor vertexState;
vertexState.vertexBufferCount = 1;
vertexState.cVertexBuffers[0].arrayStride = sizeof(uint32_t);
vertexState.cVertexBuffers[0].attributeCount = 1;
vertexState.cAttributes[0].format = wgpu::VertexFormat::UInt;
vertexState.cAttributes[0].offset = 0;
vertexState.cAttributes[0].shaderLocation = 0;
// Bind at an offset of 4, so we clamp to only values containing 473
uint32_t kVertices[] = {111, 473, 473};
wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
wgpu::BufferUsage::Vertex);
DoTest("[[location 0]] var<in> a : u32;", "a == 473", std::move(vertexState), vertexBuffer, 4,
true);
}
TEST_P(VertexBufferRobustnessTest, Float2Clamp) {
utils::ComboVertexStateDescriptor vertexState;
vertexState.vertexBufferCount = 1;
vertexState.cVertexBuffers[0].arrayStride = sizeof(float) * 2;
vertexState.cVertexBuffers[0].attributeCount = 1;
vertexState.cAttributes[0].format = wgpu::VertexFormat::Float2;
vertexState.cAttributes[0].offset = 0;
vertexState.cAttributes[0].shaderLocation = 0;
// Bind at an offset of 8, so we clamp to only values containing 473.0
float kVertices[] = {111.0, 111.0, 473.0, 473.0};
wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
wgpu::BufferUsage::Vertex);
DoTest("[[location 0]] var<in> a : vec2<f32>;", "a[0] == 473.0 && a[1] == 473.0",
std::move(vertexState), vertexBuffer, 8, true);
}
TEST_P(VertexBufferRobustnessTest, Float3Clamp) {
utils::ComboVertexStateDescriptor vertexState;
vertexState.vertexBufferCount = 1;
vertexState.cVertexBuffers[0].arrayStride = sizeof(float) * 3;
vertexState.cVertexBuffers[0].attributeCount = 1;
vertexState.cAttributes[0].format = wgpu::VertexFormat::Float3;
vertexState.cAttributes[0].offset = 0;
vertexState.cAttributes[0].shaderLocation = 0;
// Bind at an offset of 12, so we clamp to only values containing 473.0
float kVertices[] = {111.0, 111.0, 111.0, 473.0, 473.0, 473.0};
wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
wgpu::BufferUsage::Vertex);
DoTest("[[location 0]] var<in> a : vec3<f32>;",
"a[0] == 473.0 && a[1] == 473.0 && a[2] == 473.0", std::move(vertexState), vertexBuffer,
12, true);
}
TEST_P(VertexBufferRobustnessTest, Float4Clamp) {
utils::ComboVertexStateDescriptor vertexState;
vertexState.vertexBufferCount = 1;
vertexState.cVertexBuffers[0].arrayStride = sizeof(float) * 4;
vertexState.cVertexBuffers[0].attributeCount = 1;
vertexState.cAttributes[0].format = wgpu::VertexFormat::Float4;
vertexState.cAttributes[0].offset = 0;
vertexState.cAttributes[0].shaderLocation = 0;
// Bind at an offset of 16, so we clamp to only values containing 473.0
float kVertices[] = {111.0, 111.0, 111.0, 111.0, 473.0, 473.0, 473.0, 473.0};
wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
wgpu::BufferUsage::Vertex);
DoTest("[[location 0]] var<in> a : vec4<f32>;",
"a[0] == 473.0 && a[1] == 473.0 && a[2] == 473.0 && a[3] == 473.0",
std::move(vertexState), vertexBuffer, 16, true);
}
DAWN_INSTANTIATE_TEST(VertexBufferRobustnessTest, MetalBackend({"metal_enable_vertex_pulling"}));

View File

@ -144,6 +144,14 @@ namespace utils {
return CreateShaderModuleFromResult(device, result); return CreateShaderModuleFromResult(device, result);
} }
wgpu::ShaderModule CreateShaderModuleFromWGSL(const wgpu::Device& device, const char* source) {
wgpu::ShaderModuleWGSLDescriptor wgslDesc;
wgslDesc.source = source;
wgpu::ShaderModuleDescriptor descriptor;
descriptor.nextInChain = &wgslDesc;
return device.CreateShaderModule(&descriptor);
}
std::vector<uint32_t> CompileGLSLToSpirv(SingleShaderStage stage, const char* source) { std::vector<uint32_t> CompileGLSLToSpirv(SingleShaderStage stage, const char* source) {
shaderc_shader_kind kind = ShadercShaderKind(stage); shaderc_shader_kind kind = ShadercShaderKind(stage);

View File

@ -34,6 +34,8 @@ namespace utils {
SingleShaderStage stage, SingleShaderStage stage,
const char* source); const char* source);
wgpu::ShaderModule CreateShaderModuleFromASM(const wgpu::Device& device, const char* source); wgpu::ShaderModule CreateShaderModuleFromASM(const wgpu::Device& device, const char* source);
wgpu::ShaderModule CreateShaderModuleFromWGSL(const wgpu::Device& device, const char* source);
std::vector<uint32_t> CompileGLSLToSpirv(SingleShaderStage stage, const char* source); std::vector<uint32_t> CompileGLSLToSpirv(SingleShaderStage stage, const char* source);
wgpu::Buffer CreateBufferFromData(const wgpu::Device& device, wgpu::Buffer CreateBufferFromData(const wgpu::Device& device,