diff --git a/src/dawn_native/CreatePipelineAsyncTracker.cpp b/src/dawn_native/CreatePipelineAsyncTracker.cpp index a9fcf620a1..23b8310c43 100644 --- a/src/dawn_native/CreatePipelineAsyncTracker.cpp +++ b/src/dawn_native/CreatePipelineAsyncTracker.cpp @@ -128,11 +128,7 @@ namespace dawn_native { mCreatePipelineAsyncTasksInFlight.ClearUpTo(finishedSerial); for (auto& task : tasks) { - if (mDevice->IsLost()) { - task->HandleDeviceLoss(); - } else { - task->Finish(); - } + task->Finish(); } } @@ -143,4 +139,11 @@ namespace dawn_native { mCreatePipelineAsyncTasksInFlight.Clear(); } + void CreatePipelineAsyncTracker::ClearForDeviceLoss() { + for (auto& task : mCreatePipelineAsyncTasksInFlight.IterateAll()) { + task->HandleDeviceLoss(); + } + mCreatePipelineAsyncTasksInFlight.Clear(); + } + } // namespace dawn_native diff --git a/src/dawn_native/CreatePipelineAsyncTracker.h b/src/dawn_native/CreatePipelineAsyncTracker.h index b84daed2e6..738d71930f 100644 --- a/src/dawn_native/CreatePipelineAsyncTracker.h +++ b/src/dawn_native/CreatePipelineAsyncTracker.h @@ -80,6 +80,7 @@ namespace dawn_native { void TrackTask(std::unique_ptr task, ExecutionSerial serial); void Tick(ExecutionSerial finishedSerial); void ClearForShutDown(); + void ClearForDeviceLoss(); private: DeviceBase* mDevice; diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp index 2a493c7a68..713b88eb2f 100644 --- a/src/dawn_native/Device.cpp +++ b/src/dawn_native/Device.cpp @@ -238,6 +238,7 @@ namespace dawn_native { } mQueue->HandleDeviceLoss(); + mCreatePipelineAsyncTracker->ClearForDeviceLoss(); // Still forward device loss errors to the error scopes so they all reject. mErrorScopeStack->HandleError(ToWGPUErrorType(type), message); diff --git a/src/tests/end2end/DeviceLostTests.cpp b/src/tests/end2end/DeviceLostTests.cpp index d420a0a559..6989caf2a5 100644 --- a/src/tests/end2end/DeviceLostTests.cpp +++ b/src/tests/end2end/DeviceLostTests.cpp @@ -554,6 +554,26 @@ TEST_P(DeviceLostTest, DeviceLostDoesntCallUncapturedError) { device.LoseForTesting(); } +// Test that WGPUCreatePipelineAsyncStatus_DeviceLost can be correctly returned when device is lost +// before the callback of Create*PipelineAsync() is called. +TEST_P(DeviceLostTest, DeviceLostBeforeCreatePipelineAsyncCallback) { + wgpu::ShaderModule csModule = utils::CreateShaderModule(device, R"( + [[stage(compute)]] fn main() { + })"); + + wgpu::ComputePipelineDescriptor descriptor; + descriptor.computeStage.module = csModule; + descriptor.computeStage.entryPoint = "main"; + + auto callback = [](WGPUCreatePipelineAsyncStatus status, WGPUComputePipeline returnPipeline, + const char* message, void* userdata) { + EXPECT_EQ(WGPUCreatePipelineAsyncStatus::WGPUCreatePipelineAsyncStatus_DeviceLost, status); + }; + + device.CreateComputePipelineAsync(&descriptor, callback, nullptr); + SetCallbackAndLoseForTesting(); +} + DAWN_INSTANTIATE_TEST(DeviceLostTest, D3D12Backend(), MetalBackend(),