Updates DawnInstanceDescriptor to pass in the Platform.

Notes:
- Separates ChainedStruct to be reusable without cpp header. (Also
  updates native structs to directly use it.)
- Manually implements the descriptor in DawnNative.
- Reworks ChainUtils with mapping from struct to STypes.
- Updates the tests to use either SetPlatformForTesting which is still
  required because DawnTest uses a "global" instance for all tests and
  some tests require setting (and cleaning up) a test specific platform.

Bug: dawn:1374
Change-Id: I078c78f22c5137030cf3cf0e8358fe4373ee9c6c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/132268
Reviewed-by: Austin Eng <enga@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Loko Kung <lokokung@google.com>
This commit is contained in:
Loko Kung 2023-05-16 04:50:32 +00:00 committed by Dawn LUCI CQ
parent f9f9f829e3
commit 14ed533565
15 changed files with 157 additions and 75 deletions

View File

@ -1531,16 +1531,6 @@
"extensible": "in", "extensible": "in",
"members": [] "members": []
}, },
"dawn instance descriptor": {
"tags": ["dawn", "native"],
"category": "structure",
"chained": "in",
"chain roots": ["instance descriptor"],
"members": [
{"name": "additional runtime search paths count", "type": "uint32_t", "default": 0},
{"name": "additional runtime search paths", "type": "char", "annotation": "const*const*", "length": "additional runtime search paths count"}
]
},
"vertex attribute": { "vertex attribute": {
"category": "structure", "category": "structure",
"extensible": false, "extensible": false,

View File

@ -965,6 +965,11 @@ class MultiGeneratorFromDawnJSON(Generator):
'include/dawn/' + api + '_cpp_print.h', 'include/dawn/' + api + '_cpp_print.h',
[RENDER_PARAMS_BASE, params_dawn])) [RENDER_PARAMS_BASE, params_dawn]))
renders.append(
FileRender('api_cpp_chained_struct.h',
'include/dawn/' + api + '_cpp_chained_struct.h',
[RENDER_PARAMS_BASE, params_dawn]))
if 'proc' in targets: if 'proc' in targets:
renders.append( renders.append(
FileRender('dawn_proc.c', 'src/dawn/' + prefix + '_proc.c', FileRender('dawn_proc.c', 'src/dawn/' + prefix + '_proc.c',

View File

@ -13,7 +13,7 @@
//* limitations under the License. //* limitations under the License.
{% set API = metadata.api.upper() %} {% set API = metadata.api.upper() %}
{% set api = API.lower() %} {% set api = API.lower() %}
{% if 'dawn' not in enabled_tags %} {% if 'dawn' in enabled_tags %}
#ifdef __EMSCRIPTEN__ #ifdef __EMSCRIPTEN__
#error "Do not include this header. Emscripten already provides headers needed for {{metadata.api}}." #error "Do not include this header. Emscripten already provides headers needed for {{metadata.api}}."
#endif #endif
@ -22,17 +22,12 @@
#define {{API}}_CPP_H_ #define {{API}}_CPP_H_
#include "dawn/{{api}}.h" #include "dawn/{{api}}.h"
#include "dawn/{{api}}_cpp_chained_struct.h"
#include "dawn/EnumClassBitmasks.h" #include "dawn/EnumClassBitmasks.h"
#include <cmath> #include <cmath>
namespace {{metadata.namespace}} { namespace {{metadata.namespace}} {
namespace detail {
constexpr size_t ConstexprMax(size_t a, size_t b) {
return a > b ? a : b;
}
} // namespace detail
{% set c_prefix = metadata.c_prefix %} {% set c_prefix = metadata.c_prefix %}
{% for constant in by_category["constant"] %} {% for constant in by_category["constant"] %}
{% set type = as_cppType(constant.type.name) %} {% set type = as_cppType(constant.type.name) %}
@ -218,16 +213,6 @@ namespace {{metadata.namespace}} {
); );
{% endfor %} {% endfor %}
struct ChainedStruct {
ChainedStruct const * nextInChain = nullptr;
SType sType = SType::Invalid;
};
struct ChainedStructOut {
ChainedStruct * nextInChain = nullptr;
SType sType = SType::Invalid;
};
{% for type in by_category["structure"] %} {% for type in by_category["structure"] %}
{% set Out = "Out" if type.output else "" %} {% set Out = "Out" if type.output else "" %}
{% set const = "const" if not type.output else "" %} {% set const = "const" if not type.output else "" %}

View File

@ -0,0 +1,48 @@
//* Copyright 2023 The Dawn Authors
//*
//* Licensed under the Apache License, Version 2.0 (the "License");
//* you may not use this file except in compliance with the License.
//* You may obtain a copy of the License at
//*
//* http://www.apache.org/licenses/LICENSE-2.0
//*
//* Unless required by applicable law or agreed to in writing, software
//* distributed under the License is distributed on an "AS IS" BASIS,
//* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//* See the License for the specific language governing permissions and
//* limitations under the License.
{% set API = metadata.api.upper() %}
{% if 'dawn' in enabled_tags %}
#ifdef __EMSCRIPTEN__
#error "Do not include this header. Emscripten already provides headers needed for {{metadata.api}}."
#endif
{% endif %}
#ifndef {{API}}_CPP_CHAINED_STRUCT_H_
#define {{API}}_CPP_CHAINED_STRUCT_H_
// This header file declares the ChainedStruct structures separately from the {{metadata.api}}
// headers so that dependencies can directly extend structures without including the larger header
// which exposes capabilities that may require correctly set proc tables.
namespace {{metadata.namespace}} {
namespace detail {
constexpr size_t ConstexprMax(size_t a, size_t b) {
return a > b ? a : b;
}
} // namespace detail
enum class SType : uint32_t;
struct ChainedStruct {
ChainedStruct const * nextInChain = nullptr;
SType sType = SType(0u);
};
struct ChainedStructOut {
ChainedStructOut * nextInChain = nullptr;
SType sType = SType(0u);
};
} // namespace {{metadata.namespace}}}
#endif // {{API}}_CPP_CHAINED_STRUCT_H_

View File

@ -23,21 +23,6 @@
namespace {{native_namespace}} { namespace {{native_namespace}} {
{% set namespace = metadata.namespace %} {% set namespace = metadata.namespace %}
{% for value in types["s type"].values %}
{% if value.valid %}
{% set const_qualifier = "const " if types[value.name.get()].chained == "in" else "" %}
{% set chained_struct_type = "ChainedStruct" if types[value.name.get()].chained == "in" else "ChainedStructOut" %}
void FindInChain({{const_qualifier}}{{chained_struct_type}}* chain, {{const_qualifier}}{{as_cppEnum(value.name)}}** out) {
for (; chain; chain = chain->nextInChain) {
if (chain->sType == {{namespace}}::SType::{{as_cppEnum(value.name)}}) {
*out = static_cast<{{const_qualifier}}{{as_cppEnum(value.name)}}*>(chain);
break;
}
}
}
{% endif %}
{% endfor %}
MaybeError ValidateSTypes(const ChainedStruct* chain, MaybeError ValidateSTypes(const ChainedStruct* chain,
std::vector<std::vector<{{namespace}}::SType>> oneOfConstraints) { std::vector<std::vector<{{namespace}}::SType>> oneOfConstraints) {
std::unordered_set<{{namespace}}::SType> allSTypes; std::unordered_set<{{namespace}}::SType> allSTypes;

View File

@ -18,6 +18,7 @@
#define {{DIR}}_CHAIN_UTILS_H_ #define {{DIR}}_CHAIN_UTILS_H_
{% set impl_dir = metadata.impl_dir + "/" if metadata.impl_dir else "" %} {% set impl_dir = metadata.impl_dir + "/" if metadata.impl_dir else "" %}
{% set namespace = metadata.namespace %}
{% set namespace_name = Name(metadata.native_namespace) %} {% set namespace_name = Name(metadata.native_namespace) %}
{% set native_namespace = namespace_name.namespace_case() %} {% set native_namespace = namespace_name.namespace_case() %}
{% set native_dir = impl_dir + namespace_name.Dirs() %} {% set native_dir = impl_dir + namespace_name.Dirs() %}
@ -26,13 +27,44 @@
#include "{{native_dir}}/Error.h" #include "{{native_dir}}/Error.h"
namespace {{native_namespace}} { namespace {{native_namespace}} {
namespace detail {
// Mapping from native types to the expected STypes is implemented as template specializations.
template <typename T>
struct STypeForImpl;
{% for value in types["s type"].values %} {% for value in types["s type"].values %}
{% if value.valid %} {% if value.valid and value.name.get() in types %}
{% set const_qualifier = "const " if types[value.name.get()].chained == "in" else "" %} template <>
{% set chained_struct_type = "ChainedStruct" if types[value.name.get()].chained == "in" else "ChainedStructOut" %} struct STypeForImpl<{{as_cppEnum(value.name)}}> {
void FindInChain({{const_qualifier}}{{chained_struct_type}}* chain, {{const_qualifier}}{{as_cppEnum(value.name)}}** out); static constexpr {{namespace}}::SType value = {{namespace}}::SType::{{as_cppEnum(value.name)}};
};
{% endif %} {% endif %}
{% endfor %} {% endfor %}
template <>
struct STypeForImpl<DawnInstanceDescriptor> {
static constexpr {{namespace}}::SType value = {{namespace}}::SType::DawnInstanceDescriptor;
};
} // namespace detail
template <typename T>
constexpr {{namespace}}::SType STypeFor = detail::STypeForImpl<T>::value;
template <typename T>
void FindInChain(const ChainedStruct* chain, const T** out) {
for (; chain; chain = chain->nextInChain) {
if (chain->sType == STypeFor<T>) {
*out = static_cast<const T*>(chain);
break;
}
}
}
template <typename T>
void FindInChain(ChainedStructOut* chain, T** out) {
for (; chain; chain = chain->nextInChain) {
if (chain->sType == STypeFor<T>) {
*out = static_cast<T*>(chain);
break;
}
}
}
// Verifies that |chain| only contains ChainedStructs of types enumerated in // Verifies that |chain| only contains ChainedStructs of types enumerated in
// |oneOfConstraints| and contains no duplicate sTypes. Each vector in // |oneOfConstraints| and contains no duplicate sTypes. Each vector in
@ -40,7 +72,6 @@ namespace {{native_namespace}} {
// For example: // For example:
// ValidateSTypes(chain, { { ShaderModuleSPIRVDescriptor, ShaderModuleWGSLDescriptor } })) // ValidateSTypes(chain, { { ShaderModuleSPIRVDescriptor, ShaderModuleWGSLDescriptor } }))
// ValidateSTypes(chain, { { Extension1 }, { Extension2 } }) // ValidateSTypes(chain, { { Extension1 }, { Extension2 } })
{% set namespace = metadata.namespace %}
MaybeError ValidateSTypes(const ChainedStruct* chain, MaybeError ValidateSTypes(const ChainedStruct* chain,
std::vector<std::vector<{{namespace}}::SType>> oneOfConstraints); std::vector<std::vector<{{namespace}}::SType>> oneOfConstraints);
MaybeError ValidateSTypes(const ChainedStructOut* chain, MaybeError ValidateSTypes(const ChainedStructOut* chain,

View File

@ -44,15 +44,8 @@ namespace {{native_namespace}} {
{%- endif -%} {%- endif -%}
{%- endmacro %} {%- endmacro %}
struct ChainedStruct { using {{namespace}}::ChainedStruct;
ChainedStruct const * nextInChain = nullptr; using {{namespace}}::ChainedStructOut;
{{namespace}}::SType sType = {{namespace}}::SType::Invalid;
};
struct ChainedStructOut {
ChainedStructOut * nextInChain = nullptr;
{{namespace}}::SType sType = {{namespace}}::SType::Invalid;
};
{% for type in by_category["structure"] %} {% for type in by_category["structure"] %}
{% if type.chained %} {% if type.chained %}

View File

@ -44,6 +44,7 @@ dawn_json_generator("cpp_headers_gen") {
target = "cpp_headers" target = "cpp_headers"
outputs = [ outputs = [
"include/dawn/webgpu_cpp.h", "include/dawn/webgpu_cpp.h",
"include/dawn/webgpu_cpp_chained_struct.h",
"include/dawn/webgpu_cpp_print.h", "include/dawn/webgpu_cpp_print.h",
] ]
} }

View File

@ -21,6 +21,7 @@
#include "dawn/dawn_proc_table.h" #include "dawn/dawn_proc_table.h"
#include "dawn/native/dawn_native_export.h" #include "dawn/native/dawn_native_export.h"
#include "dawn/webgpu.h" #include "dawn/webgpu.h"
#include "dawn/webgpu_cpp_chained_struct.h"
namespace dawn::platform { namespace dawn::platform {
class Platform; class Platform;
@ -128,6 +129,20 @@ struct DAWN_NATIVE_EXPORT AdapterDiscoveryOptionsBase {
enum BackendValidationLevel { Full, Partial, Disabled }; enum BackendValidationLevel { Full, Partial, Disabled };
// Can be chained in InstanceDescriptor
struct DAWN_NATIVE_EXPORT DawnInstanceDescriptor : wgpu::ChainedStruct {
DawnInstanceDescriptor();
static constexpr size_t kFirstMemberAlignment =
wgpu::detail::ConstexprMax(alignof(wgpu::ChainedStruct), alignof(uint32_t));
alignas(kFirstMemberAlignment) uint32_t additionalRuntimeSearchPathsCount = 0;
const char* const* additionalRuntimeSearchPaths;
dawn::platform::Platform* platform = nullptr;
// Equality operators, mostly for testing. Note that this tests
// strict pointer-pointer equality if the struct contains member pointers.
bool operator==(const DawnInstanceDescriptor& rhs) const;
};
// Represents a connection to dawn_native and is used for dependency injection, discovering // Represents a connection to dawn_native and is used for dependency injection, discovering
// system adapters and injecting custom adapters (like a Swiftshader Vulkan adapter). // system adapters and injecting custom adapters (like a Swiftshader Vulkan adapter).
// //
@ -262,6 +277,12 @@ DAWN_NATIVE_EXPORT bool BindGroupLayoutBindingsEqualForTesting(WGPUBindGroupLayo
} // namespace dawn::native } // namespace dawn::native
// Alias the DawnInstanceDescriptor up to wgpu.
// TODO(dawn:1374) Remove this aliasing once the usages are updated.
namespace wgpu {
using dawn::native::DawnInstanceDescriptor;
} // namespace wgpu
// TODO(dawn:824): Remove once the deprecation period is passed. // TODO(dawn:824): Remove once the deprecation period is passed.
namespace dawn_native = dawn::native; namespace dawn_native = dawn::native;

View File

@ -135,6 +135,19 @@ void Adapter::ResetInternalDeviceForTesting() {
AdapterDiscoveryOptionsBase::AdapterDiscoveryOptionsBase(WGPUBackendType type) AdapterDiscoveryOptionsBase::AdapterDiscoveryOptionsBase(WGPUBackendType type)
: backendType(type) {} : backendType(type) {}
// DawnInstanceDescriptor
DawnInstanceDescriptor::DawnInstanceDescriptor() {
sType = wgpu::SType::DawnInstanceDescriptor;
}
bool DawnInstanceDescriptor::operator==(const DawnInstanceDescriptor& rhs) const {
return (nextInChain == rhs.nextInChain) &&
std::tie(additionalRuntimeSearchPathsCount, additionalRuntimeSearchPaths, platform) ==
std::tie(rhs.additionalRuntimeSearchPathsCount, rhs.additionalRuntimeSearchPaths,
rhs.platform);
}
// Instance // Instance
Instance::Instance(const WGPUInstanceDescriptor* desc) Instance::Instance(const WGPUInstanceDescriptor* desc)

View File

@ -186,7 +186,7 @@ MaybeError InstanceBase::Initialize(const InstanceDescriptor* descriptor) {
// Initialize the platform to the default for now. // Initialize the platform to the default for now.
mDefaultPlatform = std::make_unique<dawn::platform::Platform>(); mDefaultPlatform = std::make_unique<dawn::platform::Platform>();
SetPlatform(mDefaultPlatform.get()); SetPlatform(dawnDesc != nullptr ? dawnDesc->platform : mDefaultPlatform.get());
return {}; return {};
} }

View File

@ -49,19 +49,21 @@ void DawnNativeTest::SetUp() {
// adapter and device toggles and allow us to test unsafe apis (including experimental // adapter and device toggles and allow us to test unsafe apis (including experimental
// features). // features).
const char* allowUnsafeApisToggle = "allow_unsafe_apis"; const char* allowUnsafeApisToggle = "allow_unsafe_apis";
WGPUDawnTogglesDescriptor instanceToggles = {}; wgpu::DawnTogglesDescriptor instanceToggles;
instanceToggles.chain.sType = WGPUSType::WGPUSType_DawnTogglesDescriptor;
instanceToggles.enabledTogglesCount = 1; instanceToggles.enabledTogglesCount = 1;
instanceToggles.enabledToggles = &allowUnsafeApisToggle; instanceToggles.enabledToggles = &allowUnsafeApisToggle;
WGPUInstanceDescriptor instanceDesc = {};
instanceDesc.nextInChain = &instanceToggles.chain;
instance = std::make_unique<dawn::native::Instance>(&instanceDesc);
instance->EnableAdapterBlocklist(false);
platform = CreateTestPlatform(); platform = CreateTestPlatform();
dawn::native::FromAPI(instance->Get())->SetPlatformForTesting(platform.get()); wgpu::DawnInstanceDescriptor dawnInstanceDesc;
dawnInstanceDesc.platform = platform.get();
dawnInstanceDesc.nextInChain = &instanceToggles;
wgpu::InstanceDescriptor instanceDesc;
instanceDesc.nextInChain = &dawnInstanceDesc;
instance = std::make_unique<dawn::native::Instance>(
reinterpret_cast<const WGPUInstanceDescriptor*>(&instanceDesc));
instance->EnableAdapterBlocklist(false);
instance->DiscoverDefaultAdapters(); instance->DiscoverDefaultAdapters();
std::vector<dawn::native::Adapter> adapters = instance->GetAdapters(); std::vector<dawn::native::Adapter> adapters = instance->GetAdapters();

View File

@ -373,20 +373,25 @@ void DawnTestEnvironment::ParseArgs(int argc, char** argv) {
} }
} }
std::unique_ptr<dawn::native::Instance> DawnTestEnvironment::CreateInstanceAndDiscoverAdapters() { std::unique_ptr<dawn::native::Instance> DawnTestEnvironment::CreateInstanceAndDiscoverAdapters(
dawn::platform::Platform* platform) {
// Create an instance with toggle AllowUnsafeAPIs enabled, which would be inherited to // Create an instance with toggle AllowUnsafeAPIs enabled, which would be inherited to
// adapter and device toggles and allow us to test unsafe apis (including experimental // adapter and device toggles and allow us to test unsafe apis (including experimental
// features). // features).
const char* allowUnsafeApisToggle = "allow_unsafe_apis"; const char* allowUnsafeApisToggle = "allow_unsafe_apis";
WGPUDawnTogglesDescriptor instanceToggles = {}; wgpu::DawnTogglesDescriptor instanceToggles;
instanceToggles.chain.sType = WGPUSType::WGPUSType_DawnTogglesDescriptor;
instanceToggles.enabledTogglesCount = 1; instanceToggles.enabledTogglesCount = 1;
instanceToggles.enabledToggles = &allowUnsafeApisToggle; instanceToggles.enabledToggles = &allowUnsafeApisToggle;
WGPUInstanceDescriptor instanceDesc = {}; wgpu::DawnInstanceDescriptor dawnInstanceDesc;
instanceDesc.nextInChain = &instanceToggles.chain; dawnInstanceDesc.platform = platform;
dawnInstanceDesc.nextInChain = &instanceToggles;
auto instance = std::make_unique<dawn::native::Instance>(&instanceDesc); wgpu::InstanceDescriptor instanceDesc;
instanceDesc.nextInChain = &dawnInstanceDesc;
auto instance = std::make_unique<dawn::native::Instance>(
reinterpret_cast<const WGPUInstanceDescriptor*>(&instanceDesc));
instance->EnableBeginCaptureOnStartup(mBeginCaptureOnStartup); instance->EnableBeginCaptureOnStartup(mBeginCaptureOnStartup);
instance->SetBackendValidationLevel(mBackendValidationLevel); instance->SetBackendValidationLevel(mBackendValidationLevel);
instance->EnableAdapterBlocklist(false); instance->EnableAdapterBlocklist(false);
@ -1096,6 +1101,9 @@ void DawnTestBase::TearDown() {
EXPECT_EQ(mLastWarningCount, EXPECT_EQ(mLastWarningCount,
dawn::native::GetDeprecationWarningCountForTesting(device.Get())); dawn::native::GetDeprecationWarningCountForTesting(device.Get()));
} }
// Unsets the platform since we are cleaning the per-test platform up with the test case.
dawn::native::FromAPI(gTestEnv->GetInstance()->Get())->SetPlatformForTesting(nullptr);
} }
void DawnTestBase::DestroyDevice(wgpu::Device device) { void DawnTestBase::DestroyDevice(wgpu::Device device) {

View File

@ -183,11 +183,12 @@ class DawnTestEnvironment : public testing::Environment {
bool RunSuppressedTests() const; bool RunSuppressedTests() const;
protected: protected:
std::unique_ptr<dawn::native::Instance> CreateInstanceAndDiscoverAdapters(
dawn::platform::Platform* platform = nullptr);
std::unique_ptr<dawn::native::Instance> mInstance; std::unique_ptr<dawn::native::Instance> mInstance;
private: private:
void ParseArgs(int argc, char** argv); void ParseArgs(int argc, char** argv);
std::unique_ptr<dawn::native::Instance> CreateInstanceAndDiscoverAdapters();
void SelectPreferredAdapterProperties(const dawn::native::Instance* instance); void SelectPreferredAdapterProperties(const dawn::native::Instance* instance);
void PrintTestConfigurationAndAdapterInfo(dawn::native::Instance* instance) const; void PrintTestConfigurationAndAdapterInfo(dawn::native::Instance* instance) const;

View File

@ -122,10 +122,9 @@ DawnPerfTestEnvironment::DawnPerfTestEnvironment(int argc, char** argv)
DawnPerfTestEnvironment::~DawnPerfTestEnvironment() = default; DawnPerfTestEnvironment::~DawnPerfTestEnvironment() = default;
void DawnPerfTestEnvironment::SetUp() { void DawnPerfTestEnvironment::SetUp() {
DawnTestEnvironment::SetUp();
mPlatform = std::make_unique<DawnPerfTestPlatform>(); mPlatform = std::make_unique<DawnPerfTestPlatform>();
mInstance->SetPlatform(mPlatform.get()); mInstance = CreateInstanceAndDiscoverAdapters(mPlatform.get());
ASSERT(mInstance);
// Begin writing the trace event array. // Begin writing the trace event array.
if (mTraceFile != nullptr) { if (mTraceFile != nullptr) {