Propagating errors out of GetFunction in MTL backend

BUG=dawn:303

Change-Id: Iff1903aecae4c043b222208b3eab5efdf9774b52
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/14501
Commit-Queue: Ryan Harrison <rharrison@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
This commit is contained in:
Ryan Harrison 2019-12-12 17:51:39 +00:00 committed by Commit Bot service account
parent 69cdaf94df
commit 5c413afdc7
8 changed files with 88 additions and 44 deletions

View File

@ -25,7 +25,8 @@ namespace dawn_native { namespace metal {
class ComputePipeline : public ComputePipelineBase { class ComputePipeline : public ComputePipelineBase {
public: public:
ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor); static ResultOrError<ComputePipeline*> Create(Device* device,
const ComputePipelineDescriptor* descriptor);
~ComputePipeline(); ~ComputePipeline();
void Encode(id<MTLComputeCommandEncoder> encoder); void Encode(id<MTLComputeCommandEncoder> encoder);
@ -33,6 +34,9 @@ namespace dawn_native { namespace metal {
bool RequiresStorageBufferLength() const; bool RequiresStorageBufferLength() const;
private: private:
using ComputePipelineBase::ComputePipelineBase;
MaybeError Initialize(const ComputePipelineDescriptor* descriptor);
id<MTLComputePipelineState> mMtlComputePipelineState = nil; id<MTLComputePipelineState> mMtlComputePipelineState = nil;
MTLSize mLocalWorkgroupSize; MTLSize mLocalWorkgroupSize;
bool mRequiresStorageBufferLength; bool mRequiresStorageBufferLength;

View File

@ -19,27 +19,37 @@
namespace dawn_native { namespace metal { namespace dawn_native { namespace metal {
ComputePipeline::ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor) // static
: ComputePipelineBase(device, descriptor) { ResultOrError<ComputePipeline*> ComputePipeline::Create(
Device* device,
const ComputePipelineDescriptor* descriptor) {
std::unique_ptr<ComputePipeline> pipeline =
std::make_unique<ComputePipeline>(device, descriptor);
DAWN_TRY(pipeline->Initialize(descriptor));
return pipeline.release();
}
MaybeError ComputePipeline::Initialize(const ComputePipelineDescriptor* descriptor) {
auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice(); auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice();
const ShaderModule* computeModule = ToBackend(descriptor->computeStage.module); const ShaderModule* computeModule = ToBackend(descriptor->computeStage.module);
const char* computeEntryPoint = descriptor->computeStage.entryPoint; const char* computeEntryPoint = descriptor->computeStage.entryPoint;
ShaderModule::MetalFunctionData computeData = computeModule->GetFunction( ShaderModule::MetalFunctionData computeData;
computeEntryPoint, SingleShaderStage::Compute, ToBackend(GetLayout())); DAWN_TRY(computeModule->GetFunction(computeEntryPoint, SingleShaderStage::Compute,
ToBackend(GetLayout()), &computeData));
NSError* error = nil; NSError* error = nil;
mMtlComputePipelineState = mMtlComputePipelineState =
[mtlDevice newComputePipelineStateWithFunction:computeData.function error:&error]; [mtlDevice newComputePipelineStateWithFunction:computeData.function error:&error];
if (error != nil) { if (error != nil) {
NSLog(@" error => %@", error); NSLog(@" error => %@", error);
GetDevice()->HandleError(wgpu::ErrorType::DeviceLost, "Error creating pipeline state"); return DAWN_DEVICE_LOST_ERROR("Error creating pipeline state");
return;
} }
// Copy over the local workgroup size as it is passed to dispatch explicitly in Metal // Copy over the local workgroup size as it is passed to dispatch explicitly in Metal
mLocalWorkgroupSize = computeData.localWorkgroupSize; mLocalWorkgroupSize = computeData.localWorkgroupSize;
mRequiresStorageBufferLength = computeData.needsStorageBufferLength; mRequiresStorageBufferLength = computeData.needsStorageBufferLength;
return {};
} }
ComputePipeline::~ComputePipeline() { ComputePipeline::~ComputePipeline() {

View File

@ -111,7 +111,7 @@ namespace dawn_native { namespace metal {
} }
ResultOrError<ComputePipelineBase*> Device::CreateComputePipelineImpl( ResultOrError<ComputePipelineBase*> Device::CreateComputePipelineImpl(
const ComputePipelineDescriptor* descriptor) { const ComputePipelineDescriptor* descriptor) {
return new ComputePipeline(this, descriptor); return ComputePipeline::Create(this, descriptor);
} }
ResultOrError<PipelineLayoutBase*> Device::CreatePipelineLayoutImpl( ResultOrError<PipelineLayoutBase*> Device::CreatePipelineLayoutImpl(
const PipelineLayoutDescriptor* descriptor) { const PipelineLayoutDescriptor* descriptor) {
@ -122,7 +122,7 @@ namespace dawn_native { namespace metal {
} }
ResultOrError<RenderPipelineBase*> Device::CreateRenderPipelineImpl( ResultOrError<RenderPipelineBase*> Device::CreateRenderPipelineImpl(
const RenderPipelineDescriptor* descriptor) { const RenderPipelineDescriptor* descriptor) {
return new RenderPipeline(this, descriptor); return RenderPipeline::Create(this, descriptor);
} }
ResultOrError<SamplerBase*> Device::CreateSamplerImpl(const SamplerDescriptor* descriptor) { ResultOrError<SamplerBase*> Device::CreateSamplerImpl(const SamplerDescriptor* descriptor) {
return new Sampler(this, descriptor); return new Sampler(this, descriptor);

View File

@ -25,7 +25,8 @@ namespace dawn_native { namespace metal {
class RenderPipeline : public RenderPipelineBase { class RenderPipeline : public RenderPipelineBase {
public: public:
RenderPipeline(Device* device, const RenderPipelineDescriptor* descriptor); static ResultOrError<RenderPipeline*> Create(Device* device,
const RenderPipelineDescriptor* descriptor);
~RenderPipeline(); ~RenderPipeline();
MTLIndexType GetMTLIndexType() const; MTLIndexType GetMTLIndexType() const;
@ -44,6 +45,9 @@ namespace dawn_native { namespace metal {
wgpu::ShaderStage GetStagesRequiringStorageBufferLength() const; wgpu::ShaderStage GetStagesRequiringStorageBufferLength() const;
private: private:
using RenderPipelineBase::RenderPipelineBase;
MaybeError Initialize(const RenderPipelineDescriptor* descriptor);
MTLVertexDescriptor* MakeVertexDesc(); MTLVertexDescriptor* MakeVertexDesc();
MTLIndexType mMtlIndexType; MTLIndexType mMtlIndexType;

View File

@ -311,20 +311,31 @@ namespace dawn_native { namespace metal {
} // anonymous namespace } // anonymous namespace
RenderPipeline::RenderPipeline(Device* device, const RenderPipelineDescriptor* descriptor) // static
: RenderPipelineBase(device, descriptor), ResultOrError<RenderPipeline*> RenderPipeline::Create(
mMtlIndexType(MTLIndexFormat(GetVertexStateDescriptor()->indexFormat)), Device* device,
mMtlPrimitiveTopology(MTLPrimitiveTopology(GetPrimitiveTopology())), const RenderPipelineDescriptor* descriptor) {
mMtlFrontFace(MTLFrontFace(GetFrontFace())), std::unique_ptr<RenderPipeline> pipeline =
mMtlCullMode(ToMTLCullMode(GetCullMode())) { std::make_unique<RenderPipeline>(device, descriptor);
auto mtlDevice = device->GetMTLDevice(); DAWN_TRY(pipeline->Initialize(descriptor));
return pipeline.release();
}
MaybeError RenderPipeline::Initialize(const RenderPipelineDescriptor* descriptor) {
mMtlIndexType = MTLIndexFormat(GetVertexStateDescriptor()->indexFormat);
mMtlPrimitiveTopology = MTLPrimitiveTopology(GetPrimitiveTopology());
mMtlFrontFace = MTLFrontFace(GetFrontFace());
mMtlCullMode = ToMTLCullMode(GetCullMode());
auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice();
MTLRenderPipelineDescriptor* descriptorMTL = [MTLRenderPipelineDescriptor new]; MTLRenderPipelineDescriptor* descriptorMTL = [MTLRenderPipelineDescriptor new];
const ShaderModule* vertexModule = ToBackend(descriptor->vertexStage.module); const ShaderModule* vertexModule = ToBackend(descriptor->vertexStage.module);
const char* vertexEntryPoint = descriptor->vertexStage.entryPoint; const char* vertexEntryPoint = descriptor->vertexStage.entryPoint;
ShaderModule::MetalFunctionData vertexData = vertexModule->GetFunction( ShaderModule::MetalFunctionData vertexData;
vertexEntryPoint, SingleShaderStage::Vertex, ToBackend(GetLayout())); DAWN_TRY(vertexModule->GetFunction(vertexEntryPoint, SingleShaderStage::Vertex,
ToBackend(GetLayout()), &vertexData));
descriptorMTL.vertexFunction = vertexData.function; descriptorMTL.vertexFunction = vertexData.function;
if (vertexData.needsStorageBufferLength) { if (vertexData.needsStorageBufferLength) {
mStagesRequiringStorageBufferLength |= wgpu::ShaderStage::Vertex; mStagesRequiringStorageBufferLength |= wgpu::ShaderStage::Vertex;
@ -332,8 +343,10 @@ namespace dawn_native { namespace metal {
const ShaderModule* fragmentModule = ToBackend(descriptor->fragmentStage->module); const ShaderModule* fragmentModule = ToBackend(descriptor->fragmentStage->module);
const char* fragmentEntryPoint = descriptor->fragmentStage->entryPoint; const char* fragmentEntryPoint = descriptor->fragmentStage->entryPoint;
ShaderModule::MetalFunctionData fragmentData = fragmentModule->GetFunction( ShaderModule::MetalFunctionData fragmentData;
fragmentEntryPoint, SingleShaderStage::Fragment, ToBackend(GetLayout())); DAWN_TRY(fragmentModule->GetFunction(fragmentEntryPoint, SingleShaderStage::Fragment,
ToBackend(GetLayout()), &fragmentData));
descriptorMTL.fragmentFunction = fragmentData.function; descriptorMTL.fragmentFunction = fragmentData.function;
if (fragmentData.needsStorageBufferLength) { if (fragmentData.needsStorageBufferLength) {
mStagesRequiringStorageBufferLength |= wgpu::ShaderStage::Fragment; mStagesRequiringStorageBufferLength |= wgpu::ShaderStage::Fragment;
@ -372,9 +385,7 @@ namespace dawn_native { namespace metal {
[descriptorMTL release]; [descriptorMTL release];
if (error != nil) { if (error != nil) {
NSLog(@" error => %@", error); NSLog(@" error => %@", error);
device->HandleError(wgpu::ErrorType::DeviceLost, return DAWN_DEVICE_LOST_ERROR("Error creating rendering pipeline state");
"Error creating rendering pipeline state");
return;
} }
} }
@ -385,6 +396,8 @@ namespace dawn_native { namespace metal {
MakeDepthStencilDesc(GetDepthStencilStateDescriptor()); MakeDepthStencilDesc(GetDepthStencilStateDescriptor());
mMtlDepthStencilState = [mtlDevice newDepthStencilStateWithDescriptor:depthStencilDesc]; mMtlDepthStencilState = [mtlDevice newDepthStencilStateWithDescriptor:depthStencilDesc];
[depthStencilDesc release]; [depthStencilDesc release];
return {};
} }
RenderPipeline::~RenderPipeline() { RenderPipeline::~RenderPipeline() {

View File

@ -43,9 +43,10 @@ namespace dawn_native { namespace metal {
[function release]; [function release];
} }
}; };
MetalFunctionData GetFunction(const char* functionName, MaybeError GetFunction(const char* functionName,
SingleShaderStage functionStage, SingleShaderStage functionStage,
const PipelineLayout* layout) const; const PipelineLayout* layout,
MetalFunctionData* out) const;
private: private:
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor); ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);

View File

@ -92,17 +92,22 @@ namespace dawn_native { namespace metal {
return {}; return {};
} }
ShaderModule::MetalFunctionData ShaderModule::GetFunction(const char* functionName, MaybeError ShaderModule::GetFunction(const char* functionName,
SingleShaderStage functionStage, SingleShaderStage functionStage,
const PipelineLayout* layout) const { const PipelineLayout* layout,
ShaderModule::MetalFunctionData* out) const {
ASSERT(!IsError());
ASSERT(out);
std::unique_ptr<spirv_cross::CompilerMSL> compiler_impl; std::unique_ptr<spirv_cross::CompilerMSL> compiler_impl;
spirv_cross::CompilerMSL* compiler; spirv_cross::CompilerMSL* compiler;
if (GetDevice()->IsToggleEnabled(Toggle::UseSpvc)) { if (GetDevice()->IsToggleEnabled(Toggle::UseSpvc)) {
// Initializing the compiler is needed every call, because this method uses reflection // Initializing the compiler is needed every call, because this method uses reflection
// to mutate the compiler's IR. // to mutate the compiler's IR.
mSpvcContext.InitializeForMsl(mSpirv.data(), mSpirv.size(), GetMSLCompileOptions()); if (mSpvcContext.InitializeForMsl(mSpirv.data(), mSpirv.size(),
// TODO(rharrison): Handle initialize failing GetMSLCompileOptions()) !=
shaderc_spvc_status_success) {
return DAWN_DEVICE_LOST_ERROR("Unable to initialize instance of spvc");
}
compiler = reinterpret_cast<spirv_cross::CompilerMSL*>(mSpvcContext.GetCompiler()); compiler = reinterpret_cast<spirv_cross::CompilerMSL*>(mSpvcContext.GetCompiler());
} else { } else {
// If these options are changed, the values in DawnSPIRVCrossMSLFastFuzzer.cpp need to // If these options are changed, the values in DawnSPIRVCrossMSLFastFuzzer.cpp need to
@ -147,12 +152,10 @@ namespace dawn_native { namespace metal {
} }
} }
MetalFunctionData result;
{ {
spv::ExecutionModel executionModel = SpirvExecutionModelForStage(functionStage); spv::ExecutionModel executionModel = SpirvExecutionModelForStage(functionStage);
auto size = compiler->get_entry_point(functionName, executionModel).workgroup_size; auto size = compiler->get_entry_point(functionName, executionModel).workgroup_size;
result.localWorkgroupSize = MTLSizeMake(size.x, size.y, size.z); out->localWorkgroupSize = MTLSizeMake(size.x, size.y, size.z);
} }
{ {
@ -167,9 +170,14 @@ namespace dawn_native { namespace metal {
options:nil options:nil
error:&error]; error:&error];
if (error != nil) { if (error != nil) {
// TODO(cwallez@chromium.org): forward errors to caller // TODO(cwallez@chromium.org): Switch that NSLog to use dawn::InfoLog or even be
// folded in the DAWN_VALIDATION_ERROR
NSLog(@"MTLDevice newLibraryWithSource => %@", error); NSLog(@"MTLDevice newLibraryWithSource => %@", error);
if (error.code != MTLLibraryErrorCompileWarning) {
return DAWN_VALIDATION_ERROR("Unable to create library object");
}
} }
// TODO(kainino@chromium.org): make this somehow more robust; it needs to behave like // TODO(kainino@chromium.org): make this somehow more robust; it needs to behave like
// clean_func_name: // clean_func_name:
// https://github.com/KhronosGroup/SPIRV-Cross/blob/4e915e8c483e319d0dd7a1fa22318bef28f8cca3/spirv_msl.cpp#L1213 // https://github.com/KhronosGroup/SPIRV-Cross/blob/4e915e8c483e319d0dd7a1fa22318bef28f8cca3/spirv_msl.cpp#L1213
@ -178,13 +186,13 @@ namespace dawn_native { namespace metal {
} }
NSString* name = [NSString stringWithFormat:@"%s", functionName]; NSString* name = [NSString stringWithFormat:@"%s", functionName];
result.function = [library newFunctionWithName:name]; out->function = [library newFunctionWithName:name];
[library release]; [library release];
} }
result.needsStorageBufferLength = compiler->needs_buffer_size_buffer(); out->needsStorageBufferLength = compiler->needs_buffer_size_buffer();
return result; return {};
} }
}} // namespace dawn_native::metal }} // namespace dawn_native::metal

View File

@ -139,14 +139,16 @@ TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnShaderModule) {
wgpu::ShaderModule module = wgpu::ShaderModule module =
utils::CreateShaderModule(device, utils::SingleShaderStage::Compute, R"( utils::CreateShaderModule(device, utils::SingleShaderStage::Compute, R"(
#version 450 #version 450
shared uint i;
void main() { void main() {
int i = 0; i = 0;
})"); })");
wgpu::ShaderModule sameModule = wgpu::ShaderModule sameModule =
utils::CreateShaderModule(device, utils::SingleShaderStage::Compute, R"( utils::CreateShaderModule(device, utils::SingleShaderStage::Compute, R"(
#version 450 #version 450
shared uint i;
void main() { void main() {
int i = 0; i = 0;
})"); })");
wgpu::ShaderModule otherModule = wgpu::ShaderModule otherModule =
utils::CreateShaderModule(device, utils::SingleShaderStage::Compute, R"( utils::CreateShaderModule(device, utils::SingleShaderStage::Compute, R"(
@ -195,8 +197,9 @@ TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnLayout) {
desc.computeStage.module = desc.computeStage.module =
utils::CreateShaderModule(device, utils::SingleShaderStage::Compute, R"( utils::CreateShaderModule(device, utils::SingleShaderStage::Compute, R"(
#version 450 #version 450
shared uint i;
void main() { void main() {
int i = 0; i = 0;
})"); })");
desc.layout = pl; desc.layout = pl;
@ -311,8 +314,9 @@ TEST_P(ObjectCachingTest, RenderPipelineDeduplicationOnFragmentModule) {
wgpu::ShaderModule otherModule = wgpu::ShaderModule otherModule =
utils::CreateShaderModule(device, utils::SingleShaderStage::Fragment, R"( utils::CreateShaderModule(device, utils::SingleShaderStage::Fragment, R"(
#version 450 #version 450
layout (location = 0) out vec4 color;
void main() { void main() {
int i = 0; color = vec4(0.0);
})"); })");
EXPECT_NE(module.Get(), otherModule.Get()); EXPECT_NE(module.Get(), otherModule.Get());