From 5346e770c94a7e629d3c9769f9e9a6410bae6dff Mon Sep 17 00:00:00 2001 From: Brian Ho Date: Thu, 22 Apr 2021 17:49:42 +0000 Subject: [PATCH] Add helper functions to iterate over ChainedStructs This CL adds two helpers for more ergonomic processing of ChainedStructs. 1. FindInChain(): Iterates through the chain and automatically casts the ChainedStruct into the appropriate child type before returning. 2. ValidateSTypes(): Verifies that the chain only contains structs with sTypes from a pre-defined set. This also allows the caller to specify one-of constraints. 3. ValidateSingleSType(): Verifies that the chain contains a single struct with a specific sType or is an empty chain. This is a common case of |ValidateSTypes()| and is separated out as a fast-path. Change-Id: I938df0bf2a9b1800b1105fb7f80fbde20bef8ec8 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/47680 Commit-Queue: Brian Ho Reviewed-by: Corentin Wallez --- docs/codegen.md | 1 + generator/dawn_json_generator.py | 8 + .../templates/dawn_native/ChainUtils.cpp | 61 ++++++ generator/templates/dawn_native/ChainUtils.h | 81 ++++++++ src/dawn_native/BUILD.gn | 2 + src/dawn_native/RenderPipeline.cpp | 25 ++- src/dawn_native/ShaderModule.cpp | 119 +++++------- src/dawn_native/Surface.cpp | 147 ++++++-------- src/tests/BUILD.gn | 1 + src/tests/unittests/ChainUtilsTests.cpp | 181 ++++++++++++++++++ 10 files changed, 459 insertions(+), 167 deletions(-) create mode 100644 generator/templates/dawn_native/ChainUtils.cpp create mode 100644 generator/templates/dawn_native/ChainUtils.h create mode 100644 src/tests/unittests/ChainUtilsTests.cpp diff --git a/docs/codegen.md b/docs/codegen.md index 9cdf40b064..d79c0e9d6d 100644 --- a/docs/codegen.md +++ b/docs/codegen.md @@ -18,6 +18,7 @@ At this time it is used to generate: - validation helper functions for dawn_native - the definition of dawn_native's proc table - dawn_native's internal version of the webgpu.h types + - utilities for working with dawn_native's chained structs - a lot of dawn_wire parts, see below Internally `dawn.json` is a dictionary from the "canonical name" of things to their definition. The "canonical name" is a space-separated (mostly) lower-case version of the name that's parsed into a `Name` Python object. Then that name can be turned into various casings with `.CamelCase()` `.SNAKE_CASE()`, etc. When `dawn.json` things reference each other, it is always via these "canonical names". diff --git a/generator/dawn_json_generator.py b/generator/dawn_json_generator.py index 84a2b99c1a..858be5219c 100644 --- a/generator/dawn_json_generator.py +++ b/generator/dawn_json_generator.py @@ -765,6 +765,14 @@ class MultiGeneratorFromDawnJSON(Generator): renders.append( FileRender('dawn_native/ProcTable.cpp', 'src/dawn_native/ProcTable.cpp', frontend_params)) + renders.append( + FileRender('dawn_native/ChainUtils.h', + 'src/dawn_native/ChainUtils_autogen.h', + frontend_params)) + renders.append( + FileRender('dawn_native/ChainUtils.cpp', + 'src/dawn_native/ChainUtils_autogen.cpp', + frontend_params)) if 'dawn_wire' in targets: additional_params = compute_wire_params(api_params, wire_json) diff --git a/generator/templates/dawn_native/ChainUtils.cpp b/generator/templates/dawn_native/ChainUtils.cpp new file mode 100644 index 0000000000..2a42db2a32 --- /dev/null +++ b/generator/templates/dawn_native/ChainUtils.cpp @@ -0,0 +1,61 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dawn_native/ChainUtils_autogen.h" + +#include + +namespace dawn_native { + +{% for value in types["s type"].values %} + {% if value.valid %} + void FindInChain(const ChainedStruct* chain, const {{as_cppEnum(value.name)}}** out) { + for (; chain; chain = chain->nextInChain) { + if (chain->sType == wgpu::SType::{{as_cppEnum(value.name)}}) { + *out = static_cast(chain); + break; + } + } + } + {% endif %} +{% endfor %} + +MaybeError ValidateSTypes(const ChainedStruct* chain, + std::vector> oneOfConstraints) { + std::unordered_set allSTypes; + for (; chain; chain = chain->nextInChain) { + if (allSTypes.find(chain->sType) != allSTypes.end()) { + return DAWN_VALIDATION_ERROR("Chain cannot have duplicate sTypes"); + } + allSTypes.insert(chain->sType); + } + for (const auto& oneOfConstraint : oneOfConstraints) { + bool satisfied = false; + for (wgpu::SType oneOfSType : oneOfConstraint) { + if (allSTypes.find(oneOfSType) != allSTypes.end()) { + if (satisfied) { + return DAWN_VALIDATION_ERROR("Unsupported sType combination"); + } + satisfied = true; + allSTypes.erase(oneOfSType); + } + } + } + if (!allSTypes.empty()) { + return DAWN_VALIDATION_ERROR("Unsupported sType"); + } + return {}; +} + +} // namespace dawn_native diff --git a/generator/templates/dawn_native/ChainUtils.h b/generator/templates/dawn_native/ChainUtils.h new file mode 100644 index 0000000000..ce4659173d --- /dev/null +++ b/generator/templates/dawn_native/ChainUtils.h @@ -0,0 +1,81 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DAWNNATIVE_CHAIN_UTILS_H_ +#define DAWNNATIVE_CHAIN_UTILS_H_ + +#include "dawn_native/dawn_platform.h" +#include "dawn_native/Error.h" + +namespace dawn_native { + {% for value in types["s type"].values %} + {% if value.valid %} + void FindInChain(const ChainedStruct* chain, const {{as_cppEnum(value.name)}}** out); + {% endif %} + {% endfor %} + + // Verifies that |chain| only contains ChainedStructs of types enumerated in + // |oneOfConstraints| and contains no duplicate sTypes. Each vector in + // |oneOfConstraints| defines a set of sTypes that cannot coexist in the same chain. + // For example: + // ValidateSTypes(chain, { { ShaderModuleSPIRVDescriptor, ShaderModuleWGSLDescriptor } })) + // ValidateSTypes(chain, { { Extension1 }, { Extension2 } }) + MaybeError ValidateSTypes(const ChainedStruct* chain, + std::vector> oneOfConstraints); + + template + MaybeError ValidateSingleSTypeInner(const ChainedStruct* chain, T sType) { + if (chain->sType != sType) { + return DAWN_VALIDATION_ERROR("Unsupported sType"); + } + return {}; + } + + template + MaybeError ValidateSingleSTypeInner(const ChainedStruct* chain, T sType, Args... sTypes) { + if (chain->sType == sType) { + return {}; + } + return ValidateSingleSTypeInner(chain, sTypes...); + } + + // Verifies that |chain| contains a single ChainedStruct of type |sType| or no ChainedStructs + // at all. + template + MaybeError ValidateSingleSType(const ChainedStruct* chain, T sType) { + if (chain == nullptr) { + return {}; + } + if (chain->nextInChain != nullptr) { + return DAWN_VALIDATION_ERROR("Chain can only contain a single chained struct"); + } + return ValidateSingleSTypeInner(chain, sType); + } + + // Verifies that |chain| contains a single ChainedStruct with a type enumerated in the + // parameter pack or no ChainedStructs at all. + template + MaybeError ValidateSingleSType(const ChainedStruct* chain, T sType, Args... sTypes) { + if (chain == nullptr) { + return {}; + } + if (chain->nextInChain != nullptr) { + return DAWN_VALIDATION_ERROR("Chain can only contain a single chained struct"); + } + return ValidateSingleSTypeInner(chain, sType, sTypes...); + } + +} // namespace dawn_native + +#endif // DAWNNATIVE_CHAIN_UTILS_H_ diff --git a/src/dawn_native/BUILD.gn b/src/dawn_native/BUILD.gn index 803ee30246..2a5a59b436 100644 --- a/src/dawn_native/BUILD.gn +++ b/src/dawn_native/BUILD.gn @@ -91,6 +91,8 @@ config("dawn_native_vulkan_rpath") { dawn_json_generator("dawn_native_utils_gen") { target = "dawn_native_utils" outputs = [ + "src/dawn_native/ChainUtils_autogen.h", + "src/dawn_native/ChainUtils_autogen.cpp", "src/dawn_native/ProcTable.cpp", "src/dawn_native/wgpu_structs_autogen.h", "src/dawn_native/wgpu_structs_autogen.cpp", diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp index c09bdfb8e9..c32034dd01 100644 --- a/src/dawn_native/RenderPipeline.cpp +++ b/src/dawn_native/RenderPipeline.cpp @@ -16,6 +16,7 @@ #include "common/BitSetIterator.h" #include "common/VertexFormatUtils.h" +#include "dawn_native/ChainUtils_autogen.h" #include "dawn_native/Commands.h" #include "dawn_native/Device.h" #include "dawn_native/ObjectContentHasher.h" @@ -133,16 +134,13 @@ namespace dawn_native { MaybeError ValidatePrimitiveState(const DeviceBase* device, const PrimitiveState* descriptor) { - const ChainedStruct* chained = descriptor->nextInChain; - if (chained != nullptr) { - if (chained->sType != wgpu::SType::PrimitiveDepthClampingState) { - return DAWN_VALIDATION_ERROR("Unsupported sType"); - } - if (!device->IsExtensionEnabled(Extension::DepthClamping)) { - return DAWN_VALIDATION_ERROR("The depth clamping feature is not supported"); - } + DAWN_TRY(ValidateSingleSType(descriptor->nextInChain, + wgpu::SType::PrimitiveDepthClampingState)); + const PrimitiveDepthClampingState* clampInfo = nullptr; + FindInChain(descriptor->nextInChain, &clampInfo); + if (clampInfo && !device->IsExtensionEnabled(Extension::DepthClamping)) { + return DAWN_VALIDATION_ERROR("The depth clamping feature is not supported"); } - DAWN_TRY(ValidatePrimitiveTopology(descriptor->topology)); DAWN_TRY(ValidateIndexFormat(descriptor->stripIndexFormat)); DAWN_TRY(ValidateFrontFace(descriptor->frontFace)); @@ -426,11 +424,10 @@ namespace dawn_native { } mPrimitive = descriptor->primitive; - const ChainedStruct* chained = mPrimitive.nextInChain; - if (chained != nullptr) { - ASSERT(chained->sType == wgpu::SType::PrimitiveDepthClampingState); - const auto* clampState = static_cast(chained); - mClampDepth = clampState->clampDepth; + const PrimitiveDepthClampingState* clampInfo = nullptr; + FindInChain(mPrimitive.nextInChain, &clampInfo); + if (clampInfo) { + mClampDepth = clampInfo->clampDepth; } mMultisample = descriptor->multisample; diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp index a9cceac5b3..5e5762fced 100644 --- a/src/dawn_native/ShaderModule.cpp +++ b/src/dawn_native/ShaderModule.cpp @@ -17,6 +17,7 @@ #include "common/HashUtils.h" #include "common/VertexFormatUtils.h" #include "dawn_native/BindGroupLayout.h" +#include "dawn_native/ChainUtils_autogen.h" #include "dawn_native/CompilationMessages.h" #include "dawn_native/Device.h" #include "dawn_native/ObjectContentHasher.h" @@ -1069,65 +1070,56 @@ namespace dawn_native { return DAWN_VALIDATION_ERROR("Shader module descriptor missing chained descriptor"); } // For now only a single SPIRV or WGSL subdescriptor is allowed. - if (chainedDescriptor->nextInChain != nullptr) { - return DAWN_VALIDATION_ERROR( - "Shader module descriptor chained nextInChain must be nullptr"); - } + DAWN_TRY(ValidateSingleSType(chainedDescriptor, + wgpu::SType::ShaderModuleSPIRVDescriptor, + wgpu::SType::ShaderModuleWGSLDescriptor)); OwnedCompilationMessages* outMessages = parseResult->compilationMessages.get(); ScopedTintICEHandler scopedICEHandler(device); - switch (chainedDescriptor->sType) { - case wgpu::SType::ShaderModuleSPIRVDescriptor: { - const auto* spirvDesc = - static_cast(chainedDescriptor); - std::vector spirv(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize); - if (device->IsToggleEnabled(Toggle::UseTintGenerator)) { - tint::Program program; - DAWN_TRY_ASSIGN(program, ParseSPIRV(spirv, outMessages)); - parseResult->tintProgram = std::make_unique(std::move(program)); - } else { - if (device->IsValidationEnabled()) { - DAWN_TRY(ValidateSpirv(spirv.data(), spirv.size())); - } - parseResult->spirv = std::move(spirv); - } - break; - } - - case wgpu::SType::ShaderModuleWGSLDescriptor: { - const auto* wgslDesc = - static_cast(chainedDescriptor); - - auto tintSource = std::make_unique("", wgslDesc->source); + const ShaderModuleSPIRVDescriptor* spirvDesc = nullptr; + FindInChain(chainedDescriptor, &spirvDesc); + const ShaderModuleWGSLDescriptor* wgslDesc = nullptr; + FindInChain(chainedDescriptor, &wgslDesc); + if (spirvDesc) { + std::vector spirv(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize); + if (device->IsToggleEnabled(Toggle::UseTintGenerator)) { tint::Program program; - DAWN_TRY_ASSIGN(program, ParseWGSL(&tintSource->file, outMessages)); - - if (device->IsToggleEnabled(Toggle::UseTintGenerator)) { - parseResult->tintProgram = std::make_unique(std::move(program)); - parseResult->tintSource = std::move(tintSource); - } else { - tint::transform::Manager transformManager; - transformManager.Add(); - transformManager.Add(); - - tint::transform::DataMap transformInputs; - - DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, &program, - transformInputs, nullptr, outMessages)); - - std::vector spirv; - DAWN_TRY_ASSIGN(spirv, ModuleToSPIRV(&program)); + DAWN_TRY_ASSIGN(program, ParseSPIRV(spirv, outMessages)); + parseResult->tintProgram = std::make_unique(std::move(program)); + } else { + if (device->IsValidationEnabled()) { DAWN_TRY(ValidateSpirv(spirv.data(), spirv.size())); - - parseResult->spirv = std::move(spirv); } - break; + parseResult->spirv = std::move(spirv); + } + } else if (wgslDesc) { + auto tintSource = std::make_unique("", wgslDesc->source); + + tint::Program program; + DAWN_TRY_ASSIGN(program, ParseWGSL(&tintSource->file, outMessages)); + + if (device->IsToggleEnabled(Toggle::UseTintGenerator)) { + parseResult->tintProgram = std::make_unique(std::move(program)); + parseResult->tintSource = std::move(tintSource); + } else { + tint::transform::Manager transformManager; + transformManager.Add(); + transformManager.Add(); + + tint::transform::DataMap transformInputs; + + DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, &program, + transformInputs, nullptr, outMessages)); + + std::vector spirv; + DAWN_TRY_ASSIGN(spirv, ModuleToSPIRV(&program)); + DAWN_TRY(ValidateSpirv(spirv.data(), spirv.size())); + + parseResult->spirv = std::move(spirv); } - default: - return DAWN_VALIDATION_ERROR("Unsupported sType"); } return {}; @@ -1216,23 +1208,18 @@ namespace dawn_native { ShaderModuleBase::ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor) : CachedObject(device), mType(Type::Undefined) { ASSERT(descriptor->nextInChain != nullptr); - switch (descriptor->nextInChain->sType) { - case wgpu::SType::ShaderModuleSPIRVDescriptor: { - mType = Type::Spirv; - const auto* spirvDesc = - static_cast(descriptor->nextInChain); - mOriginalSpirv.assign(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize); - break; - } - case wgpu::SType::ShaderModuleWGSLDescriptor: { - mType = Type::Wgsl; - const auto* wgslDesc = - static_cast(descriptor->nextInChain); - mWgsl = std::string(wgslDesc->source); - break; - } - default: - UNREACHABLE(); + const ShaderModuleSPIRVDescriptor* spirvDesc = nullptr; + FindInChain(descriptor->nextInChain, &spirvDesc); + const ShaderModuleWGSLDescriptor* wgslDesc = nullptr; + FindInChain(descriptor->nextInChain, &wgslDesc); + ASSERT(spirvDesc || wgslDesc); + + if (spirvDesc) { + mType = Type::Spirv; + mOriginalSpirv.assign(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize); + } else if (wgslDesc) { + mType = Type::Wgsl; + mWgsl = std::string(wgslDesc->source); } } diff --git a/src/dawn_native/Surface.cpp b/src/dawn_native/Surface.cpp index 4afe05ed45..9b317bcdb8 100644 --- a/src/dawn_native/Surface.cpp +++ b/src/dawn_native/Surface.cpp @@ -15,6 +15,7 @@ #include "dawn_native/Surface.h" #include "common/Platform.h" +#include "dawn_native/ChainUtils_autogen.h" #include "dawn_native/Instance.h" #include "dawn_native/SwapChain.h" @@ -34,75 +35,60 @@ namespace dawn_native { MaybeError ValidateSurfaceDescriptor(const InstanceBase* instance, const SurfaceDescriptor* descriptor) { - // TODO(cwallez@chromium.org): Have some type of helper to iterate over all the chained - // structures. if (descriptor->nextInChain == nullptr) { return DAWN_VALIDATION_ERROR("Surface cannot be created with just the base descriptor"); } - const ChainedStruct* chainedDescriptor = descriptor->nextInChain; - if (chainedDescriptor->nextInChain != nullptr) { - return DAWN_VALIDATION_ERROR("Cannot specify two windows for a single surface"); - } + DAWN_TRY(ValidateSingleSType(descriptor->nextInChain, + wgpu::SType::SurfaceDescriptorFromMetalLayer, + wgpu::SType::SurfaceDescriptorFromWindowsHWND, + wgpu::SType::SurfaceDescriptorFromXlib)); - switch (chainedDescriptor->sType) { #if defined(DAWN_ENABLE_BACKEND_METAL) - case wgpu::SType::SurfaceDescriptorFromMetalLayer: { - const SurfaceDescriptorFromMetalLayer* metalDesc = - static_cast(chainedDescriptor); - - // Check that the layer is a CAMetalLayer (or a derived class). - if (!InheritsFromCAMetalLayer(metalDesc->layer)) { - return DAWN_VALIDATION_ERROR("layer must be a CAMetalLayer"); - } - break; - } + const SurfaceDescriptorFromMetalLayer* metalDesc = nullptr; + FindInChain(descriptor->nextInChain, &metalDesc); + if (!metalDesc) { + return DAWN_VALIDATION_ERROR("Unsupported sType"); + } + // Check that the layer is a CAMetalLayer (or a derived class). + if (!InheritsFromCAMetalLayer(metalDesc->layer)) { + return DAWN_VALIDATION_ERROR("layer must be a CAMetalLayer"); + } #endif // defined(DAWN_ENABLE_BACKEND_METAL) #if defined(DAWN_PLATFORM_WINDOWS) - case wgpu::SType::SurfaceDescriptorFromWindowsHWND: { - const SurfaceDescriptorFromWindowsHWND* hwndDesc = - static_cast(chainedDescriptor); - - // It is not possible to validate an HINSTANCE. - - // Validate the hwnd using the windows.h IsWindow function. - if (IsWindow(static_cast(hwndDesc->hwnd)) == 0) { - return DAWN_VALIDATION_ERROR("Invalid HWND"); - } - break; - } + const SurfaceDescriptorFromWindowsHWND* hwndDesc = nullptr; + FindInChain(descriptor->nextInChain, &hwndDesc); + if (!hwndDesc) { + return DAWN_VALIDATION_ERROR("Unsupported sType"); + } + // Validate the hwnd using the windows.h IsWindow function. + if (IsWindow(static_cast(hwndDesc->hwnd)) == 0) { + return DAWN_VALIDATION_ERROR("Invalid HWND"); + } #endif // defined(DAWN_PLATFORM_WINDOWS) #if defined(DAWN_USE_X11) - case wgpu::SType::SurfaceDescriptorFromXlib: { - const SurfaceDescriptorFromXlib* xDesc = - static_cast(chainedDescriptor); - - // It is not possible to validate an X Display. - - // Check the validity of the window by calling a getter function on the window that - // returns a status code. If the window is bad the call return a status of zero. We - // need to set a temporary X11 error handler while doing this because the default - // X11 error handler exits the program on any error. - XErrorHandler oldErrorHandler = - XSetErrorHandler([](Display*, XErrorEvent*) { return 0; }); - XWindowAttributes attributes; - int status = XGetWindowAttributes(reinterpret_cast(xDesc->display), - xDesc->window, &attributes); - XSetErrorHandler(oldErrorHandler); - - if (status == 0) { - return DAWN_VALIDATION_ERROR("Invalid X Window"); - } - break; - } -#endif // defined(DAWN_USE_X11) - - case wgpu::SType::SurfaceDescriptorFromCanvasHTMLSelector: - default: - return DAWN_VALIDATION_ERROR("Unsupported sType"); + const SurfaceDescriptorFromXlib* xDesc = nullptr; + FindInChain(descriptor->nextInChain, &xDesc); + if (!xDesc) { + return DAWN_VALIDATION_ERROR("Unsupported sType"); } + // Check the validity of the window by calling a getter function on the window that + // returns a status code. If the window is bad the call return a status of zero. We + // need to set a temporary X11 error handler while doing this because the default + // X11 error handler exits the program on any error. + XErrorHandler oldErrorHandler = + XSetErrorHandler([](Display*, XErrorEvent*) { return 0; }); + XWindowAttributes attributes; + int status = XGetWindowAttributes(reinterpret_cast(xDesc->display), + xDesc->window, &attributes); + XSetErrorHandler(oldErrorHandler); + + if (status == 0) { + return DAWN_VALIDATION_ERROR("Invalid X Window"); + } +#endif // defined(DAWN_USE_X11) return {}; } @@ -110,37 +96,24 @@ namespace dawn_native { Surface::Surface(InstanceBase* instance, const SurfaceDescriptor* descriptor) : mInstance(instance) { ASSERT(descriptor->nextInChain != nullptr); - const ChainedStruct* chainedDescriptor = descriptor->nextInChain; - - switch (chainedDescriptor->sType) { - case wgpu::SType::SurfaceDescriptorFromMetalLayer: { - const SurfaceDescriptorFromMetalLayer* metalDesc = - static_cast(chainedDescriptor); - mType = Type::MetalLayer; - mMetalLayer = metalDesc->layer; - break; - } - - case wgpu::SType::SurfaceDescriptorFromWindowsHWND: { - const SurfaceDescriptorFromWindowsHWND* hwndDesc = - static_cast(chainedDescriptor); - mType = Type::WindowsHWND; - mHInstance = hwndDesc->hinstance; - mHWND = hwndDesc->hwnd; - break; - } - - case wgpu::SType::SurfaceDescriptorFromXlib: { - const SurfaceDescriptorFromXlib* xDesc = - static_cast(chainedDescriptor); - mType = Type::Xlib; - mXDisplay = xDesc->display; - mXWindow = xDesc->window; - break; - } - - default: - UNREACHABLE(); + const SurfaceDescriptorFromMetalLayer* metalDesc = nullptr; + const SurfaceDescriptorFromWindowsHWND* hwndDesc = nullptr; + const SurfaceDescriptorFromXlib* xDesc = nullptr; + FindInChain(descriptor->nextInChain, &metalDesc); + FindInChain(descriptor->nextInChain, &hwndDesc); + FindInChain(descriptor->nextInChain, &xDesc); + ASSERT(metalDesc || hwndDesc || xDesc); + if (metalDesc) { + mType = Type::MetalLayer; + mMetalLayer = metalDesc->layer; + } else if (hwndDesc) { + mType = Type::WindowsHWND; + mHInstance = hwndDesc->hinstance; + mHWND = hwndDesc->hwnd; + } else if (xDesc) { + mType = Type::Xlib; + mXDisplay = xDesc->display; + mXWindow = xDesc->window; } } diff --git a/src/tests/BUILD.gn b/src/tests/BUILD.gn index faf4b224f9..284cc41cda 100644 --- a/src/tests/BUILD.gn +++ b/src/tests/BUILD.gn @@ -156,6 +156,7 @@ test("dawn_unittests") { "unittests/BitSetIteratorTests.cpp", "unittests/BuddyAllocatorTests.cpp", "unittests/BuddyMemoryAllocatorTests.cpp", + "unittests/ChainUtilsTests.cpp", "unittests/CommandAllocatorTests.cpp", "unittests/EnumClassBitmasksTests.cpp", "unittests/EnumMaskIteratorTests.cpp", diff --git a/src/tests/unittests/ChainUtilsTests.cpp b/src/tests/unittests/ChainUtilsTests.cpp new file mode 100644 index 0000000000..2d437298f8 --- /dev/null +++ b/src/tests/unittests/ChainUtilsTests.cpp @@ -0,0 +1,181 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "dawn_native/ChainUtils_autogen.h" +#include "dawn_native/dawn_platform.h" + +// Checks that we cannot find any structs in an empty chain +TEST(ChainUtilsTests, FindEmptyChain) { + const dawn_native::PrimitiveDepthClampingState* info = nullptr; + dawn_native::FindInChain(nullptr, &info); + + ASSERT_EQ(nullptr, info); +} + +// Checks that searching a chain for a present struct returns that struct +TEST(ChainUtilsTests, FindPresentInChain) { + dawn_native::PrimitiveDepthClampingState chain1; + dawn_native::ShaderModuleSPIRVDescriptor chain2; + chain1.nextInChain = &chain2; + const dawn_native::PrimitiveDepthClampingState* info1 = nullptr; + const dawn_native::ShaderModuleSPIRVDescriptor* info2 = nullptr; + dawn_native::FindInChain(&chain1, &info1); + dawn_native::FindInChain(&chain1, &info2); + + ASSERT_NE(nullptr, info1); + ASSERT_NE(nullptr, info2); +} + +// Checks that searching a chain for a struct that doesn't exist returns a nullptr +TEST(ChainUtilsTests, FindMissingInChain) { + dawn_native::PrimitiveDepthClampingState chain1; + dawn_native::ShaderModuleSPIRVDescriptor chain2; + chain1.nextInChain = &chain2; + const dawn_native::SurfaceDescriptorFromMetalLayer* info = nullptr; + dawn_native::FindInChain(&chain1, &info); + + ASSERT_EQ(nullptr, info); +} + +// Checks that validation rejects chains with duplicate STypes +TEST(ChainUtilsTests, ValidateDuplicateSTypes) { + dawn_native::PrimitiveDepthClampingState chain1; + dawn_native::ShaderModuleSPIRVDescriptor chain2; + dawn_native::PrimitiveDepthClampingState chain3; + chain1.nextInChain = &chain2; + chain2.nextInChain = &chain3; + + dawn_native::MaybeError result = dawn_native::ValidateSTypes(&chain1, {}); + ASSERT_TRUE(result.IsError()); + result.AcquireError(); +} + +// Checks that validation rejects chains that contain unspecified STypes +TEST(ChainUtilsTests, ValidateUnspecifiedSTypes) { + dawn_native::PrimitiveDepthClampingState chain1; + dawn_native::ShaderModuleSPIRVDescriptor chain2; + dawn_native::ShaderModuleWGSLDescriptor chain3; + chain1.nextInChain = &chain2; + chain2.nextInChain = &chain3; + + dawn_native::MaybeError result = dawn_native::ValidateSTypes(&chain1, { + {wgpu::SType::PrimitiveDepthClampingState}, + {wgpu::SType::ShaderModuleSPIRVDescriptor}, + }); + ASSERT_TRUE(result.IsError()); + result.AcquireError(); +} + +// Checks that validation rejects chains that contain multiple STypes from the same oneof +// constraint. +TEST(ChainUtilsTests, ValidateOneOfFailure) { + dawn_native::PrimitiveDepthClampingState chain1; + dawn_native::ShaderModuleSPIRVDescriptor chain2; + dawn_native::ShaderModuleWGSLDescriptor chain3; + chain1.nextInChain = &chain2; + chain2.nextInChain = &chain3; + + dawn_native::MaybeError result = dawn_native::ValidateSTypes(&chain1, + {{wgpu::SType::ShaderModuleSPIRVDescriptor, wgpu::SType::ShaderModuleWGSLDescriptor}}); + ASSERT_TRUE(result.IsError()); + result.AcquireError(); +} + +// Checks that validation accepts chains that match the constraints. +TEST(ChainUtilsTests, ValidateSuccess) { + dawn_native::PrimitiveDepthClampingState chain1; + dawn_native::ShaderModuleSPIRVDescriptor chain2; + chain1.nextInChain = &chain2; + + dawn_native::MaybeError result = dawn_native::ValidateSTypes(&chain1, { + {wgpu::SType::ShaderModuleSPIRVDescriptor, wgpu::SType::ShaderModuleWGSLDescriptor}, + {wgpu::SType::PrimitiveDepthClampingState}, + {wgpu::SType::SurfaceDescriptorFromMetalLayer}, + }); + ASSERT_TRUE(result.IsSuccess()); +} + +// Checks that validation always passes on empty chains. +TEST(ChainUtilsTests, ValidateEmptyChain) { + dawn_native::MaybeError result = dawn_native::ValidateSTypes(nullptr, { + {wgpu::SType::ShaderModuleSPIRVDescriptor}, + {wgpu::SType::PrimitiveDepthClampingState}, + }); + ASSERT_TRUE(result.IsSuccess()); + + result = dawn_native::ValidateSTypes(nullptr, {}); + ASSERT_TRUE(result.IsSuccess()); +} + +// Checks that singleton validation always passes on empty chains. +TEST(ChainUtilsTests, ValidateSingleEmptyChain) { + dawn_native::MaybeError result = dawn_native::ValidateSingleSType(nullptr, + wgpu::SType::ShaderModuleSPIRVDescriptor); + ASSERT_TRUE(result.IsSuccess()); + + result = dawn_native::ValidateSingleSType(nullptr, + wgpu::SType::ShaderModuleSPIRVDescriptor, wgpu::SType::PrimitiveDepthClampingState); + ASSERT_TRUE(result.IsSuccess()); +} + +// Checks that singleton validation always fails on chains with multiple children. +TEST(ChainUtilsTests, ValidateSingleMultiChain) { + dawn_native::PrimitiveDepthClampingState chain1; + dawn_native::ShaderModuleSPIRVDescriptor chain2; + chain1.nextInChain = &chain2; + + dawn_native::MaybeError result = dawn_native::ValidateSingleSType(&chain1, + wgpu::SType::PrimitiveDepthClampingState); + ASSERT_TRUE(result.IsError()); + result.AcquireError(); + + result = dawn_native::ValidateSingleSType(&chain1, + wgpu::SType::PrimitiveDepthClampingState, wgpu::SType::ShaderModuleSPIRVDescriptor); + ASSERT_TRUE(result.IsError()); + result.AcquireError(); +} + +// Checks that singleton validation passes when the oneof constraint is met. +TEST(ChainUtilsTests, ValidateSingleSatisfied) { + dawn_native::ShaderModuleWGSLDescriptor chain1; + + dawn_native::MaybeError result = dawn_native::ValidateSingleSType(&chain1, + wgpu::SType::ShaderModuleWGSLDescriptor); + ASSERT_TRUE(result.IsSuccess()); + + result = dawn_native::ValidateSingleSType(&chain1, + wgpu::SType::ShaderModuleSPIRVDescriptor, wgpu::SType::ShaderModuleWGSLDescriptor); + ASSERT_TRUE(result.IsSuccess()); + + result = dawn_native::ValidateSingleSType(&chain1, + wgpu::SType::ShaderModuleWGSLDescriptor, wgpu::SType::ShaderModuleSPIRVDescriptor); + ASSERT_TRUE(result.IsSuccess()); +} + +// Checks that singleton validation passes when the oneof constraint is not met. +TEST(ChainUtilsTests, ValidateSingleUnsatisfied) { + dawn_native::PrimitiveDepthClampingState chain1; + + dawn_native::MaybeError result = dawn_native::ValidateSingleSType(&chain1, + wgpu::SType::ShaderModuleWGSLDescriptor); + ASSERT_TRUE(result.IsError()); + result.AcquireError(); + + result = dawn_native::ValidateSingleSType(&chain1, + wgpu::SType::ShaderModuleSPIRVDescriptor, wgpu::SType::ShaderModuleWGSLDescriptor); + ASSERT_TRUE(result.IsError()); + result.AcquireError(); +}