Fix use-after-free issue in Create*PipelineAsyncTasks::Run()

This patch fixes a use-after-free issue in Create*PipelineAsyncTasks
that when pipeline->Initialize() returns error, the pipeline object
will be deleted, while we still attempt to call its member function
after it is deleted.

BUG=dawn:1310
TEST=dawn_unittests

Change-Id: I57d5ca98d6c97b14df1d7c3bf2941c9cc87adeff
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/81800
Reviewed-by: Loko Kung <lokokung@google.com>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
This commit is contained in:
Jiawei Shao 2022-02-25 16:28:39 +00:00 committed by Dawn LUCI CQ
parent d606233833
commit ec3e4b0510
3 changed files with 88 additions and 10 deletions

View File

@ -114,11 +114,13 @@ namespace dawn::native {
void CreateComputePipelineAsyncTask::Run() {
const char* eventLabel = utils::GetLabelForTrace(mComputePipeline->GetLabel().c_str());
TRACE_EVENT_FLOW_END1(mComputePipeline->GetDevice()->GetPlatform(), General,
DeviceBase* device = mComputePipeline->GetDevice();
TRACE_EVENT_FLOW_END1(device->GetPlatform(), General,
"CreateComputePipelineAsyncTask::RunAsync", this, "label",
eventLabel);
TRACE_EVENT1(mComputePipeline->GetDevice()->GetPlatform(), General,
"CreateComputePipelineAsyncTask::Run", "label", eventLabel);
TRACE_EVENT1(device->GetPlatform(), General, "CreateComputePipelineAsyncTask::Run", "label",
eventLabel);
MaybeError maybeError = mComputePipeline->Initialize();
std::string errorMessage;
@ -127,8 +129,8 @@ namespace dawn::native {
errorMessage = maybeError.AcquireError()->GetMessage();
}
mComputePipeline->GetDevice()->AddComputePipelineAsyncCallbackTask(
mComputePipeline, errorMessage, mCallback, mUserdata);
device->AddComputePipelineAsyncCallbackTask(mComputePipeline, errorMessage, mCallback,
mUserdata);
}
void CreateComputePipelineAsyncTask::RunAsync(
@ -164,10 +166,12 @@ namespace dawn::native {
void CreateRenderPipelineAsyncTask::Run() {
const char* eventLabel = utils::GetLabelForTrace(mRenderPipeline->GetLabel().c_str());
TRACE_EVENT_FLOW_END1(mRenderPipeline->GetDevice()->GetPlatform(), General,
DeviceBase* device = mRenderPipeline->GetDevice();
TRACE_EVENT_FLOW_END1(device->GetPlatform(), General,
"CreateRenderPipelineAsyncTask::RunAsync", this, "label", eventLabel);
TRACE_EVENT1(mRenderPipeline->GetDevice()->GetPlatform(), General,
"CreateRenderPipelineAsyncTask::Run", "label", eventLabel);
TRACE_EVENT1(device->GetPlatform(), General, "CreateRenderPipelineAsyncTask::Run", "label",
eventLabel);
MaybeError maybeError = mRenderPipeline->Initialize();
std::string errorMessage;
@ -176,8 +180,8 @@ namespace dawn::native {
errorMessage = maybeError.AcquireError()->GetMessage();
}
mRenderPipeline->GetDevice()->AddRenderPipelineAsyncCallbackTask(
mRenderPipeline, errorMessage, mCallback, mUserdata);
device->AddRenderPipelineAsyncCallbackTask(mRenderPipeline, errorMessage, mCallback,
mUserdata);
}
void CreateRenderPipelineAsyncTask::RunAsync(

View File

@ -239,6 +239,7 @@ dawn_test("dawn_unittests") {
"unittests/ToBackendTests.cpp",
"unittests/TypedIntegerTests.cpp",
"unittests/native/CommandBufferEncodingTests.cpp",
"unittests/native/CreatePipelineAsyncTaskTests.cpp",
"unittests/native/DestroyObjectTests.cpp",
"unittests/native/DeviceCreationTests.cpp",
"unittests/validation/BindGroupValidationTests.cpp",

View File

@ -0,0 +1,73 @@
// Copyright 2022 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dawn/tests/DawnNativeTest.h"
#include "dawn/native/CreatePipelineAsyncTask.h"
#include "mocks/ComputePipelineMock.h"
#include "mocks/RenderPipelineMock.h"
class CreatePipelineAsyncTaskTests : public DawnNativeTest {};
// A regression test for a null pointer issue in CreateRenderPipelineAsyncTask::Run().
// See crbug.com/dawn/1310 for more details.
TEST_F(CreatePipelineAsyncTaskTests, InitializationErrorInCreateRenderPipelineAsync) {
dawn::native::DeviceBase* deviceBase =
reinterpret_cast<dawn::native::DeviceBase*>(device.Get());
Ref<dawn::native::RenderPipelineMock> renderPipelineMock =
AcquireRef(new dawn::native::RenderPipelineMock(deviceBase));
ON_CALL(*renderPipelineMock.Get(), Initialize)
.WillByDefault(testing::Return(testing::ByMove(
DAWN_MAKE_ERROR(dawn::native::InternalErrorType::Validation, "Initialization Error"))));
dawn::native::CreateRenderPipelineAsyncTask asyncTask(
renderPipelineMock,
[](WGPUCreatePipelineAsyncStatus status, WGPURenderPipeline returnPipeline,
const char* message, void* userdata) {
EXPECT_EQ(WGPUCreatePipelineAsyncStatus::WGPUCreatePipelineAsyncStatus_Error, status);
},
nullptr);
asyncTask.Run();
device.Tick();
EXPECT_CALL(*renderPipelineMock.Get(), DestroyImpl).Times(1);
}
// A regression test for a null pointer issue in CreateComputePipelineAsyncTask::Run().
// See crbug.com/dawn/1310 for more details.
TEST_F(CreatePipelineAsyncTaskTests, InitializationErrorInCreateComputePipelineAsync) {
dawn::native::DeviceBase* deviceBase =
reinterpret_cast<dawn::native::DeviceBase*>(device.Get());
Ref<dawn::native::ComputePipelineMock> computePipelineMock =
AcquireRef(new dawn::native::ComputePipelineMock(deviceBase));
ON_CALL(*computePipelineMock.Get(), Initialize)
.WillByDefault(testing::Return(testing::ByMove(
DAWN_MAKE_ERROR(dawn::native::InternalErrorType::Validation, "Initialization Error"))));
dawn::native::CreateComputePipelineAsyncTask asyncTask(
computePipelineMock,
[](WGPUCreatePipelineAsyncStatus status, WGPUComputePipeline returnPipeline,
const char* message, void* userdata) {
EXPECT_EQ(WGPUCreatePipelineAsyncStatus::WGPUCreatePipelineAsyncStatus_Error, status);
},
nullptr);
asyncTask.Run();
device.Tick();
EXPECT_CALL(*computePipelineMock.Get(), DestroyImpl).Times(1);
}