diff --git a/generator/dawn_json_generator.py b/generator/dawn_json_generator.py index b7d9003e9d..02c572337c 100644 --- a/generator/dawn_json_generator.py +++ b/generator/dawn_json_generator.py @@ -687,6 +687,10 @@ class MultiGeneratorFromDawnJSON(Generator): renders.append( FileRender('dawn_proc.c', 'src/dawn/dawn_proc.c', [base_params, api_params])) + renders.append( + FileRender('dawn_thread_dispatch_proc.cpp', + 'src/dawn/dawn_thread_dispatch_proc.cpp', + [base_params, api_params])) if 'dawncpp' in targets: renders.append( diff --git a/generator/templates/dawn_native/ProcTable.cpp b/generator/templates/dawn_native/ProcTable.cpp index 88b780c907..defae2df3f 100644 --- a/generator/templates/dawn_native/ProcTable.cpp +++ b/generator/templates/dawn_native/ProcTable.cpp @@ -135,16 +135,17 @@ namespace dawn_native { return result; } - DawnProcTable GetProcsAutogen() { - DawnProcTable table; - table.getProcAddress = NativeGetProcAddress; - table.createInstance = NativeCreateInstance; + static DawnProcTable gProcTable = { + NativeGetProcAddress, + NativeCreateInstance, {% for type in by_category["object"] %} {% for method in c_methods(type) %} - table.{{as_varName(type.name, method.name)}} = Native{{as_MethodSuffix(type.name, method.name)}}; + Native{{as_MethodSuffix(type.name, method.name)}}, {% endfor %} {% endfor %} - return table; - } + }; + const DawnProcTable& GetProcsAutogen() { + return gProcTable; + } } diff --git a/generator/templates/dawn_proc_table.h b/generator/templates/dawn_proc_table.h index 197f3001bb..1da1f73a36 100644 --- a/generator/templates/dawn_proc_table.h +++ b/generator/templates/dawn_proc_table.h @@ -17,6 +17,7 @@ #include "dawn/webgpu.h" +// Note: Often allocated as a static global. Do not add a complex constructor. typedef struct DawnProcTable { WGPUProcGetProcAddress getProcAddress; WGPUProcCreateInstance createInstance; diff --git a/generator/templates/dawn_thread_dispatch_proc.cpp b/generator/templates/dawn_thread_dispatch_proc.cpp new file mode 100644 index 0000000000..bfc7794806 --- /dev/null +++ b/generator/templates/dawn_thread_dispatch_proc.cpp @@ -0,0 +1,52 @@ +#include "dawn/dawn_thread_dispatch_proc.h" + +#include + +static DawnProcTable nullProcs; +thread_local DawnProcTable perThreadProcs; + +void dawnProcSetPerThreadProcs(const DawnProcTable* procs) { + if (procs) { + perThreadProcs = *procs; + } else { + perThreadProcs = nullProcs; + } +} + +static WGPUProc ThreadDispatchGetProcAddress(WGPUDevice device, const char* procName) { + return perThreadProcs.getProcAddress(device, procName); +} + +static WGPUInstance ThreadDispatchCreateInstance(WGPUInstanceDescriptor const * descriptor) { + return perThreadProcs.createInstance(descriptor); +} + +{% for type in by_category["object"] %} + {% for method in c_methods(type) %} + static {{as_cType(method.return_type.name)}} ThreadDispatch{{as_MethodSuffix(type.name, method.name)}}( + {{-as_cType(type.name)}} {{as_varName(type.name)}} + {%- for arg in method.arguments -%} + , {{as_annotated_cType(arg)}} + {%- endfor -%} + ) { + {% if method.return_type.name.canonical_case() != "void" %}return {% endif %} + perThreadProcs.{{as_varName(type.name, method.name)}}({{as_varName(type.name)}} + {%- for arg in method.arguments -%} + , {{as_varName(arg.name)}} + {%- endfor -%} + ); + } + {% endfor %} +{% endfor %} + +extern "C" { + DawnProcTable dawnThreadDispatchProcTable = { + ThreadDispatchGetProcAddress, + ThreadDispatchCreateInstance, +{% for type in by_category["object"] %} + {% for method in c_methods(type) %} + ThreadDispatch{{as_MethodSuffix(type.name, method.name)}}, + {% endfor %} +{% endfor %} + }; +} diff --git a/generator/templates/dawn_wire/client/ApiProcs.cpp b/generator/templates/dawn_wire/client/ApiProcs.cpp index 3edba1a22b..0ad9a77297 100644 --- a/generator/templates/dawn_wire/client/ApiProcs.cpp +++ b/generator/templates/dawn_wire/client/ApiProcs.cpp @@ -293,21 +293,16 @@ namespace dawn_wire { namespace client { return result; } - //* Some commands don't have a custom wire format, but need to be handled manually to update - //* some client-side state tracking. For these we have two functions: - //* - An autogenerated Client{{suffix}} method that sends the command on the wire - //* - A manual ProxyClient{{suffix}} method that will be inserted in the proctable instead of - //* the autogenerated one, and that will have to call Client{{suffix}} - DawnProcTable GetProcs() { - DawnProcTable table; - table.getProcAddress = ClientGetProcAddress; - table.createInstance = ClientCreateInstance; + static DawnProcTable gProcTable = { + ClientGetProcAddress, + ClientCreateInstance, {% for type in by_category["object"] %} {% for method in c_methods(type) %} - {% set suffix = as_MethodSuffix(type.name, method.name) %} - table.{{as_varName(type.name, method.name)}} = Client{{suffix}}; + Client{{as_MethodSuffix(type.name, method.name)}}, {% endfor %} {% endfor %} - return table; + }; + const DawnProcTable& GetProcs() { + return gProcTable; } }} // namespace dawn_wire::client diff --git a/src/dawn/BUILD.gn b/src/dawn/BUILD.gn index 9034be436f..ad48712844 100644 --- a/src/dawn/BUILD.gn +++ b/src/dawn/BUILD.gn @@ -87,7 +87,10 @@ source_set("dawncpp") { dawn_json_generator("dawn_proc_gen") { target = "dawn_proc" - outputs = [ "src/dawn/dawn_proc.c" ] + outputs = [ + "src/dawn/dawn_proc.c", + "src/dawn/dawn_thread_dispatch_proc.cpp", + ] } dawn_component("dawn_proc") { @@ -96,5 +99,8 @@ dawn_component("dawn_proc") { public_deps = [ ":dawn_headers" ] deps = [ ":dawn_proc_gen" ] sources = get_target_outputs(":dawn_proc_gen") - sources += [ "${dawn_root}/src/include/dawn/dawn_proc.h" ] + sources += [ + "${dawn_root}/src/include/dawn/dawn_proc.h", + "${dawn_root}/src/include/dawn/dawn_thread_dispatch_proc.h", + ] } diff --git a/src/dawn_native/DawnNative.cpp b/src/dawn_native/DawnNative.cpp index bfc47db68d..52efa18f95 100644 --- a/src/dawn_native/DawnNative.cpp +++ b/src/dawn_native/DawnNative.cpp @@ -22,9 +22,9 @@ namespace dawn_native { - DawnProcTable GetProcsAutogen(); + const DawnProcTable& GetProcsAutogen(); - DawnProcTable GetProcs() { + const DawnProcTable& GetProcs() { return GetProcsAutogen(); } diff --git a/src/dawn_wire/WireClient.cpp b/src/dawn_wire/WireClient.cpp index e6fe263767..430a55cb97 100644 --- a/src/dawn_wire/WireClient.cpp +++ b/src/dawn_wire/WireClient.cpp @@ -26,7 +26,7 @@ namespace dawn_wire { } // static - DawnProcTable WireClient::GetProcs() { + const DawnProcTable& WireClient::GetProcs() { return client::GetProcs(); } diff --git a/src/dawn_wire/client/Client.h b/src/dawn_wire/client/Client.h index d8df86d8c9..be70e75873 100644 --- a/src/dawn_wire/client/Client.h +++ b/src/dawn_wire/client/Client.h @@ -68,7 +68,7 @@ namespace dawn_wire { namespace client { bool mIsDisconnected = false; }; - DawnProcTable GetProcs(); + const DawnProcTable& GetProcs(); std::unique_ptr CreateInlineMemoryTransferService(); diff --git a/src/include/dawn/dawn_thread_dispatch_proc.h b/src/include/dawn/dawn_thread_dispatch_proc.h new file mode 100644 index 0000000000..4d08ba8adc --- /dev/null +++ b/src/include/dawn/dawn_thread_dispatch_proc.h @@ -0,0 +1,33 @@ +// Copyright 2020 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DAWN_DAWN_THREAD_DISPATCH_PROC_H_ +#define DAWN_DAWN_THREAD_DISPATCH_PROC_H_ + +#include "dawn/dawn_proc.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// Call dawnProcSetProcs(&dawnThreadDispatchProcTable) and then use dawnProcSetPerThreadProcs +// to set per-thread procs. +WGPU_EXPORT extern DawnProcTable dawnThreadDispatchProcTable; +WGPU_EXPORT void dawnProcSetPerThreadProcs(const DawnProcTable* procs); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // DAWN_DAWN_THREAD_DISPATCH_PROC_H_ diff --git a/src/include/dawn_native/DawnNative.h b/src/include/dawn_native/DawnNative.h index 2199efa85b..6498646e49 100644 --- a/src/include/dawn_native/DawnNative.h +++ b/src/include/dawn_native/DawnNative.h @@ -170,7 +170,7 @@ namespace dawn_native { }; // Backend-agnostic API for dawn_native - DAWN_NATIVE_EXPORT DawnProcTable GetProcs(); + DAWN_NATIVE_EXPORT const DawnProcTable& GetProcs(); // Query the names of all the toggles that are enabled in device DAWN_NATIVE_EXPORT std::vector GetTogglesUsed(WGPUDevice device); diff --git a/src/include/dawn_wire/WireClient.h b/src/include/dawn_wire/WireClient.h index 815b66b877..50da913462 100644 --- a/src/include/dawn_wire/WireClient.h +++ b/src/include/dawn_wire/WireClient.h @@ -26,6 +26,8 @@ namespace dawn_wire { namespace client { class Client; class MemoryTransferService; + + DAWN_WIRE_EXPORT const DawnProcTable& GetProcs(); } // namespace client struct ReservedTexture { @@ -44,7 +46,8 @@ namespace dawn_wire { WireClient(const WireClientDescriptor& descriptor); ~WireClient() override; - static DawnProcTable GetProcs(); + // TODO(enga): Remove this and use dawn_wire::client::GetProcs() instead + static const DawnProcTable& GetProcs(); WGPUDevice GetDevice() const; const volatile char* HandleCommands(const volatile char* commands, diff --git a/src/tests/BUILD.gn b/src/tests/BUILD.gn index 0077695768..9c998f5609 100644 --- a/src/tests/BUILD.gn +++ b/src/tests/BUILD.gn @@ -168,6 +168,7 @@ test("dawn_unittests") { "unittests/MathTests.cpp", "unittests/ObjectBaseTests.cpp", "unittests/PerStageTests.cpp", + "unittests/PerThreadProcTests.cpp", "unittests/PlacementAllocatedTests.cpp", "unittests/RefCountedTests.cpp", "unittests/ResultTests.cpp", diff --git a/src/tests/DawnTest.cpp b/src/tests/DawnTest.cpp index 5873da83d2..945d9c1668 100644 --- a/src/tests/DawnTest.cpp +++ b/src/tests/DawnTest.cpp @@ -744,12 +744,9 @@ void DawnTestBase::SetUp() { clientDesc.serializer = mC2sBuf.get(); mWireClient.reset(new dawn_wire::WireClient(clientDesc)); - WGPUDevice clientDevice = mWireClient->GetDevice(); - DawnProcTable clientProcs = dawn_wire::WireClient::GetProcs(); + cDevice = mWireClient->GetDevice(); + procs = dawn_wire::client::GetProcs(); mS2cBuf->SetHandler(mWireClient.get()); - - procs = clientProcs; - cDevice = clientDevice; } else { procs = backendProcs; cDevice = backendDevice; diff --git a/src/tests/end2end/WindowSurfaceTests.cpp b/src/tests/end2end/WindowSurfaceTests.cpp index 8aaa8ccc23..03e954e5eb 100644 --- a/src/tests/end2end/WindowSurfaceTests.cpp +++ b/src/tests/end2end/WindowSurfaceTests.cpp @@ -50,8 +50,7 @@ class WindowSurfaceInstanceTests : public testing::Test { }); DAWN_SKIP_TEST_IF(!glfwInit()); - DawnProcTable procs = dawn_native::GetProcs(); - dawnProcSetProcs(&procs); + dawnProcSetProcs(&dawn_native::GetProcs()); mInstance = wgpu::CreateInstance(); } diff --git a/src/tests/unittests/GetProcAddressTests.cpp b/src/tests/unittests/GetProcAddressTests.cpp index f5ac8c699a..10f9c5ca51 100644 --- a/src/tests/unittests/GetProcAddressTests.cpp +++ b/src/tests/unittests/GetProcAddressTests.cpp @@ -72,7 +72,7 @@ namespace { mWireClient = std::make_unique(clientDesc); mDevice = wgpu::Device::Acquire(mWireClient->GetDevice()); - mProcs = dawn_wire::WireClient::GetProcs(); + mProcs = dawn_wire::client::GetProcs(); break; } diff --git a/src/tests/unittests/PerThreadProcTests.cpp b/src/tests/unittests/PerThreadProcTests.cpp new file mode 100644 index 0000000000..38ce981eab --- /dev/null +++ b/src/tests/unittests/PerThreadProcTests.cpp @@ -0,0 +1,118 @@ +// Copyright 2020 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dawn/dawn_thread_dispatch_proc.h" +#include "dawn/webgpu_cpp.h" +#include "dawn_native/DawnNative.h" +#include "dawn_native/Instance.h" +#include "dawn_native/null/DeviceNull.h" + +#include +#include +#include + +class PerThreadProcTests : public testing::Test { + public: + PerThreadProcTests() + : mNativeInstance(dawn_native::InstanceBase::Create()), + mNativeAdapter(mNativeInstance.Get()) { + } + ~PerThreadProcTests() override = default; + + protected: + Ref mNativeInstance; + dawn_native::null::Adapter mNativeAdapter; +}; + +// Test that procs can be set per thread. This test overrides deviceCreateBuffer with a dummy proc +// for each thread that increments a counter. Because each thread has their own proc and counter, +// there should be no data races. The per-thread procs also check that the current thread id is +// exactly equal to the expected thread id. +TEST_F(PerThreadProcTests, DispatchesPerThread) { + dawnProcSetProcs(&dawnThreadDispatchProcTable); + + // Threads will block on this atomic to be sure we set procs on both threads before + // either thread calls the procs. + std::atomic ready(false); + + static int threadACounter = 0; + static int threadBCounter = 0; + + static std::atomic threadIdA; + static std::atomic threadIdB; + + constexpr int kThreadATargetCount = 28347; + constexpr int kThreadBTargetCount = 40420; + + // Note: Acquire doesn't call reference or release. + wgpu::Device deviceA = + wgpu::Device::Acquire(reinterpret_cast(mNativeAdapter.CreateDevice(nullptr))); + + wgpu::Device deviceB = + wgpu::Device::Acquire(reinterpret_cast(mNativeAdapter.CreateDevice(nullptr))); + + std::thread threadA([&]() { + DawnProcTable procs = dawn_native::GetProcs(); + procs.deviceCreateBuffer = [](WGPUDevice device, + WGPUBufferDescriptor const* descriptor) -> WGPUBuffer { + EXPECT_EQ(std::this_thread::get_id(), threadIdA); + threadACounter++; + return nullptr; + }; + dawnProcSetPerThreadProcs(&procs); + + while (!ready) { + } // Should be fast, so just spin. + + for (int i = 0; i < kThreadATargetCount; ++i) { + deviceA.CreateBuffer(nullptr); + } + + deviceA = nullptr; + dawnProcSetPerThreadProcs(nullptr); + }); + + std::thread threadB([&]() { + DawnProcTable procs = dawn_native::GetProcs(); + procs.deviceCreateBuffer = [](WGPUDevice device, + WGPUBufferDescriptor const* bufferDesc) -> WGPUBuffer { + EXPECT_EQ(std::this_thread::get_id(), threadIdB); + threadBCounter++; + return nullptr; + }; + dawnProcSetPerThreadProcs(&procs); + + while (!ready) { + } // Should be fast, so just spin. + + for (int i = 0; i < kThreadBTargetCount; ++i) { + deviceB.CreateBuffer(nullptr); + } + + deviceB = nullptr; + dawnProcSetPerThreadProcs(nullptr); + }); + + threadIdA = threadA.get_id(); + threadIdB = threadB.get_id(); + + ready = true; + threadA.join(); + threadB.join(); + + EXPECT_EQ(threadACounter, kThreadATargetCount); + EXPECT_EQ(threadBCounter, kThreadBTargetCount); + + dawnProcSetProcs(nullptr); +} diff --git a/src/tests/unittests/validation/ValidationTest.cpp b/src/tests/unittests/validation/ValidationTest.cpp index 058fd3b921..1713171626 100644 --- a/src/tests/unittests/validation/ValidationTest.cpp +++ b/src/tests/unittests/validation/ValidationTest.cpp @@ -40,8 +40,7 @@ ValidationTest::ValidationTest() { ASSERT(foundNullAdapter); - DawnProcTable procs = dawn_native::GetProcs(); - dawnProcSetProcs(&procs); + dawnProcSetProcs(&dawn_native::GetProcs()); device = CreateDeviceFromAdapter(adapter, std::vector()); } diff --git a/src/tests/unittests/wire/WireMultipleDeviceTests.cpp b/src/tests/unittests/wire/WireMultipleDeviceTests.cpp index 3ba0da0ab4..b442794d56 100644 --- a/src/tests/unittests/wire/WireMultipleDeviceTests.cpp +++ b/src/tests/unittests/wire/WireMultipleDeviceTests.cpp @@ -29,8 +29,7 @@ using namespace dawn_wire; class WireMultipleDeviceTests : public testing::Test { protected: void SetUp() override { - DawnProcTable procs = dawn_wire::WireClient::GetProcs(); - dawnProcSetProcs(&procs); + dawnProcSetProcs(&dawn_wire::client::GetProcs()); } void TearDown() override { diff --git a/src/tests/unittests/wire/WireTest.cpp b/src/tests/unittests/wire/WireTest.cpp index bdac99f2f7..d23709c88c 100644 --- a/src/tests/unittests/wire/WireTest.cpp +++ b/src/tests/unittests/wire/WireTest.cpp @@ -66,8 +66,7 @@ void WireTest::SetUp() { mS2cBuf->SetHandler(mWireClient.get()); device = mWireClient->GetDevice(); - DawnProcTable clientProcs = dawn_wire::WireClient::GetProcs(); - dawnProcSetProcs(&clientProcs); + dawnProcSetProcs(&dawn_wire::client::GetProcs()); apiDevice = mockDevice;