diff --git a/BUILD.gn b/BUILD.gn index 111ada7e31..0d06dafd42 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -773,6 +773,7 @@ test("dawn_unittests") { "src/tests/unittests/EnumClassBitmasksTests.cpp", "src/tests/unittests/ErrorTests.cpp", "src/tests/unittests/ExtensionTests.cpp", + "src/tests/unittests/GetProcAddressTests.cpp", "src/tests/unittests/MathTests.cpp", "src/tests/unittests/ObjectBaseTests.cpp", "src/tests/unittests/PerStageTests.cpp", diff --git a/generator/dawn_json_generator.py b/generator/dawn_json_generator.py index 2489588868..f92ad92ec1 100644 --- a/generator/dawn_json_generator.py +++ b/generator/dawn_json_generator.py @@ -448,6 +448,12 @@ def c_native_methods(types, typ): Method(Name('release'), types['void'], []), ] +def get_methods_sorted_by_name(api_params): + unsorted = [(as_MethodSuffix(typ.name, method.name), typ, method) \ + for typ in api_params['by_category']['object'] \ + for method in c_native_methods(api_params['types'], typ) ] + return [(typ, method) for (_, typ, method) in sorted(unsorted)] + class MultiGeneratorFromDawnJSON(Generator): def get_description(self): return 'Generates code for various target from Dawn.json.' @@ -486,6 +492,7 @@ class MultiGeneratorFromDawnJSON(Generator): 'convert_cType_to_cppType': convert_cType_to_cppType, 'as_varName': as_varName, 'decorate': decorate, + 'methods_sorted_by_name': get_methods_sorted_by_name(api_params), } renders = [] diff --git a/generator/templates/api.h b/generator/templates/api.h index 733c387da1..bb71141c23 100644 --- a/generator/templates/api.h +++ b/generator/templates/api.h @@ -54,7 +54,6 @@ extern "C" { #endif // Custom types depending on the target language -typedef void (*DawnErrorCallback)(DawnErrorType type, const char* message, void* userdata); typedef void (*DawnBufferCreateMappedCallback)(DawnBufferMapAsyncStatus status, DawnCreateBufferMappedResult result, void* userdata); @@ -67,9 +66,14 @@ typedef void (*DawnBufferMapWriteCallback)(DawnBufferMapAsyncStatus status, uint64_t dataLength, void* userdata); typedef void (*DawnFenceOnCompletionCallback)(DawnFenceCompletionStatus status, void* userdata); +typedef void (*DawnErrorCallback)(DawnErrorType type, const char* message, void* userdata); + +typedef void (*DawnProc)(); #if !defined(DAWN_SKIP_PROCS) +typedef DawnProc (*DawnProcGetProcAddress)(DawnDevice device, const char* procName); + {% for type in by_category["object"] %} // Procs of {{type.name.CamelCase()}} {% for method in native_methods(type) %} @@ -86,6 +90,8 @@ typedef void (*DawnFenceOnCompletionCallback)(DawnFenceCompletionStatus status, #if !defined(DAWN_SKIP_DECLARATIONS) +DAWN_EXPORT DawnProc DawnGetProcAddress(DawnDevice device, const char* procName); + {% for type in by_category["object"] %} // Methods of {{type.name.CamelCase()}} {% for method in native_methods(type) %} diff --git a/generator/templates/api_proc.c b/generator/templates/api_proc.c index a09175b495..a6225912cd 100644 --- a/generator/templates/api_proc.c +++ b/generator/templates/api_proc.c @@ -26,6 +26,10 @@ void dawnProcSetProcs(const DawnProcTable* procs_) { } } +DawnProc DawnGetProcAddress(DawnDevice device, const char* procName) { + return procs.getProcAddress(device, procName); +} + {% for type in by_category["object"] %} {% for method in native_methods(type) %} {{as_cType(method.return_type.name)}} {{as_cMethod(type.name, method.name)}}( diff --git a/generator/templates/api_proc_table.h b/generator/templates/api_proc_table.h index 1f1eb3caa0..9fb850c7f1 100644 --- a/generator/templates/api_proc_table.h +++ b/generator/templates/api_proc_table.h @@ -18,6 +18,8 @@ #include "dawn/dawn.h" typedef struct DawnProcTable { + DawnProcGetProcAddress getProcAddress; + {% for type in by_category["object"] %} {% for method in native_methods(type) %} {{as_cProc(type.name, method.name)}} {{as_varName(type.name, method.name)}}; diff --git a/generator/templates/apicpp.cpp b/generator/templates/apicpp.cpp index ad468d175c..b43fb4fe8e 100644 --- a/generator/templates/apicpp.cpp +++ b/generator/templates/apicpp.cpp @@ -112,4 +112,8 @@ namespace dawn { {% endfor %} + Proc GetProcAddress(Device const& device, const char* procName) { + return reinterpret_cast(DawnGetProcAddress(device.Get(), procName)); + } + } diff --git a/generator/templates/apicpp.h b/generator/templates/apicpp.h index d5e9ea108e..a0122c90fc 100644 --- a/generator/templates/apicpp.h +++ b/generator/templates/apicpp.h @@ -48,6 +48,7 @@ namespace dawn { {% endfor %} + using Proc = DawnProc; {% for type in by_category["natively defined"] %} using {{as_cppType(type.name)}} = {{as_cType(type.name)}}; {% endfor %} @@ -175,6 +176,8 @@ namespace dawn { {% endfor %} + DAWN_EXPORT Proc GetProcAddress(Device const& device, const char* procName); + {% for type in by_category["structure"] %} struct {{as_cppType(type.name)}} { {% if type.extensible %} diff --git a/generator/templates/dawn_native/ProcTable.cpp b/generator/templates/dawn_native/ProcTable.cpp index 186b334c14..4f1bc7c33c 100644 --- a/generator/templates/dawn_native/ProcTable.cpp +++ b/generator/templates/dawn_native/ProcTable.cpp @@ -12,12 +12,10 @@ //* See the License for the specific language governing permissions and //* limitations under the License. -#include "common/Assert.h" - #include "dawn_native/dawn_platform.h" #include "dawn_native/DawnNative.h" -#include "dawn_native/ErrorData.h" -#include "dawn_native/ValidationUtils_autogen.h" + +#include {% for type in by_category["object"] %} {% if type.name.canonical_case() not in ["texture view"] %} @@ -28,11 +26,12 @@ namespace dawn_native { namespace { + {% for type in by_category["object"] %} {% for method in native_methods(type) %} {% set suffix = as_MethodSuffix(type.name, method.name) %} - {{as_cType(method.return_type.name)}} CToCpp{{suffix}}( + {{as_cType(method.return_type.name)}} Native{{suffix}}( {{-as_cType(type.name)}} cSelf {%- for arg in method.arguments -%} , {{as_annotated_cType(arg)}} @@ -71,13 +70,56 @@ namespace dawn_native { } {% endfor %} {% endfor %} + + struct ProcEntry { + DawnProc proc; + const char* name; + }; + static const ProcEntry sProcMap[] = { + {% for (type, method) in methods_sorted_by_name %} + { reinterpret_cast(Native{{as_MethodSuffix(type.name, method.name)}}), "{{as_cMethod(type.name, method.name)}}" }, + {% endfor %} + }; + static constexpr size_t sProcMapSize = sizeof(sProcMap) / sizeof(sProcMap[0]); + } + + DawnProc NativeGetProcAddress(DawnDevice, const char* procName) { + if (procName == nullptr) { + return nullptr; + } + + const ProcEntry* entry = std::lower_bound(&sProcMap[0], &sProcMap[sProcMapSize], procName, + [](const ProcEntry &a, const char *b) -> bool { + return strcmp(a.name, b) < 0; + } + ); + + if (entry != &sProcMap[sProcMapSize] && strcmp(entry->name, procName) == 0) { + return entry->proc; + } + + if (strcmp(procName, "dawnGetProcAddress") == 0) { + return reinterpret_cast(NativeGetProcAddress); + } + + return nullptr; + } + + std::vector GetProcMapNamesForTesting() { + std::vector result; + result.reserve(sProcMapSize); + for (const ProcEntry& entry : sProcMap) { + result.push_back(entry.name); + } + return result; } DawnProcTable GetProcsAutogen() { DawnProcTable table; + table.getProcAddress = NativeGetProcAddress; {% for type in by_category["object"] %} {% for method in native_methods(type) %} - table.{{as_varName(type.name, method.name)}} = CToCpp{{as_MethodSuffix(type.name, method.name)}}; + table.{{as_varName(type.name, method.name)}} = Native{{as_MethodSuffix(type.name, method.name)}}; {% endfor %} {% endfor %} return table; diff --git a/generator/templates/dawn_wire/client/ApiProcs.cpp b/generator/templates/dawn_wire/client/ApiProcs.cpp index 5b3fa95444..1bdccbbdc9 100644 --- a/generator/templates/dawn_wire/client/ApiProcs.cpp +++ b/generator/templates/dawn_wire/client/ApiProcs.cpp @@ -16,6 +16,9 @@ #include "dawn_wire/client/ApiProcs_autogen.h" #include "dawn_wire/client/Client.h" +#include +#include + namespace dawn_wire { namespace client { //* Implementation of the client API functions. {% for type in by_category["object"] %} @@ -89,6 +92,50 @@ namespace dawn_wire { namespace client { {% endif %} {% endfor %} + namespace { + struct ProcEntry { + DawnProc proc; + const char* name; + }; + static const ProcEntry sProcMap[] = { + {% for (type, method) in methods_sorted_by_name %} + { reinterpret_cast(Client{{as_MethodSuffix(type.name, method.name)}}), "{{as_cMethod(type.name, method.name)}}" }, + {% endfor %} + }; + static constexpr size_t sProcMapSize = sizeof(sProcMap) / sizeof(sProcMap[0]); + } // anonymous namespace + + DawnProc ClientGetProcAddress(DawnDevice, const char* procName) { + if (procName == nullptr) { + return nullptr; + } + + const ProcEntry* entry = std::lower_bound(&sProcMap[0], &sProcMap[sProcMapSize], procName, + [](const ProcEntry &a, const char *b) -> bool { + return strcmp(a.name, b) < 0; + } + ); + + if (entry != &sProcMap[sProcMapSize] && strcmp(entry->name, procName) == 0) { + return entry->proc; + } + + if (strcmp(procName, "dawnGetProcAddress") == 0) { + return reinterpret_cast(ClientGetProcAddress); + } + + return nullptr; + } + + std::vector GetProcMapNamesForTesting() { + std::vector result; + result.reserve(sProcMapSize); + for (const ProcEntry& entry : sProcMap) { + result.push_back(entry.name); + } + return result; + } + //* Some commands don't have a custom wire format, but need to be handled manually to update //* some client-side state tracking. For these we have two functions: //* - An autogenerated Client{{suffix}} method that sends the command on the wire @@ -96,6 +143,7 @@ namespace dawn_wire { namespace client { //* the autogenerated one, and that will have to call Client{{suffix}} DawnProcTable GetProcs() { DawnProcTable table; + table.getProcAddress = ClientGetProcAddress; {% for type in by_category["object"] %} {% for method in native_methods(type) %} {% set suffix = as_MethodSuffix(type.name, method.name) %} diff --git a/src/include/dawn_native/DawnNative.h b/src/include/dawn_native/DawnNative.h index d6eeebc8c6..b3125ed8f4 100644 --- a/src/include/dawn_native/DawnNative.h +++ b/src/include/dawn_native/DawnNative.h @@ -160,6 +160,9 @@ namespace dawn_native { // Backdoor to get the number of lazy clears for testing DAWN_NATIVE_EXPORT size_t GetLazyClearCountForTesting(DawnDevice device); + + // Backdoor to get the order of the ProcMap for testing + DAWN_NATIVE_EXPORT std::vector GetProcMapNamesForTesting(); } // namespace dawn_native #endif // DAWNNATIVE_DAWNNATIVE_H_ diff --git a/src/include/dawn_wire/WireClient.h b/src/include/dawn_wire/WireClient.h index 215e893974..d56e090065 100644 --- a/src/include/dawn_wire/WireClient.h +++ b/src/include/dawn_wire/WireClient.h @@ -120,8 +120,10 @@ namespace dawn_wire { virtual ~WriteHandle(); }; }; - } // namespace client + // Backdoor to get the order of the ProcMap for testing + DAWN_WIRE_EXPORT std::vector GetProcMapNamesForTesting(); + } // namespace client } // namespace dawn_wire #endif // DAWNWIRE_WIRECLIENT_H_ diff --git a/src/tests/unittests/GetProcAddressTests.cpp b/src/tests/unittests/GetProcAddressTests.cpp new file mode 100644 index 0000000000..e2a93b00e7 --- /dev/null +++ b/src/tests/unittests/GetProcAddressTests.cpp @@ -0,0 +1,166 @@ +// Copyright 2019 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "dawn/dawn_proc.h" +#include "dawn_native/Instance.h" +#include "dawn_native/null/DeviceNull.h" +#include "dawn_wire/WireClient.h" +#include "utils/TerribleCommandBuffer.h" + +namespace { + + // libdawn_wire and libdawn_native contain duplicated code for the handling of GetProcAddress + // so we run the tests against both implementations. This enum is used as a test parameters to + // know which implementation to test. + enum class DawnFlavor { + Native, + Wire, + }; + + std::ostream& operator<<(std::ostream& stream, DawnFlavor flavor) { + switch (flavor) { + case DawnFlavor::Native: + stream << "dawn_native"; + break; + + case DawnFlavor::Wire: + stream << "dawn_wire"; + break; + + default: + UNREACHABLE(); + break; + } + return stream; + } + + class GetProcAddressTests : public testing::TestWithParam { + public: + GetProcAddressTests() + : testing::TestWithParam(), + mNativeInstance(), + mNativeAdapter(&mNativeInstance) { + } + + void SetUp() override { + switch (GetParam()) { + case DawnFlavor::Native: { + mDevice = dawn::Device::Acquire( + reinterpret_cast(mNativeAdapter.CreateDevice(nullptr))); + mProcs = dawn_native::GetProcs(); + break; + } + + case DawnFlavor::Wire: { + mC2sBuf = std::make_unique(); + + dawn_wire::WireClientDescriptor clientDesc = {}; + clientDesc.serializer = mC2sBuf.get(); + mWireClient = std::make_unique(clientDesc); + + mDevice = dawn::Device::Acquire(mWireClient->GetDevice()); + mProcs = mWireClient->GetProcs(); + break; + } + + default: + UNREACHABLE(); + break; + } + + dawnProcSetProcs(&mProcs); + } + + void TearDown() override { + // Destroy the device before freeing the instance or the wire client in the destructor + mDevice = dawn::Device(); + } + + protected: + dawn_native::InstanceBase mNativeInstance; + dawn_native::null::Adapter mNativeAdapter; + + std::unique_ptr mC2sBuf; + std::unique_ptr mWireClient; + + dawn::Device mDevice; + DawnProcTable mProcs; + }; + + // Test GetProcAddress with and without devices on some valid examples + TEST_P(GetProcAddressTests, ValidExamples) { + ASSERT_EQ(mProcs.getProcAddress(nullptr, "dawnDeviceCreateBuffer"), + reinterpret_cast(mProcs.deviceCreateBuffer)); + ASSERT_EQ(mProcs.getProcAddress(mDevice.Get(), "dawnDeviceCreateBuffer"), + reinterpret_cast(mProcs.deviceCreateBuffer)); + ASSERT_EQ(mProcs.getProcAddress(nullptr, "dawnQueueSubmit"), + reinterpret_cast(mProcs.queueSubmit)); + ASSERT_EQ(mProcs.getProcAddress(mDevice.Get(), "dawnQueueSubmit"), + reinterpret_cast(mProcs.queueSubmit)); + } + + // Test GetProcAddress with and without devices on nullptr procName + TEST_P(GetProcAddressTests, Nullptr) { + ASSERT_EQ(mProcs.getProcAddress(nullptr, nullptr), nullptr); + ASSERT_EQ(mProcs.getProcAddress(mDevice.Get(), nullptr), nullptr); + } + + // Test GetProcAddress with and without devices on some invalid + TEST_P(GetProcAddressTests, InvalidExamples) { + ASSERT_EQ(mProcs.getProcAddress(nullptr, "dawnDeviceDoSomething"), nullptr); + ASSERT_EQ(mProcs.getProcAddress(mDevice.Get(), "dawnDeviceDoSomething"), nullptr); + + // Trigger the condition where lower_bound will return the end of the procMap. + ASSERT_EQ(mProcs.getProcAddress(nullptr, "zzzzzzz"), nullptr); + ASSERT_EQ(mProcs.getProcAddress(mDevice.Get(), "zzzzzzz"), nullptr); + ASSERT_EQ(mProcs.getProcAddress(nullptr, "ZZ"), nullptr); + ASSERT_EQ(mProcs.getProcAddress(mDevice.Get(), "ZZ"), nullptr); + + // Some more potential corner cases. + ASSERT_EQ(mProcs.getProcAddress(nullptr, ""), nullptr); + ASSERT_EQ(mProcs.getProcAddress(mDevice.Get(), ""), nullptr); + ASSERT_EQ(mProcs.getProcAddress(nullptr, "0"), nullptr); + ASSERT_EQ(mProcs.getProcAddress(mDevice.Get(), "0"), nullptr); + } + + // Test that GetProcAddress supports itself: it is handled specially because it is a + // freestanding function and not a method on an object. + TEST_P(GetProcAddressTests, GetProcAddressItself) { + ASSERT_EQ(mProcs.getProcAddress(nullptr, "dawnGetProcAddress"), + reinterpret_cast(mProcs.getProcAddress)); + ASSERT_EQ(mProcs.getProcAddress(mDevice.Get(), "dawnGetProcAddress"), + reinterpret_cast(mProcs.getProcAddress)); + } + + INSTANTIATE_TEST_SUITE_P(, + GetProcAddressTests, + testing::Values(DawnFlavor::Native, DawnFlavor::Wire), + testing::PrintToStringParamName()); + + TEST(GetProcAddressInternalTests, CheckDawnNativeProcMapOrder) { + std::vector names = dawn_native::GetProcMapNamesForTesting(); + for (size_t i = 1; i < names.size(); i++) { + ASSERT_LT(std::string(names[i - 1]), std::string(names[i])); + } + } + + TEST(GetProcAddressInternalTests, CheckDawnWireClientProcMapOrder) { + std::vector names = dawn_wire::client::GetProcMapNamesForTesting(); + for (size_t i = 1; i < names.size(); i++) { + ASSERT_LT(std::string(names[i - 1]), std::string(names[i])); + } + } +} // anonymous namespace