Use unique_ptr where applicable.

Change-Id: Icb29f6f9760f0ea36528e8ea6890713c2fb3b965
This commit is contained in:
Corentin Wallez 2018-09-06 15:26:48 +02:00 committed by Corentin Wallez
parent 21d8438ad6
commit cca9c698a0
13 changed files with 108 additions and 131 deletions

View File

@ -50,11 +50,10 @@ namespace dawn_native {
// DeviceBase // DeviceBase
DeviceBase::DeviceBase() { DeviceBase::DeviceBase() {
mCaches = new DeviceBase::Caches(); mCaches = std::make_unique<DeviceBase::Caches>();
} }
DeviceBase::~DeviceBase() { DeviceBase::~DeviceBase() {
delete mCaches;
} }
void DeviceBase::HandleError(const char* message) { void DeviceBase::HandleError(const char* message) {

View File

@ -21,6 +21,8 @@
#include "dawn_native/dawn_platform.h" #include "dawn_native/dawn_platform.h"
#include <memory>
namespace dawn_native { namespace dawn_native {
using ErrorCallback = void (*)(const char* errorMessage, void* userData); using ErrorCallback = void (*)(const char* errorMessage, void* userData);
@ -137,7 +139,7 @@ namespace dawn_native {
// The object caches aren't exposed in the header as they would require a lot of // The object caches aren't exposed in the header as they would require a lot of
// additional includes. // additional includes.
struct Caches; struct Caches;
Caches* mCaches = nullptr; std::unique_ptr<Caches> mCaches;
dawn::DeviceErrorCallback mErrorCallback = nullptr; dawn::DeviceErrorCallback mErrorCallback = nullptr;
dawn::CallbackUserdata mErrorUserdata = 0; dawn::CallbackUserdata mErrorUserdata = 0;

View File

@ -116,7 +116,7 @@ namespace dawn_native { namespace d3d12 {
} // anonymous namespace } // anonymous namespace
Device::Device() { Device::Device() {
mFunctions = new PlatformFunctions(); mFunctions = std::make_unique<PlatformFunctions>();
{ {
MaybeError status = mFunctions->LoadFunctions(); MaybeError status = mFunctions->LoadFunctions();
@ -124,10 +124,10 @@ namespace dawn_native { namespace d3d12 {
} }
// Create the connection to DXGI and the D3D12 device // Create the connection to DXGI and the D3D12 device
mFactory = CreateFactory(mFunctions); mFactory = CreateFactory(mFunctions.get());
ASSERT(mFactory.Get() != nullptr); ASSERT(mFactory.Get() != nullptr);
mHardwareAdapter = GetHardwareAdapter(mFactory, mFunctions); mHardwareAdapter = GetHardwareAdapter(mFactory, mFunctions.get());
ASSERT(mHardwareAdapter.Get() != nullptr); ASSERT(mHardwareAdapter.Get() != nullptr);
ASSERT_SUCCESS(mFunctions->d3d12CreateDevice(mHardwareAdapter.Get(), D3D_FEATURE_LEVEL_11_0, ASSERT_SUCCESS(mFunctions->d3d12CreateDevice(mHardwareAdapter.Get(), D3D_FEATURE_LEVEL_11_0,
@ -145,11 +145,11 @@ namespace dawn_native { namespace d3d12 {
ASSERT(mFenceEvent != nullptr); ASSERT(mFenceEvent != nullptr);
// Initialize backend services // Initialize backend services
mCommandAllocatorManager = new CommandAllocatorManager(this); mCommandAllocatorManager = std::make_unique<CommandAllocatorManager>(this);
mDescriptorHeapAllocator = new DescriptorHeapAllocator(this); mDescriptorHeapAllocator = std::make_unique<DescriptorHeapAllocator>(this);
mMapRequestTracker = new MapRequestTracker(this); mMapRequestTracker = std::make_unique<MapRequestTracker>(this);
mResourceAllocator = new ResourceAllocator(this); mResourceAllocator = std::make_unique<ResourceAllocator>(this);
mResourceUploader = new ResourceUploader(this); mResourceUploader = std::make_unique<ResourceUploader>(this);
NextSerial(); NextSerial();
} }
@ -159,22 +159,9 @@ namespace dawn_native { namespace d3d12 {
NextSerial(); NextSerial();
WaitForSerial(currentSerial); // Wait for all in-flight commands to finish executing WaitForSerial(currentSerial); // Wait for all in-flight commands to finish executing
TickImpl(); // Call tick one last time so resources are cleaned up TickImpl(); // Call tick one last time so resources are cleaned up
ASSERT(mUsedComObjectRefs.Empty()); ASSERT(mUsedComObjectRefs.Empty());
ASSERT(mPendingCommands.commandList == nullptr); ASSERT(mPendingCommands.commandList == nullptr);
// Free all D3D12 and DXGI objects before unloading the DLLs
mFence = nullptr;
mFactory = nullptr;
mHardwareAdapter = nullptr;
mD3d12Device = nullptr;
mCommandQueue = nullptr;
delete mCommandAllocatorManager;
delete mDescriptorHeapAllocator;
delete mMapRequestTracker;
delete mResourceAllocator;
delete mResourceUploader;
delete mFunctions;
} }
ComPtr<IDXGIFactory4> Device::GetFactory() { ComPtr<IDXGIFactory4> Device::GetFactory() {
@ -190,23 +177,23 @@ namespace dawn_native { namespace d3d12 {
} }
DescriptorHeapAllocator* Device::GetDescriptorHeapAllocator() { DescriptorHeapAllocator* Device::GetDescriptorHeapAllocator() {
return mDescriptorHeapAllocator; return mDescriptorHeapAllocator.get();
} }
const PlatformFunctions* Device::GetFunctions() { const PlatformFunctions* Device::GetFunctions() {
return mFunctions; return mFunctions.get();
} }
MapRequestTracker* Device::GetMapRequestTracker() const { MapRequestTracker* Device::GetMapRequestTracker() const {
return mMapRequestTracker; return mMapRequestTracker.get();
} }
ResourceAllocator* Device::GetResourceAllocator() { ResourceAllocator* Device::GetResourceAllocator() {
return mResourceAllocator; return mResourceAllocator.get();
} }
ResourceUploader* Device::GetResourceUploader() { ResourceUploader* Device::GetResourceUploader() {
return mResourceUploader; return mResourceUploader.get();
} }
void Device::OpenCommandList(ComPtr<ID3D12GraphicsCommandList>* commandList) { void Device::OpenCommandList(ComPtr<ID3D12GraphicsCommandList>* commandList) {

View File

@ -22,6 +22,8 @@
#include "dawn_native/d3d12/Forward.h" #include "dawn_native/d3d12/Forward.h"
#include "dawn_native/d3d12/d3d12_platform.h" #include "dawn_native/d3d12/d3d12_platform.h"
#include <memory>
namespace dawn_native { namespace d3d12 { namespace dawn_native { namespace d3d12 {
class CommandAllocatorManager; class CommandAllocatorManager;
@ -88,6 +90,10 @@ namespace dawn_native { namespace d3d12 {
const ShaderModuleDescriptor* descriptor) override; const ShaderModuleDescriptor* descriptor) override;
ResultOrError<TextureBase*> CreateTextureImpl(const TextureDescriptor* descriptor) override; ResultOrError<TextureBase*> CreateTextureImpl(const TextureDescriptor* descriptor) override;
// Keep mFunctions as the first member so that in the destructor it is freed. Otherwise the
// D3D12 DLLs are unloaded before we are done using it.
std::unique_ptr<PlatformFunctions> mFunctions;
uint64_t mSerial = 0; uint64_t mSerial = 0;
ComPtr<ID3D12Fence> mFence; ComPtr<ID3D12Fence> mFence;
HANDLE mFenceEvent; HANDLE mFenceEvent;
@ -97,19 +103,18 @@ namespace dawn_native { namespace d3d12 {
ComPtr<ID3D12Device> mD3d12Device; ComPtr<ID3D12Device> mD3d12Device;
ComPtr<ID3D12CommandQueue> mCommandQueue; ComPtr<ID3D12CommandQueue> mCommandQueue;
CommandAllocatorManager* mCommandAllocatorManager = nullptr;
DescriptorHeapAllocator* mDescriptorHeapAllocator = nullptr;
MapRequestTracker* mMapRequestTracker = nullptr;
PlatformFunctions* mFunctions = nullptr;
ResourceAllocator* mResourceAllocator = nullptr;
ResourceUploader* mResourceUploader = nullptr;
struct PendingCommandList { struct PendingCommandList {
ComPtr<ID3D12GraphicsCommandList> commandList; ComPtr<ID3D12GraphicsCommandList> commandList;
bool open = false; bool open = false;
} mPendingCommands; } mPendingCommands;
SerialQueue<ComPtr<IUnknown>> mUsedComObjectRefs; SerialQueue<ComPtr<IUnknown>> mUsedComObjectRefs;
std::unique_ptr<CommandAllocatorManager> mCommandAllocatorManager;
std::unique_ptr<DescriptorHeapAllocator> mDescriptorHeapAllocator;
std::unique_ptr<MapRequestTracker> mMapRequestTracker;
std::unique_ptr<ResourceAllocator> mResourceAllocator;
std::unique_ptr<ResourceUploader> mResourceUploader;
}; };
}} // namespace dawn_native::d3d12 }} // namespace dawn_native::d3d12

View File

@ -23,6 +23,8 @@
#import <Metal/Metal.h> #import <Metal/Metal.h>
#import <QuartzCore/CAMetalLayer.h> #import <QuartzCore/CAMetalLayer.h>
#include <memory>
#include <type_traits> #include <type_traits>
namespace dawn_native { namespace metal { namespace dawn_native { namespace metal {
@ -76,8 +78,8 @@ namespace dawn_native { namespace metal {
id<MTLDevice> mMtlDevice = nil; id<MTLDevice> mMtlDevice = nil;
id<MTLCommandQueue> mCommandQueue = nil; id<MTLCommandQueue> mCommandQueue = nil;
MapRequestTracker* mMapTracker; std::unique_ptr<MapRequestTracker> mMapTracker;
ResourceUploader* mResourceUploader; std::unique_ptr<ResourceUploader> mResourceUploader;
Serial mFinishedCommandSerial = 0; Serial mFinishedCommandSerial = 0;
Serial mPendingCommandSerial = 1; Serial mPendingCommandSerial = 1;

View File

@ -65,10 +65,7 @@ namespace dawn_native { namespace metal {
[mPendingCommands release]; [mPendingCommands release];
mPendingCommands = nil; mPendingCommands = nil;
delete mMapTracker;
mMapTracker = nullptr; mMapTracker = nullptr;
delete mResourceUploader;
mResourceUploader = nullptr; mResourceUploader = nullptr;
[mMtlDevice release]; [mMtlDevice release];
@ -189,11 +186,11 @@ namespace dawn_native { namespace metal {
} }
MapRequestTracker* Device::GetMapTracker() const { MapRequestTracker* Device::GetMapTracker() const {
return mMapTracker; return mMapTracker.get();
} }
ResourceUploader* Device::GetResourceUploader() const { ResourceUploader* Device::GetResourceUploader() const {
return mResourceUploader; return mResourceUploader.get();
} }
}} // namespace dawn_native::metal }} // namespace dawn_native::metal

View File

@ -146,11 +146,11 @@ namespace dawn_native { namespace vulkan {
GatherQueueFromDevice(); GatherQueueFromDevice();
mBufferUploader = new BufferUploader(this); mBufferUploader = std::make_unique<BufferUploader>(this);
mDeleter = new FencedDeleter(this); mDeleter = std::make_unique<FencedDeleter>(this);
mMapRequestTracker = new MapRequestTracker(this); mMapRequestTracker = std::make_unique<MapRequestTracker>(this);
mMemoryAllocator = new MemoryAllocator(this); mMemoryAllocator = std::make_unique<MemoryAllocator>(this);
mRenderPassCache = new RenderPassCache(this); mRenderPassCache = std::make_unique<RenderPassCache>(this);
} }
Device::~Device() { Device::~Device() {
@ -182,21 +182,14 @@ namespace dawn_native { namespace vulkan {
} }
mUnusedFences.clear(); mUnusedFences.clear();
delete mBufferUploader; // Free services explicitly so that they can free Vulkan objects before vkDestroyDevice
mBufferUploader = nullptr; mBufferUploader = nullptr;
delete mDeleter;
mDeleter = nullptr; mDeleter = nullptr;
delete mMapRequestTracker;
mMapRequestTracker = nullptr; mMapRequestTracker = nullptr;
delete mMemoryAllocator;
mMemoryAllocator = nullptr; mMemoryAllocator = nullptr;
// The VkRenderPasses in the cache can be destroyed immediately since all commands referring // The VkRenderPasses in the cache can be destroyed immediately since all commands referring
// to them are guaranteed to be finished executing. // to them are guaranteed to be finished executing.
delete mRenderPassCache;
mRenderPassCache = nullptr; mRenderPassCache = nullptr;
// VkQueues are destroyed when the VkDevice is destroyed // VkQueues are destroyed when the VkDevice is destroyed
@ -322,23 +315,23 @@ namespace dawn_native { namespace vulkan {
} }
MapRequestTracker* Device::GetMapRequestTracker() const { MapRequestTracker* Device::GetMapRequestTracker() const {
return mMapRequestTracker; return mMapRequestTracker.get();
} }
MemoryAllocator* Device::GetMemoryAllocator() const { MemoryAllocator* Device::GetMemoryAllocator() const {
return mMemoryAllocator; return mMemoryAllocator.get();
} }
BufferUploader* Device::GetBufferUploader() const { BufferUploader* Device::GetBufferUploader() const {
return mBufferUploader; return mBufferUploader.get();
} }
FencedDeleter* Device::GetFencedDeleter() const { FencedDeleter* Device::GetFencedDeleter() const {
return mDeleter; return mDeleter.get();
} }
RenderPassCache* Device::GetRenderPassCache() const { RenderPassCache* Device::GetRenderPassCache() const {
return mRenderPassCache; return mRenderPassCache.get();
} }
Serial Device::GetSerial() const { Serial Device::GetSerial() const {

View File

@ -25,6 +25,7 @@
#include "dawn_native/vulkan/VulkanFunctions.h" #include "dawn_native/vulkan/VulkanFunctions.h"
#include "dawn_native/vulkan/VulkanInfo.h" #include "dawn_native/vulkan/VulkanInfo.h"
#include <memory>
#include <queue> #include <queue>
namespace dawn_native { namespace vulkan { namespace dawn_native { namespace vulkan {
@ -123,11 +124,11 @@ namespace dawn_native { namespace vulkan {
VkQueue mQueue = VK_NULL_HANDLE; VkQueue mQueue = VK_NULL_HANDLE;
VkDebugReportCallbackEXT mDebugReportCallback = VK_NULL_HANDLE; VkDebugReportCallbackEXT mDebugReportCallback = VK_NULL_HANDLE;
BufferUploader* mBufferUploader = nullptr; std::unique_ptr<BufferUploader> mBufferUploader;
FencedDeleter* mDeleter = nullptr; std::unique_ptr<FencedDeleter> mDeleter;
MapRequestTracker* mMapRequestTracker = nullptr; std::unique_ptr<MapRequestTracker> mMapRequestTracker;
MemoryAllocator* mMemoryAllocator = nullptr; std::unique_ptr<MemoryAllocator> mMemoryAllocator;
RenderPassCache* mRenderPassCache = nullptr; std::unique_ptr<RenderPassCache> mRenderPassCache;
VkFence GetUnusedFence(); VkFence GetUnusedFence();
void CheckPassedFences(); void CheckPassedFences();

View File

@ -97,6 +97,8 @@ namespace {
}; };
} // namespace } // namespace
DawnTest::DawnTest() = default;
DawnTest::~DawnTest() { DawnTest::~DawnTest() {
// We need to destroy child objects before the Device // We need to destroy child objects before the Device
mReadbackSlots.clear(); mReadbackSlots.clear();
@ -104,9 +106,6 @@ DawnTest::~DawnTest() {
swapchain = dawn::SwapChain(); swapchain = dawn::SwapChain();
device = dawn::Device(); device = dawn::Device();
delete mBinding;
mBinding = nullptr;
dawnSetProcs(nullptr); dawnSetProcs(nullptr);
} }
@ -129,10 +128,10 @@ bool DawnTest::IsVulkan() const {
bool gTestUsesWire = false; bool gTestUsesWire = false;
void DawnTest::SetUp() { void DawnTest::SetUp() {
mBinding = utils::CreateBinding(ParamToBackendType(GetParam())); mBinding.reset(utils::CreateBinding(ParamToBackendType(GetParam())));
DAWN_ASSERT(mBinding != nullptr); DAWN_ASSERT(mBinding != nullptr);
GLFWwindow* testWindow = GetWindowForBackend(mBinding, GetParam()); GLFWwindow* testWindow = GetWindowForBackend(mBinding.get(), GetParam());
DAWN_ASSERT(testWindow != nullptr); DAWN_ASSERT(testWindow != nullptr);
mBinding->SetWindow(testWindow); mBinding->SetWindow(testWindow);
@ -145,16 +144,17 @@ void DawnTest::SetUp() {
dawnProcTable procs; dawnProcTable procs;
if (gTestUsesWire) { if (gTestUsesWire) {
mC2sBuf = new utils::TerribleCommandBuffer(); mC2sBuf = std::make_unique<utils::TerribleCommandBuffer>();
mS2cBuf = new utils::TerribleCommandBuffer(); mS2cBuf = std::make_unique<utils::TerribleCommandBuffer>();
mWireServer = dawn_wire::NewServerCommandHandler(backendDevice, backendProcs, mS2cBuf); mWireServer.reset(
mC2sBuf->SetHandler(mWireServer); dawn_wire::NewServerCommandHandler(backendDevice, backendProcs, mS2cBuf.get()));
mC2sBuf->SetHandler(mWireServer.get());
dawnDevice clientDevice; dawnDevice clientDevice;
dawnProcTable clientProcs; dawnProcTable clientProcs;
mWireClient = dawn_wire::NewClientDevice(&clientProcs, &clientDevice, mC2sBuf); mWireClient.reset(dawn_wire::NewClientDevice(&clientProcs, &clientDevice, mC2sBuf.get()));
mS2cBuf->SetHandler(mWireClient); mS2cBuf->SetHandler(mWireClient.get());
procs = clientProcs; procs = clientProcs;
cDevice = clientDevice; cDevice = clientDevice;
@ -191,18 +191,6 @@ void DawnTest::TearDown() {
for (size_t i = 0; i < mReadbackSlots.size(); ++i) { for (size_t i = 0; i < mReadbackSlots.size(); ++i) {
mReadbackSlots[i].buffer.Unmap(); mReadbackSlots[i].buffer.Unmap();
} }
for (auto& expectation : mDeferredExpectations) {
delete expectation.expectation;
expectation.expectation = nullptr;
}
if (gTestUsesWire) {
delete mC2sBuf;
delete mS2cBuf;
delete mWireClient;
delete mWireServer;
}
} }
std::ostringstream& DawnTest::AddBufferExpectation(const char* file, std::ostringstream& DawnTest::AddBufferExpectation(const char* file,
@ -232,7 +220,7 @@ std::ostringstream& DawnTest::AddBufferExpectation(const char* file,
deferred.size = size; deferred.size = size;
deferred.rowBytes = size; deferred.rowBytes = size;
deferred.rowPitch = size; deferred.rowPitch = size;
deferred.expectation = expectation; deferred.expectation.reset(expectation);
mDeferredExpectations.push_back(std::move(deferred)); mDeferredExpectations.push_back(std::move(deferred));
mDeferredExpectations.back().message = std::make_unique<std::ostringstream>(); mDeferredExpectations.back().message = std::make_unique<std::ostringstream>();
@ -273,7 +261,7 @@ std::ostringstream& DawnTest::AddTextureExpectation(const char* file,
deferred.size = size; deferred.size = size;
deferred.rowBytes = width * pixelSize; deferred.rowBytes = width * pixelSize;
deferred.rowPitch = rowPitch; deferred.rowPitch = rowPitch;
deferred.expectation = expectation; deferred.expectation.reset(expectation);
mDeferredExpectations.push_back(std::move(deferred)); mDeferredExpectations.push_back(std::move(deferred));
mDeferredExpectations.back().message = std::make_unique<std::ostringstream>(); mDeferredExpectations.back().message = std::make_unique<std::ostringstream>();

View File

@ -79,6 +79,7 @@ namespace dawn_wire {
class DawnTest : public ::testing::TestWithParam<BackendType> { class DawnTest : public ::testing::TestWithParam<BackendType> {
public: public:
DawnTest();
~DawnTest(); ~DawnTest();
void SetUp() override; void SetUp() override;
@ -118,10 +119,10 @@ class DawnTest : public ::testing::TestWithParam<BackendType> {
private: private:
// Things used to set up testing through the Wire. // Things used to set up testing through the Wire.
dawn_wire::CommandHandler* mWireServer = nullptr; std::unique_ptr<dawn_wire::CommandHandler> mWireServer;
dawn_wire::CommandHandler* mWireClient = nullptr; std::unique_ptr<dawn_wire::CommandHandler> mWireClient;
utils::TerribleCommandBuffer* mC2sBuf = nullptr; std::unique_ptr<utils::TerribleCommandBuffer> mC2sBuf;
utils::TerribleCommandBuffer* mS2cBuf = nullptr; std::unique_ptr<utils::TerribleCommandBuffer> mS2cBuf;
void FlushWire(); void FlushWire();
// MapRead buffers used to get data for the expectations // MapRead buffers used to get data for the expectations
@ -155,7 +156,7 @@ class DawnTest : public ::testing::TestWithParam<BackendType> {
uint32_t size; uint32_t size;
uint32_t rowBytes; uint32_t rowBytes;
uint32_t rowPitch; uint32_t rowPitch;
detail::Expectation* expectation; std::unique_ptr<detail::Expectation> expectation;
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=54316 // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=54316
// Use unique_ptr because of missing move/copy constructors on std::basic_ostringstream // Use unique_ptr because of missing move/copy constructors on std::basic_ostringstream
std::unique_ptr<std::ostringstream> message; std::unique_ptr<std::ostringstream> message;
@ -165,7 +166,7 @@ class DawnTest : public ::testing::TestWithParam<BackendType> {
// Assuming the data is mapped, checks all expectations // Assuming the data is mapped, checks all expectations
void ResolveExpectations(); void ResolveExpectations();
utils::BackendBinding* mBinding = nullptr; std::unique_ptr<utils::BackendBinding> mBinding;
}; };
// Instantiate the test once for each backend provided after the first argument. Use it like this: // Instantiate the test once for each backend provided after the first argument. Use it like this:

View File

@ -18,7 +18,6 @@
class BufferMapReadTests : public DawnTest { class BufferMapReadTests : public DawnTest {
protected: protected:
static void MapReadCallback(dawnBufferMapAsyncStatus status, const void* data, dawnCallbackUserdata userdata) { static void MapReadCallback(dawnBufferMapAsyncStatus status, const void* data, dawnCallbackUserdata userdata) {
ASSERT_EQ(DAWN_BUFFER_MAP_ASYNC_STATUS_SUCCESS, status); ASSERT_EQ(DAWN_BUFFER_MAP_ASYNC_STATUS_SUCCESS, status);
ASSERT_NE(nullptr, data); ASSERT_NE(nullptr, data);

View File

@ -19,6 +19,8 @@
#include "dawn_wire/Wire.h" #include "dawn_wire/Wire.h"
#include "utils/TerribleCommandBuffer.h" #include "utils/TerribleCommandBuffer.h"
#include <memory>
using namespace testing; using namespace testing;
using namespace dawn_wire; using namespace dawn_wire;
@ -76,7 +78,7 @@ class MockDeviceErrorCallback {
MOCK_METHOD2(Call, void(const char* message, dawnCallbackUserdata userdata)); MOCK_METHOD2(Call, void(const char* message, dawnCallbackUserdata userdata));
}; };
static MockDeviceErrorCallback* mockDeviceErrorCallback = nullptr; static std::unique_ptr<MockDeviceErrorCallback> mockDeviceErrorCallback;
static void ToMockDeviceErrorCallback(const char* message, dawnCallbackUserdata userdata) { static void ToMockDeviceErrorCallback(const char* message, dawnCallbackUserdata userdata) {
mockDeviceErrorCallback->Call(message, userdata); mockDeviceErrorCallback->Call(message, userdata);
} }
@ -86,7 +88,7 @@ class MockBuilderErrorCallback {
MOCK_METHOD4(Call, void(dawnBuilderErrorStatus status, const char* message, dawnCallbackUserdata userdata1, dawnCallbackUserdata userdata2)); MOCK_METHOD4(Call, void(dawnBuilderErrorStatus status, const char* message, dawnCallbackUserdata userdata1, dawnCallbackUserdata userdata2));
}; };
static MockBuilderErrorCallback* mockBuilderErrorCallback = nullptr; static std::unique_ptr<MockBuilderErrorCallback> mockBuilderErrorCallback;
static void ToMockBuilderErrorCallback(dawnBuilderErrorStatus status, const char* message, dawnCallbackUserdata userdata1, dawnCallbackUserdata userdata2) { static void ToMockBuilderErrorCallback(dawnBuilderErrorStatus status, const char* message, dawnCallbackUserdata userdata1, dawnCallbackUserdata userdata2) {
mockBuilderErrorCallback->Call(status, message, userdata1, userdata2); mockBuilderErrorCallback->Call(status, message, userdata1, userdata2);
} }
@ -96,7 +98,7 @@ class MockBufferMapReadCallback {
MOCK_METHOD3(Call, void(dawnBufferMapAsyncStatus status, const uint32_t* ptr, dawnCallbackUserdata userdata)); MOCK_METHOD3(Call, void(dawnBufferMapAsyncStatus status, const uint32_t* ptr, dawnCallbackUserdata userdata));
}; };
static MockBufferMapReadCallback* mockBufferMapReadCallback = nullptr; static std::unique_ptr<MockBufferMapReadCallback> mockBufferMapReadCallback;
static void ToMockBufferMapReadCallback(dawnBufferMapAsyncStatus status, const void* ptr, dawnCallbackUserdata userdata) { static void ToMockBufferMapReadCallback(dawnBufferMapAsyncStatus status, const void* ptr, dawnCallbackUserdata userdata) {
// Assume the data is uint32_t to make writing matchers easier // Assume the data is uint32_t to make writing matchers easier
mockBufferMapReadCallback->Call(status, static_cast<const uint32_t*>(ptr), userdata); mockBufferMapReadCallback->Call(status, static_cast<const uint32_t*>(ptr), userdata);
@ -107,7 +109,7 @@ class MockBufferMapWriteCallback {
MOCK_METHOD3(Call, void(dawnBufferMapAsyncStatus status, uint32_t* ptr, dawnCallbackUserdata userdata)); MOCK_METHOD3(Call, void(dawnBufferMapAsyncStatus status, uint32_t* ptr, dawnCallbackUserdata userdata));
}; };
static MockBufferMapWriteCallback* mockBufferMapWriteCallback = nullptr; static std::unique_ptr<MockBufferMapWriteCallback> mockBufferMapWriteCallback;
uint32_t* lastMapWritePointer = nullptr; uint32_t* lastMapWritePointer = nullptr;
static void ToMockBufferMapWriteCallback(dawnBufferMapAsyncStatus status, void* ptr, dawnCallbackUserdata userdata) { static void ToMockBufferMapWriteCallback(dawnBufferMapAsyncStatus status, void* ptr, dawnCallbackUserdata userdata) {
// Assume the data is uint32_t to make writing matchers easier // Assume the data is uint32_t to make writing matchers easier
@ -122,10 +124,10 @@ class WireTestsBase : public Test {
} }
void SetUp() override { void SetUp() override {
mockDeviceErrorCallback = new MockDeviceErrorCallback; mockDeviceErrorCallback = std::make_unique<MockDeviceErrorCallback>();
mockBuilderErrorCallback = new MockBuilderErrorCallback; mockBuilderErrorCallback = std::make_unique<MockBuilderErrorCallback>();
mockBufferMapReadCallback = new MockBufferMapReadCallback; mockBufferMapReadCallback = std::make_unique<MockBufferMapReadCallback>();
mockBufferMapWriteCallback = new MockBufferMapWriteCallback; mockBufferMapWriteCallback = std::make_unique<MockBufferMapWriteCallback>();
dawnProcTable mockProcs; dawnProcTable mockProcs;
dawnDevice mockDevice; dawnDevice mockDevice;
@ -138,30 +140,28 @@ class WireTestsBase : public Test {
} }
EXPECT_CALL(api, DeviceTick(_)).Times(AnyNumber()); EXPECT_CALL(api, DeviceTick(_)).Times(AnyNumber());
mS2cBuf = new utils::TerribleCommandBuffer(); mS2cBuf = std::make_unique<utils::TerribleCommandBuffer>();
mC2sBuf = new utils::TerribleCommandBuffer(mWireServer); mC2sBuf = std::make_unique<utils::TerribleCommandBuffer>(mWireServer.get());
mWireServer = NewServerCommandHandler(mockDevice, mockProcs, mS2cBuf); mWireServer.reset(NewServerCommandHandler(mockDevice, mockProcs, mS2cBuf.get()));
mC2sBuf->SetHandler(mWireServer); mC2sBuf->SetHandler(mWireServer.get());
dawnProcTable clientProcs; dawnProcTable clientProcs;
mWireClient = NewClientDevice(&clientProcs, &device, mC2sBuf); mWireClient.reset(NewClientDevice(&clientProcs, &device, mC2sBuf.get()));
dawnSetProcs(&clientProcs); dawnSetProcs(&clientProcs);
mS2cBuf->SetHandler(mWireClient); mS2cBuf->SetHandler(mWireClient.get());
apiDevice = mockDevice; apiDevice = mockDevice;
} }
void TearDown() override { void TearDown() override {
dawnSetProcs(nullptr); dawnSetProcs(nullptr);
delete mWireServer;
delete mWireClient; // Delete mocks so that expectations are checked
delete mC2sBuf; mockDeviceErrorCallback = nullptr;
delete mS2cBuf; mockBuilderErrorCallback = nullptr;
delete mockDeviceErrorCallback; mockBufferMapReadCallback = nullptr;
delete mockBuilderErrorCallback; mockBufferMapWriteCallback = nullptr;
delete mockBufferMapReadCallback;
delete mockBufferMapWriteCallback;
} }
void FlushClient() { void FlushClient() {
@ -179,10 +179,10 @@ class WireTestsBase : public Test {
private: private:
bool mIgnoreSetCallbackCalls = false; bool mIgnoreSetCallbackCalls = false;
CommandHandler* mWireServer = nullptr; std::unique_ptr<CommandHandler> mWireServer;
CommandHandler* mWireClient = nullptr; std::unique_ptr<CommandHandler> mWireClient;
utils::TerribleCommandBuffer* mS2cBuf = nullptr; std::unique_ptr<utils::TerribleCommandBuffer> mS2cBuf;
utils::TerribleCommandBuffer* mC2sBuf = nullptr; std::unique_ptr<utils::TerribleCommandBuffer> mC2sBuf;
}; };
class WireTests : public WireTestsBase { class WireTests : public WireTestsBase {

View File

@ -16,6 +16,8 @@
#include <gmock/gmock.h> #include <gmock/gmock.h>
#include <memory>
using namespace testing; using namespace testing;
class MockBufferMapReadCallback { class MockBufferMapReadCallback {
@ -23,7 +25,7 @@ class MockBufferMapReadCallback {
MOCK_METHOD3(Call, void(dawnBufferMapAsyncStatus status, const uint32_t* ptr, dawnCallbackUserdata userdata)); MOCK_METHOD3(Call, void(dawnBufferMapAsyncStatus status, const uint32_t* ptr, dawnCallbackUserdata userdata));
}; };
static MockBufferMapReadCallback* mockBufferMapReadCallback = nullptr; static std::unique_ptr<MockBufferMapReadCallback> mockBufferMapReadCallback;
static void ToMockBufferMapReadCallback(dawnBufferMapAsyncStatus status, const void* ptr, dawnCallbackUserdata userdata) { static void ToMockBufferMapReadCallback(dawnBufferMapAsyncStatus status, const void* ptr, dawnCallbackUserdata userdata) {
// Assume the data is uint32_t to make writing matchers easier // Assume the data is uint32_t to make writing matchers easier
mockBufferMapReadCallback->Call(status, reinterpret_cast<const uint32_t*>(ptr), userdata); mockBufferMapReadCallback->Call(status, reinterpret_cast<const uint32_t*>(ptr), userdata);
@ -34,7 +36,7 @@ class MockBufferMapWriteCallback {
MOCK_METHOD3(Call, void(dawnBufferMapAsyncStatus status, uint32_t* ptr, dawnCallbackUserdata userdata)); MOCK_METHOD3(Call, void(dawnBufferMapAsyncStatus status, uint32_t* ptr, dawnCallbackUserdata userdata));
}; };
static MockBufferMapWriteCallback* mockBufferMapWriteCallback = nullptr; static std::unique_ptr<MockBufferMapWriteCallback> mockBufferMapWriteCallback;
static void ToMockBufferMapWriteCallback(dawnBufferMapAsyncStatus status, void* ptr, dawnCallbackUserdata userdata) { static void ToMockBufferMapWriteCallback(dawnBufferMapAsyncStatus status, void* ptr, dawnCallbackUserdata userdata) {
// Assume the data is uint32_t to make writing matchers easier // Assume the data is uint32_t to make writing matchers easier
mockBufferMapWriteCallback->Call(status, reinterpret_cast<uint32_t*>(ptr), userdata); mockBufferMapWriteCallback->Call(status, reinterpret_cast<uint32_t*>(ptr), userdata);
@ -70,14 +72,15 @@ class BufferValidationTest : public ValidationTest {
void SetUp() override { void SetUp() override {
ValidationTest::SetUp(); ValidationTest::SetUp();
mockBufferMapReadCallback = new MockBufferMapReadCallback; mockBufferMapReadCallback = std::make_unique<MockBufferMapReadCallback>();
mockBufferMapWriteCallback = new MockBufferMapWriteCallback; mockBufferMapWriteCallback = std::make_unique<MockBufferMapWriteCallback>();
queue = device.CreateQueue(); queue = device.CreateQueue();
} }
void TearDown() override { void TearDown() override {
delete mockBufferMapReadCallback; // Delete mocks so that expectations are checked
delete mockBufferMapWriteCallback; mockBufferMapReadCallback = nullptr;
mockBufferMapWriteCallback = nullptr;
ValidationTest::TearDown(); ValidationTest::TearDown();
} }