diff --git a/src/tests/DawnTest.cpp b/src/tests/DawnTest.cpp index 9e2cef3558..2db8691234 100644 --- a/src/tests/DawnTest.cpp +++ b/src/tests/DawnTest.cpp @@ -29,6 +29,7 @@ #include "utils/TerribleCommandBuffer.h" #include "utils/TestUtils.h" #include "utils/WGPUHelpers.h" +#include "utils/WireHelper.h" #include #include @@ -309,14 +310,7 @@ void DawnTestEnvironment::ParseArgs(int argc, char** argv) { constexpr const char kWireTraceDirArg[] = "--wire-trace-dir="; argLen = sizeof(kWireTraceDirArg) - 1; if (strncmp(argv[i], kWireTraceDirArg, argLen) == 0) { - const char* wireTraceDir = argv[i] + argLen; - if (wireTraceDir[0] != '\0') { - const char* sep = GetPathSeparator(); - mWireTraceDir = wireTraceDir; - if (mWireTraceDir.back() != *sep) { - mWireTraceDir += sep; - } - } + mWireTraceDir = argv[i] + argLen; continue; } @@ -597,26 +591,11 @@ const std::vector& DawnTestEnvironment::GetDisabledToggles() const return mDisabledToggles; } -class WireServerTraceLayer : public dawn_wire::CommandHandler { - public: - WireServerTraceLayer(const char* file, dawn_wire::CommandHandler* handler) - : dawn_wire::CommandHandler(), mHandler(handler) { - mFile.open(file, std::ios_base::out | std::ios_base::binary | std::ios_base::trunc); - } - - const volatile char* HandleCommands(const volatile char* commands, size_t size) override { - mFile.write(const_cast(commands), size); - return mHandler->HandleCommands(commands, size); - } - - private: - dawn_wire::CommandHandler* mHandler; - std::ofstream mFile; -}; - // Implementation of DawnTest -DawnTestBase::DawnTestBase(const AdapterTestParam& param) : mParam(param) { +DawnTestBase::DawnTestBase(const AdapterTestParam& param) + : mParam(param), + mWireHelper(utils::CreateWireHelper(gTestEnv->UsesWire(), gTestEnv->GetWireTraceDir())) { } DawnTestBase::~DawnTestBase() { @@ -625,13 +604,7 @@ DawnTestBase::~DawnTestBase() { queue = wgpu::Queue(); device = wgpu::Device(); - mWireClient = nullptr; - mWireServer = nullptr; - if (gTestEnv->UsesWire()) { - backendProcs.deviceRelease(backendDevice); - } - - dawnProcSetProcs(nullptr); + mWireHelper.reset(); } bool DawnTestBase::IsD3D12() const { @@ -840,58 +813,15 @@ void DawnTestBase::SetUp() { deviceDescriptor.forceDisabledToggles.push_back(info->name); } - backendDevice = mBackendAdapter.CreateDevice(&deviceDescriptor); + std::tie(device, backendDevice) = + mWireHelper->RegisterDevice(mBackendAdapter.CreateDevice(&deviceDescriptor)); ASSERT_NE(nullptr, backendDevice); - backendProcs = dawn_native::GetProcs(); + std::string traceName = + std::string(::testing::UnitTest::GetInstance()->current_test_info()->test_suite_name()) + + "_" + ::testing::UnitTest::GetInstance()->current_test_info()->name(); + mWireHelper->BeginWireTrace(traceName.c_str()); - // Choose whether to use the backend procs and devices directly, or set up the wire. - WGPUDevice cDevice = nullptr; - DawnProcTable procs; - - if (gTestEnv->UsesWire()) { - mC2sBuf = std::make_unique(); - mS2cBuf = std::make_unique(); - - dawn_wire::WireServerDescriptor serverDesc = {}; - serverDesc.device = backendDevice; - serverDesc.procs = &backendProcs; - serverDesc.serializer = mS2cBuf.get(); - - mWireServer.reset(new dawn_wire::WireServer(serverDesc)); - mC2sBuf->SetHandler(mWireServer.get()); - - if (gTestEnv->GetWireTraceDir() != nullptr) { - std::string file = - std::string( - ::testing::UnitTest::GetInstance()->current_test_info()->test_suite_name()) + - "_" + ::testing::UnitTest::GetInstance()->current_test_info()->name(); - // Replace slashes in gtest names with underscores so everything is in one directory. - std::replace(file.begin(), file.end(), '/', '_'); - - std::string fullPath = gTestEnv->GetWireTraceDir() + file; - - mWireServerTraceLayer.reset( - new WireServerTraceLayer(fullPath.c_str(), mWireServer.get())); - mC2sBuf->SetHandler(mWireServerTraceLayer.get()); - } - - dawn_wire::WireClientDescriptor clientDesc = {}; - clientDesc.serializer = mC2sBuf.get(); - - mWireClient.reset(new dawn_wire::WireClient(clientDesc)); - cDevice = mWireClient->GetDevice(); - procs = dawn_wire::client::GetProcs(); - mS2cBuf->SetHandler(mWireClient.get()); - } else { - procs = backendProcs; - cDevice = backendDevice; - } - - // Set up the device and queue because all tests need them, and DawnTestBase needs them too for - // the deferred expectations. - dawnProcSetProcs(&procs); - device = wgpu::Device::Acquire(cDevice); queue = device.GetDefaultQueue(); device.SetUncapturedErrorCallback(OnDeviceError, this); @@ -1050,8 +980,8 @@ void DawnTestBase::WaitABit() { void DawnTestBase::FlushWire() { if (gTestEnv->UsesWire()) { - bool C2SFlushed = mC2sBuf->Flush(); - bool S2CFlushed = mS2cBuf->Flush(); + bool C2SFlushed = mWireHelper->FlushClient(); + bool S2CFlushed = mWireHelper->FlushServer(); ASSERT(C2SFlushed); ASSERT(S2CFlushed); } diff --git a/src/tests/DawnTest.h b/src/tests/DawnTest.h index 6d90c94152..b185c04b1d 100644 --- a/src/tests/DawnTest.h +++ b/src/tests/DawnTest.h @@ -171,6 +171,7 @@ struct GLFWwindow; namespace utils { class PlatformDebugLogger; class TerribleCommandBuffer; + class WireHelper; } // namespace utils namespace detail { @@ -376,14 +377,7 @@ class DawnTestBase { private: AdapterTestParam mParam; - - // Things used to set up testing through the Wire. - std::unique_ptr mWireServer; - std::unique_ptr mWireClient; - std::unique_ptr mC2sBuf; - std::unique_ptr mS2cBuf; - - std::unique_ptr mWireServerTraceLayer; + std::unique_ptr mWireHelper; // Tracking for validation errors static void OnDeviceError(WGPUErrorType type, const char* message, void* userdata); diff --git a/src/utils/BUILD.gn b/src/utils/BUILD.gn index 38fd58797b..1884d8ffb1 100644 --- a/src/utils/BUILD.gn +++ b/src/utils/BUILD.gn @@ -82,9 +82,12 @@ static_library("dawn_utils") { "Timer.h", "WGPUHelpers.cpp", "WGPUHelpers.h", + "WireHelper.cpp", + "WireHelper.h", ] deps = [ "${dawn_root}/src/common", + "${dawn_root}/src/dawn:dawn_proc", "${dawn_root}/src/dawn_native", "${dawn_root}/src/dawn_wire", "${dawn_shaderc_dir}:libshaderc", diff --git a/src/utils/CMakeLists.txt b/src/utils/CMakeLists.txt index e215f51637..f553d97438 100644 --- a/src/utils/CMakeLists.txt +++ b/src/utils/CMakeLists.txt @@ -34,12 +34,15 @@ target_sources(dawn_utils PRIVATE "Timer.h" "WGPUHelpers.cpp" "WGPUHelpers.h" + "WireHelper.cpp" + "WireHelper.h" ) target_link_libraries(dawn_utils PUBLIC dawncpp_headers PRIVATE dawn_internal_config dawn_common dawn_native + dawn_proc dawn_wire shaderc glfw diff --git a/src/utils/WireHelper.cpp b/src/utils/WireHelper.cpp new file mode 100644 index 0000000000..be420cab29 --- /dev/null +++ b/src/utils/WireHelper.cpp @@ -0,0 +1,170 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "utils/WireHelper.h" + +#include "common/Assert.h" +#include "common/Log.h" +#include "common/SystemUtils.h" +#include "dawn/dawn_proc.h" +#include "dawn_native/DawnNative.h" +#include "dawn_wire/WireClient.h" +#include "dawn_wire/WireServer.h" +#include "utils/TerribleCommandBuffer.h" + +#include +#include +#include +#include + +namespace utils { + + namespace { + + class WireServerTraceLayer : public dawn_wire::CommandHandler { + public: + WireServerTraceLayer(const char* dir, dawn_wire::CommandHandler* handler) + : dawn_wire::CommandHandler(), mDir(dir), mHandler(handler) { + const char* sep = GetPathSeparator(); + if (mDir.back() != *sep) { + mDir += sep; + } + } + + void BeginWireTrace(const char* name) { + std::string filename = name; + // Replace slashes in gtest names with underscores so everything is in one + // directory. + std::replace(filename.begin(), filename.end(), '/', '_'); + std::replace(filename.begin(), filename.end(), '\\', '_'); + + // Prepend the filename with the directory. + filename = mDir + filename; + + ASSERT(!mFile.is_open()); + mFile.open(filename, + std::ios_base::out | std::ios_base::binary | std::ios_base::trunc); + } + + const volatile char* HandleCommands(const volatile char* commands, + size_t size) override { + if (mFile.is_open()) { + mFile.write(const_cast(commands), size); + } + return mHandler->HandleCommands(commands, size); + } + + private: + std::string mDir; + dawn_wire::CommandHandler* mHandler; + std::ofstream mFile; + }; + + class WireHelperDirect : public WireHelper { + public: + WireHelperDirect() { + dawnProcSetProcs(&dawn_native::GetProcs()); + } + + std::pair RegisterDevice(WGPUDevice backendDevice) override { + ASSERT(backendDevice != nullptr); + return std::make_pair(wgpu::Device::Acquire(backendDevice), backendDevice); + } + + void BeginWireTrace(const char* name) override { + } + + bool FlushClient() override { + return true; + } + + bool FlushServer() override { + return true; + } + }; + + class WireHelperProxy : public WireHelper { + public: + explicit WireHelperProxy(const char* wireTraceDir) { + mC2sBuf = std::make_unique(); + mS2cBuf = std::make_unique(); + + dawn_wire::WireServerDescriptor serverDesc = {}; + serverDesc.procs = &dawn_native::GetProcs(); + serverDesc.serializer = mS2cBuf.get(); + + mWireServer.reset(new dawn_wire::WireServer(serverDesc)); + mC2sBuf->SetHandler(mWireServer.get()); + + if (wireTraceDir != nullptr && strlen(wireTraceDir) > 0) { + mWireServerTraceLayer.reset( + new WireServerTraceLayer(wireTraceDir, mWireServer.get())); + mC2sBuf->SetHandler(mWireServerTraceLayer.get()); + } + + dawn_wire::WireClientDescriptor clientDesc = {}; + clientDesc.serializer = mC2sBuf.get(); + + mWireClient.reset(new dawn_wire::WireClient(clientDesc)); + mS2cBuf->SetHandler(mWireClient.get()); + dawnProcSetProcs(&dawn_wire::client::GetProcs()); + } + + std::pair RegisterDevice(WGPUDevice backendDevice) override { + ASSERT(backendDevice != nullptr); + + auto reservation = mWireClient->ReserveDevice(); + mWireServer->InjectDevice(backendDevice, reservation.id, reservation.generation); + dawn_native::GetProcs().deviceRelease(backendDevice); + + return std::make_pair(wgpu::Device::Acquire(reservation.device), backendDevice); + } + + void BeginWireTrace(const char* name) override { + if (mWireServerTraceLayer) { + return mWireServerTraceLayer->BeginWireTrace(name); + } + } + + bool FlushClient() override { + return mC2sBuf->Flush(); + } + + bool FlushServer() override { + return mS2cBuf->Flush(); + } + + private: + std::unique_ptr mC2sBuf; + std::unique_ptr mS2cBuf; + std::unique_ptr mWireServerTraceLayer; + std::unique_ptr mWireServer; + std::unique_ptr mWireClient; + }; + + } // anonymous namespace + + std::unique_ptr CreateWireHelper(bool useWire, const char* wireTraceDir) { + if (useWire) { + return std::unique_ptr(new WireHelperProxy(wireTraceDir)); + } else { + return std::unique_ptr(new WireHelperDirect()); + } + } + + WireHelper::~WireHelper() { + dawnProcSetProcs(nullptr); + } + +} // namespace utils diff --git a/src/utils/WireHelper.h b/src/utils/WireHelper.h new file mode 100644 index 0000000000..78aa802616 --- /dev/null +++ b/src/utils/WireHelper.h @@ -0,0 +1,44 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef UTILS_WIREHELPER_H_ +#define UTILS_WIREHELPER_H_ + +#include "dawn/webgpu_cpp.h" + +#include +#include + +namespace utils { + + class WireHelper { + public: + virtual ~WireHelper(); + + // Registers the device on the wire, if present. + // Returns a pair of the client device and backend device. + // The function should take ownership of |backendDevice|. + virtual std::pair RegisterDevice(WGPUDevice backendDevice) = 0; + + virtual void BeginWireTrace(const char* name) = 0; + + virtual bool FlushClient() = 0; + virtual bool FlushServer() = 0; + }; + + std::unique_ptr CreateWireHelper(bool useWire, const char* wireTraceDir = nullptr); + +} // namespace utils + +#endif // UTILS_WIREHELPER_H_