From 0671fe28bf90fd19b030ab2e3a3eb7274d6d88b7 Mon Sep 17 00:00:00 2001 From: Austin Eng Date: Thu, 4 Aug 2022 01:12:56 +0000 Subject: [PATCH] Cache WGSL -> MSL compilation Bug: dawn:1480 Change-Id: Ie2ef7860b38d7f350c99cf2c5451299b23413ec6 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/97882 Commit-Queue: Austin Eng Reviewed-by: Corentin Wallez Reviewed-by: Loko Kung --- src/dawn/native/BUILD.gn | 1 + src/dawn/native/CMakeLists.txt | 1 + src/dawn/native/CacheRequest.h | 29 +- src/dawn/native/ShaderModule.cpp | 137 -------- src/dawn/native/ShaderModule.h | 10 - src/dawn/native/StreamImplTint.cpp | 37 +++ src/dawn/native/TintUtils.cpp | 130 ++++++++ src/dawn/native/TintUtils.h | 15 +- src/dawn/native/VisitableMembers.h | 61 ++++ src/dawn/native/d3d12/ShaderModuleD3D12.cpp | 7 +- src/dawn/native/metal/ShaderModuleMTL.h | 9 - src/dawn/native/metal/ShaderModuleMTL.mm | 308 ++++++++++++------ src/dawn/native/opengl/ShaderModuleGL.cpp | 7 +- src/dawn/native/stream/Stream.h | 23 ++ src/dawn/native/vulkan/ShaderModuleVk.cpp | 2 +- .../tests/end2end/PipelineCachingTests.cpp | 21 +- .../tests/unittests/native/StreamTests.cpp | 53 +++ 17 files changed, 569 insertions(+), 282 deletions(-) create mode 100644 src/dawn/native/VisitableMembers.h diff --git a/src/dawn/native/BUILD.gn b/src/dawn/native/BUILD.gn index 45fec76b20..03fd9adbf9 100644 --- a/src/dawn/native/BUILD.gn +++ b/src/dawn/native/BUILD.gn @@ -333,6 +333,7 @@ source_set("sources") { "UsageValidationMode.h", "VertexFormat.cpp", "VertexFormat.h", + "VisitableMembers.h", "dawn_platform.h", "stream/BlobSource.cpp", "stream/BlobSource.h", diff --git a/src/dawn/native/CMakeLists.txt b/src/dawn/native/CMakeLists.txt index 66e1de036f..434ad66743 100644 --- a/src/dawn/native/CMakeLists.txt +++ b/src/dawn/native/CMakeLists.txt @@ -190,6 +190,7 @@ target_sources(dawn_native PRIVATE "UsageValidationMode.h" "VertexFormat.cpp" "VertexFormat.h" + "VisitableMembers.h" "dawn_platform.h" "webgpu_absl_format.cpp" "webgpu_absl_format.h" diff --git a/src/dawn/native/CacheRequest.h b/src/dawn/native/CacheRequest.h index df8af4a65d..2999449b7d 100644 --- a/src/dawn/native/CacheRequest.h +++ b/src/dawn/native/CacheRequest.h @@ -26,6 +26,7 @@ #include "dawn/native/CacheResult.h" #include "dawn/native/Device.h" #include "dawn/native/Error.h" +#include "dawn/native/VisitableMembers.h" namespace dawn::native { @@ -94,6 +95,15 @@ class CacheRequestImpl { CacheRequestImpl(const CacheRequestImpl&) = delete; CacheRequestImpl& operator=(const CacheRequestImpl&) = delete; + // Create a CacheKey from the request type and all members + CacheKey CreateCacheKey(const DeviceBase* device) const { + CacheKey key = device->GetCacheKey(); + StreamIn(&key, Request::kName); + static_cast(this)->VisitAll( + [&](const auto&... members) { StreamIn(&key, members...); }); + return key; + } + template friend auto LoadOrRun(DeviceBase* device, Request&& r, @@ -168,19 +178,12 @@ class CacheRequestImpl { // X(Bar, bar) // DAWN_MAKE_CACHE_REQUEST(MyCacheRequest, REQUEST_MEMBERS) // #undef REQUEST_MEMBERS -#define DAWN_MAKE_CACHE_REQUEST(Request, MEMBERS) \ - class Request : public ::dawn::native::CacheRequestImpl { \ - public: \ - Request() = default; \ - MEMBERS(DAWN_INTERNAL_CACHE_REQUEST_DECL_STRUCT_MEMBER) \ - \ - /* Create a CacheKey from the request type and all members */ \ - ::dawn::native::CacheKey CreateCacheKey(const ::dawn::native::DeviceBase* device) const { \ - ::dawn::native::CacheKey key = device->GetCacheKey(); \ - StreamIn(&key, #Request); \ - MEMBERS(DAWN_INTERNAL_CACHE_REQUEST_RECORD_KEY) \ - return key; \ - } \ +#define DAWN_MAKE_CACHE_REQUEST(Request, MEMBERS) \ + class Request : public ::dawn::native::CacheRequestImpl { \ + public: \ + static constexpr char kName[] = #Request; \ + Request() = default; \ + DAWN_VISITABLE_MEMBERS(MEMBERS) \ }; // Helper macro for the common pattern of DAWN_TRY_ASSIGN around LoadOrRun. diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp index d06665045c..4c81bdde64 100644 --- a/src/dawn/native/ShaderModule.cpp +++ b/src/dawn/native/ShaderModule.cpp @@ -37,87 +37,6 @@ namespace dawn::native { namespace { -tint::transform::VertexFormat ToTintVertexFormat(wgpu::VertexFormat format) { - switch (format) { - case wgpu::VertexFormat::Uint8x2: - return tint::transform::VertexFormat::kUint8x2; - case wgpu::VertexFormat::Uint8x4: - return tint::transform::VertexFormat::kUint8x4; - case wgpu::VertexFormat::Sint8x2: - return tint::transform::VertexFormat::kSint8x2; - case wgpu::VertexFormat::Sint8x4: - return tint::transform::VertexFormat::kSint8x4; - case wgpu::VertexFormat::Unorm8x2: - return tint::transform::VertexFormat::kUnorm8x2; - case wgpu::VertexFormat::Unorm8x4: - return tint::transform::VertexFormat::kUnorm8x4; - case wgpu::VertexFormat::Snorm8x2: - return tint::transform::VertexFormat::kSnorm8x2; - case wgpu::VertexFormat::Snorm8x4: - return tint::transform::VertexFormat::kSnorm8x4; - case wgpu::VertexFormat::Uint16x2: - return tint::transform::VertexFormat::kUint16x2; - case wgpu::VertexFormat::Uint16x4: - return tint::transform::VertexFormat::kUint16x4; - case wgpu::VertexFormat::Sint16x2: - return tint::transform::VertexFormat::kSint16x2; - case wgpu::VertexFormat::Sint16x4: - return tint::transform::VertexFormat::kSint16x4; - case wgpu::VertexFormat::Unorm16x2: - return tint::transform::VertexFormat::kUnorm16x2; - case wgpu::VertexFormat::Unorm16x4: - return tint::transform::VertexFormat::kUnorm16x4; - case wgpu::VertexFormat::Snorm16x2: - return tint::transform::VertexFormat::kSnorm16x2; - case wgpu::VertexFormat::Snorm16x4: - return tint::transform::VertexFormat::kSnorm16x4; - case wgpu::VertexFormat::Float16x2: - return tint::transform::VertexFormat::kFloat16x2; - case wgpu::VertexFormat::Float16x4: - return tint::transform::VertexFormat::kFloat16x4; - case wgpu::VertexFormat::Float32: - return tint::transform::VertexFormat::kFloat32; - case wgpu::VertexFormat::Float32x2: - return tint::transform::VertexFormat::kFloat32x2; - case wgpu::VertexFormat::Float32x3: - return tint::transform::VertexFormat::kFloat32x3; - case wgpu::VertexFormat::Float32x4: - return tint::transform::VertexFormat::kFloat32x4; - case wgpu::VertexFormat::Uint32: - return tint::transform::VertexFormat::kUint32; - case wgpu::VertexFormat::Uint32x2: - return tint::transform::VertexFormat::kUint32x2; - case wgpu::VertexFormat::Uint32x3: - return tint::transform::VertexFormat::kUint32x3; - case wgpu::VertexFormat::Uint32x4: - return tint::transform::VertexFormat::kUint32x4; - case wgpu::VertexFormat::Sint32: - return tint::transform::VertexFormat::kSint32; - case wgpu::VertexFormat::Sint32x2: - return tint::transform::VertexFormat::kSint32x2; - case wgpu::VertexFormat::Sint32x3: - return tint::transform::VertexFormat::kSint32x3; - case wgpu::VertexFormat::Sint32x4: - return tint::transform::VertexFormat::kSint32x4; - - case wgpu::VertexFormat::Undefined: - break; - } - UNREACHABLE(); -} - -tint::transform::VertexStepMode ToTintVertexStepMode(wgpu::VertexStepMode mode) { - switch (mode) { - case wgpu::VertexStepMode::Vertex: - return tint::transform::VertexStepMode::kVertex; - case wgpu::VertexStepMode::Instance: - return tint::transform::VertexStepMode::kInstance; - case wgpu::VertexStepMode::VertexBufferNotUsed: - break; - } - UNREACHABLE(); -} - ResultOrError TintPipelineStageToShaderStage( tint::inspector::PipelineStage stage) { switch (stage) { @@ -1088,39 +1007,6 @@ ResultOrError RunTransforms(tint::transform::Transform* transform return std::move(output.program); } -void AddVertexPullingTransformConfig(const RenderPipelineBase& renderPipeline, - const std::string& entryPoint, - BindGroupIndex pullingBufferBindingSet, - tint::transform::DataMap* transformInputs) { - tint::transform::VertexPulling::Config cfg; - cfg.entry_point_name = entryPoint; - cfg.pulling_group = static_cast(pullingBufferBindingSet); - - cfg.vertex_state.resize(renderPipeline.GetVertexBufferCount()); - for (VertexBufferSlot slot : IterateBitSet(renderPipeline.GetVertexBufferSlotsUsed())) { - const VertexBufferInfo& dawnInfo = renderPipeline.GetVertexBuffer(slot); - tint::transform::VertexBufferLayoutDescriptor* tintInfo = - &cfg.vertex_state[static_cast(slot)]; - - tintInfo->array_stride = dawnInfo.arrayStride; - tintInfo->step_mode = ToTintVertexStepMode(dawnInfo.stepMode); - } - - for (VertexAttributeLocation location : - IterateBitSet(renderPipeline.GetAttributeLocationsUsed())) { - const VertexAttributeInfo& dawnInfo = renderPipeline.GetAttribute(location); - tint::transform::VertexAttributeDescriptor tintInfo; - tintInfo.format = ToTintVertexFormat(dawnInfo.format); - tintInfo.offset = dawnInfo.offset; - tintInfo.shader_location = static_cast(static_cast(location)); - - uint8_t vertexBufferSlot = static_cast(dawnInfo.vertexBufferSlot); - cfg.vertex_state[vertexBufferSlot].attributes.push_back(tintInfo); - } - - transformInputs->Add(cfg); -} - MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device, const EntryPointMetadata& entryPoint, const PipelineLayoutBase* layout) { @@ -1306,29 +1192,6 @@ OwnedCompilationMessages* ShaderModuleBase::GetCompilationMessages() const { return mCompilationMessages.get(); } -// static -void ShaderModuleBase::AddExternalTextureTransform(const PipelineLayoutBase* layout, - tint::transform::Manager* transformManager, - tint::transform::DataMap* transformInputs) { - tint::transform::MultiplanarExternalTexture::BindingsMap newBindingsMap; - for (BindGroupIndex i : IterateBitSet(layout->GetBindGroupLayoutsMask())) { - const BindGroupLayoutBase* bgl = layout->GetBindGroupLayout(i); - - for (const auto& expansion : bgl->GetExternalTextureBindingExpansionMap()) { - newBindingsMap[{static_cast(i), - static_cast(expansion.second.plane0)}] = { - {static_cast(i), static_cast(expansion.second.plane1)}, - {static_cast(i), static_cast(expansion.second.params)}}; - } - } - - if (!newBindingsMap.empty()) { - transformManager->Add(); - transformInputs->Add( - newBindingsMap); - } -} - MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult, OwnedCompilationMessages* compilationMessages) { mTintProgram = std::move(parseResult->tintProgram); diff --git a/src/dawn/native/ShaderModule.h b/src/dawn/native/ShaderModule.h index d1e450caec..170b388000 100644 --- a/src/dawn/native/ShaderModule.h +++ b/src/dawn/native/ShaderModule.h @@ -116,12 +116,6 @@ ResultOrError RunTransforms(tint::transform::Transform* transform tint::transform::DataMap* outputs, OwnedCompilationMessages* messages); -/// Creates and adds the tint::transform::VertexPulling::Config to transformInputs. -void AddVertexPullingTransformConfig(const RenderPipelineBase& renderPipeline, - const std::string& entryPoint, - BindGroupIndex pullingBufferBindingSet, - tint::transform::DataMap* transformInputs); - // Mirrors wgpu::SamplerBindingLayout but instead stores a single boolean // for isComparison instead of a wgpu::SamplerBindingType enum. struct ShaderSamplerBindingInfo { @@ -295,10 +289,6 @@ class ShaderModuleBase : public ApiObjectBase, public CachedObject { MaybeError InitializeBase(ShaderModuleParseResult* parseResult, OwnedCompilationMessages* compilationMessages); - static void AddExternalTextureTransform(const PipelineLayoutBase* layout, - tint::transform::Manager* transformManager, - tint::transform::DataMap* transformInputs); - private: ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag); diff --git a/src/dawn/native/StreamImplTint.cpp b/src/dawn/native/StreamImplTint.cpp index e378ca7ce9..1c0c04de43 100644 --- a/src/dawn/native/StreamImplTint.cpp +++ b/src/dawn/native/StreamImplTint.cpp @@ -59,4 +59,41 @@ void stream::Stream::Write( StreamIn(sink, points.plane_1, points.params); } +template <> +void stream::Stream::Write( + stream::Sink* sink, + const tint::transform::VertexPulling::Config& cfg) { + StreamIn(sink, cfg.entry_point_name, cfg.vertex_state, cfg.pulling_group); +} + +template <> +void stream::Stream::Write( + stream::Sink* sink, + const tint::transform::VertexBufferLayoutDescriptor& layout) { + using Layout = tint::transform::VertexBufferLayoutDescriptor; + static_assert(offsetof(Layout, array_stride) == 0, + "Please update serialization for tint::transform::VertexBufferLayoutDescriptor"); + static_assert(offsetof(Layout, step_mode) == 4, + "Please update serialization for tint::transform::VertexBufferLayoutDescriptor"); + static_assert(offsetof(Layout, attributes) == 8, + "Please update serialization for tint::transform::VertexBufferLayoutDescriptor"); + StreamIn(sink, layout.array_stride, layout.step_mode, layout.attributes); +} + +template <> +void stream::Stream::Write( + stream::Sink* sink, + const tint::transform::VertexAttributeDescriptor& attrib) { + using Attrib = tint::transform::VertexAttributeDescriptor; + static_assert(offsetof(Attrib, format) == 0, + "Please update serialization for tint::transform::VertexAttributeDescriptor"); + static_assert(offsetof(Attrib, offset) == 4, + "Please update serialization for tint::transform::VertexAttributeDescriptor"); + static_assert(offsetof(Attrib, shader_location) == 8, + "Please update serialization for tint::transform::VertexAttributeDescriptor"); + static_assert(sizeof(Attrib) == 12, + "Please update serialization for tint::transform::VertexAttributeDescriptor"); + StreamIn(sink, attrib.format, attrib.offset, attrib.shader_location); +} + } // namespace dawn::native diff --git a/src/dawn/native/TintUtils.cpp b/src/dawn/native/TintUtils.cpp index f66022c9d5..c1585a57b6 100644 --- a/src/dawn/native/TintUtils.cpp +++ b/src/dawn/native/TintUtils.cpp @@ -14,7 +14,10 @@ #include "dawn/native/TintUtils.h" +#include "dawn/native/BindGroupLayout.h" #include "dawn/native/Device.h" +#include "dawn/native/PipelineLayout.h" +#include "dawn/native/RenderPipeline.h" #include "tint/tint.h" @@ -35,6 +38,87 @@ bool InitializeTintErrorReporter() { return true; } +tint::transform::VertexFormat ToTintVertexFormat(wgpu::VertexFormat format) { + switch (format) { + case wgpu::VertexFormat::Uint8x2: + return tint::transform::VertexFormat::kUint8x2; + case wgpu::VertexFormat::Uint8x4: + return tint::transform::VertexFormat::kUint8x4; + case wgpu::VertexFormat::Sint8x2: + return tint::transform::VertexFormat::kSint8x2; + case wgpu::VertexFormat::Sint8x4: + return tint::transform::VertexFormat::kSint8x4; + case wgpu::VertexFormat::Unorm8x2: + return tint::transform::VertexFormat::kUnorm8x2; + case wgpu::VertexFormat::Unorm8x4: + return tint::transform::VertexFormat::kUnorm8x4; + case wgpu::VertexFormat::Snorm8x2: + return tint::transform::VertexFormat::kSnorm8x2; + case wgpu::VertexFormat::Snorm8x4: + return tint::transform::VertexFormat::kSnorm8x4; + case wgpu::VertexFormat::Uint16x2: + return tint::transform::VertexFormat::kUint16x2; + case wgpu::VertexFormat::Uint16x4: + return tint::transform::VertexFormat::kUint16x4; + case wgpu::VertexFormat::Sint16x2: + return tint::transform::VertexFormat::kSint16x2; + case wgpu::VertexFormat::Sint16x4: + return tint::transform::VertexFormat::kSint16x4; + case wgpu::VertexFormat::Unorm16x2: + return tint::transform::VertexFormat::kUnorm16x2; + case wgpu::VertexFormat::Unorm16x4: + return tint::transform::VertexFormat::kUnorm16x4; + case wgpu::VertexFormat::Snorm16x2: + return tint::transform::VertexFormat::kSnorm16x2; + case wgpu::VertexFormat::Snorm16x4: + return tint::transform::VertexFormat::kSnorm16x4; + case wgpu::VertexFormat::Float16x2: + return tint::transform::VertexFormat::kFloat16x2; + case wgpu::VertexFormat::Float16x4: + return tint::transform::VertexFormat::kFloat16x4; + case wgpu::VertexFormat::Float32: + return tint::transform::VertexFormat::kFloat32; + case wgpu::VertexFormat::Float32x2: + return tint::transform::VertexFormat::kFloat32x2; + case wgpu::VertexFormat::Float32x3: + return tint::transform::VertexFormat::kFloat32x3; + case wgpu::VertexFormat::Float32x4: + return tint::transform::VertexFormat::kFloat32x4; + case wgpu::VertexFormat::Uint32: + return tint::transform::VertexFormat::kUint32; + case wgpu::VertexFormat::Uint32x2: + return tint::transform::VertexFormat::kUint32x2; + case wgpu::VertexFormat::Uint32x3: + return tint::transform::VertexFormat::kUint32x3; + case wgpu::VertexFormat::Uint32x4: + return tint::transform::VertexFormat::kUint32x4; + case wgpu::VertexFormat::Sint32: + return tint::transform::VertexFormat::kSint32; + case wgpu::VertexFormat::Sint32x2: + return tint::transform::VertexFormat::kSint32x2; + case wgpu::VertexFormat::Sint32x3: + return tint::transform::VertexFormat::kSint32x3; + case wgpu::VertexFormat::Sint32x4: + return tint::transform::VertexFormat::kSint32x4; + + case wgpu::VertexFormat::Undefined: + break; + } + UNREACHABLE(); +} + +tint::transform::VertexStepMode ToTintVertexStepMode(wgpu::VertexStepMode mode) { + switch (mode) { + case wgpu::VertexStepMode::Vertex: + return tint::transform::VertexStepMode::kVertex; + case wgpu::VertexStepMode::Instance: + return tint::transform::VertexStepMode::kInstance; + case wgpu::VertexStepMode::VertexBufferNotUsed: + break; + } + UNREACHABLE(); +} + } // namespace ScopedTintICEHandler::ScopedTintICEHandler(DeviceBase* device) { @@ -53,6 +137,52 @@ ScopedTintICEHandler::~ScopedTintICEHandler() { tlDevice = nullptr; } +tint::transform::MultiplanarExternalTexture::BindingsMap BuildExternalTextureTransformBindings( + const PipelineLayoutBase* layout) { + tint::transform::MultiplanarExternalTexture::BindingsMap newBindingsMap; + for (BindGroupIndex i : IterateBitSet(layout->GetBindGroupLayoutsMask())) { + const BindGroupLayoutBase* bgl = layout->GetBindGroupLayout(i); + for (const auto& [_, expansion] : bgl->GetExternalTextureBindingExpansionMap()) { + newBindingsMap[{static_cast(i), static_cast(expansion.plane0)}] = { + {static_cast(i), static_cast(expansion.plane1)}, + {static_cast(i), static_cast(expansion.params)}}; + } + } + return newBindingsMap; +} + +tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig( + const RenderPipelineBase& renderPipeline, + const std::string_view& entryPoint, + BindGroupIndex pullingBufferBindingSet) { + tint::transform::VertexPulling::Config cfg; + cfg.entry_point_name = entryPoint; + cfg.pulling_group = static_cast(pullingBufferBindingSet); + + cfg.vertex_state.resize(renderPipeline.GetVertexBufferCount()); + for (VertexBufferSlot slot : IterateBitSet(renderPipeline.GetVertexBufferSlotsUsed())) { + const VertexBufferInfo& dawnInfo = renderPipeline.GetVertexBuffer(slot); + tint::transform::VertexBufferLayoutDescriptor* tintInfo = + &cfg.vertex_state[static_cast(slot)]; + + tintInfo->array_stride = dawnInfo.arrayStride; + tintInfo->step_mode = ToTintVertexStepMode(dawnInfo.stepMode); + } + + for (VertexAttributeLocation location : + IterateBitSet(renderPipeline.GetAttributeLocationsUsed())) { + const VertexAttributeInfo& dawnInfo = renderPipeline.GetAttribute(location); + tint::transform::VertexAttributeDescriptor tintInfo; + tintInfo.format = ToTintVertexFormat(dawnInfo.format); + tintInfo.offset = dawnInfo.offset; + tintInfo.shader_location = static_cast(static_cast(location)); + + uint8_t vertexBufferSlot = static_cast(dawnInfo.vertexBufferSlot); + cfg.vertex_state[vertexBufferSlot].attributes.push_back(tintInfo); + } + return cfg; +} + } // namespace dawn::native bool std::less::operator()(const tint::sem::BindingPoint& a, diff --git a/src/dawn/native/TintUtils.h b/src/dawn/native/TintUtils.h index ba2f27e1f1..e17fc691d2 100644 --- a/src/dawn/native/TintUtils.h +++ b/src/dawn/native/TintUtils.h @@ -18,14 +18,15 @@ #include #include "dawn/common/NonCopyable.h" +#include "dawn/native/IntegerTypes.h" -namespace tint::sem { -struct BindingPoint; -} +#include "tint/tint.h" namespace dawn::native { class DeviceBase; +class PipelineLayoutBase; +class RenderPipelineBase; // Indicates that for the lifetime of this object tint internal compiler errors should be // reported to the given device. @@ -38,6 +39,14 @@ class ScopedTintICEHandler : public NonCopyable { ScopedTintICEHandler(ScopedTintICEHandler&&) = delete; }; +tint::transform::MultiplanarExternalTexture::BindingsMap BuildExternalTextureTransformBindings( + const PipelineLayoutBase* layout); + +tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig( + const RenderPipelineBase& renderPipeline, + const std::string_view& entryPoint, + BindGroupIndex pullingBufferBindingSet); + } // namespace dawn::native // std::less operator for std::map containing BindingPoint diff --git a/src/dawn/native/VisitableMembers.h b/src/dawn/native/VisitableMembers.h new file mode 100644 index 0000000000..2debfad548 --- /dev/null +++ b/src/dawn/native/VisitableMembers.h @@ -0,0 +1,61 @@ +// Copyright 2022 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_DAWN_NATIVE_VISITABLEMEMBERS_H_ +#define SRC_DAWN_NATIVE_VISITABLEMEMBERS_H_ + +#include "dawn/native/stream/Stream.h" + +// Helper for X macro to declare a visitable member. +#define DAWN_INTERNAL_VISITABLE_MEMBER_DECL(type, name) type name{}; + +// Helper for X macro for visiting a visitable member. +#define DAWN_INTERNAL_VISITABLE_MEMBER_ARG(type, name) , name + +namespace dawn::native::detail { +constexpr int kInternalVisitableUnusedForComma = 0; +} // namespace dawn::native::detail + +// Helper X macro to declare members of a class or struct, along with Visit +// methods to call a functor for each member. +// Example usage: +// #define MEMBERS(X) \ +// X(int, a) \ +// X(float, b) \ +// X(Foo, foo) \ +// X(Bar, bar) +// struct MyStruct { +// DAWN_VISITABLE_MEMBERS(MEMBERS) +// }; +// #undef MEMBERS +#define DAWN_VISITABLE_MEMBERS(MEMBERS) \ + MEMBERS(DAWN_INTERNAL_VISITABLE_MEMBER_DECL) \ + \ + template \ + constexpr auto VisitAll(V&& visit) const { \ + return [&](int, const auto&... ms) { \ + return visit(ms...); \ + }(::dawn::native::detail::kInternalVisitableUnusedForComma MEMBERS( \ + DAWN_INTERNAL_VISITABLE_MEMBER_ARG)); \ + } \ + \ + template \ + constexpr auto VisitAll(V&& visit) { \ + return [&](int, auto&... ms) { \ + return visit(ms...); \ + }(::dawn::native::detail::kInternalVisitableUnusedForComma MEMBERS( \ + DAWN_INTERNAL_VISITABLE_MEMBER_ARG)); \ + } + +#endif // SRC_DAWN_NATIVE_VISITABLEMEMBERS_H_ diff --git a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp index f64dc9102f..3e09cc00b6 100644 --- a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp +++ b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp @@ -756,7 +756,12 @@ ResultOrError ShaderModule::Compile(const ProgrammableStage& pro const tint::Program* program = GetTintProgram(); tint::Program programAsValue; - AddExternalTextureTransform(layout, &transformManager, &transformInputs); + auto externalTextureBindings = BuildExternalTextureTransformBindings(layout); + if (!externalTextureBindings.empty()) { + transformManager.Add(); + transformInputs.Add( + std::move(externalTextureBindings)); + } if (stage == SingleShaderStage::Vertex) { transformManager.Add(); diff --git a/src/dawn/native/metal/ShaderModuleMTL.h b/src/dawn/native/metal/ShaderModuleMTL.h index 035922384a..27f1def213 100644 --- a/src/dawn/native/metal/ShaderModuleMTL.h +++ b/src/dawn/native/metal/ShaderModuleMTL.h @@ -55,15 +55,6 @@ class ShaderModule final : public ShaderModuleBase { const RenderPipeline* renderPipeline = nullptr); private: - ResultOrError TranslateToMSL(const char* entryPointName, - SingleShaderStage stage, - const PipelineLayout* layout, - uint32_t sampleMask, - const RenderPipeline* renderPipeline, - std::string* remappedEntryPointName, - bool* needsStorageBufferLength, - bool* hasInvariantAttribute, - std::vector* workgroupAllocations); ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor); ~ShaderModule() override; MaybeError Initialize(ShaderModuleParseResult* parseResult, diff --git a/src/dawn/native/metal/ShaderModuleMTL.mm b/src/dawn/native/metal/ShaderModuleMTL.mm index fa8befd2e5..9cb6a13246 100644 --- a/src/dawn/native/metal/ShaderModuleMTL.mm +++ b/src/dawn/native/metal/ShaderModuleMTL.mm @@ -15,10 +15,14 @@ #include "dawn/native/metal/ShaderModuleMTL.h" #include "dawn/native/BindGroupLayout.h" +#include "dawn/native/CacheRequest.h" #include "dawn/native/TintUtils.h" +#include "dawn/native/VisitableMembers.h" #include "dawn/native/metal/DeviceMTL.h" #include "dawn/native/metal/PipelineLayoutMTL.h" #include "dawn/native/metal/RenderPipelineMTL.h" +#include "dawn/native/stream/BlobSource.h" +#include "dawn/native/stream/ByteVectorSink.h" #include "dawn/platform/DawnPlatform.h" #include "dawn/platform/tracing/TraceEvent.h" @@ -26,6 +30,67 @@ #include +namespace dawn::native::metal { +namespace { + +using OptionalVertexPullingTransformConfig = std::optional; + +#define MSL_COMPILATION_REQUEST_MEMBERS(X) \ + X(const tint::Program*, inputProgram) \ + X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \ + X(tint::transform::MultiplanarExternalTexture::BindingsMap, externalTextureBindings) \ + X(OptionalVertexPullingTransformConfig, vertexPullingTransformConfig) \ + X(std::string, entryPointName) \ + X(uint32_t, sampleMask) \ + X(bool, emitVertexPointSize) \ + X(bool, isRobustnessEnabled) \ + X(bool, disableSymbolRenaming) \ + X(bool, disableWorkgroupInit) \ + X(CacheKey::UnsafeUnkeyedValue, tracePlatform) + +DAWN_MAKE_CACHE_REQUEST(MslCompilationRequest, MSL_COMPILATION_REQUEST_MEMBERS); +#undef MSL_COMPILATION_REQUEST_MEMBERS + +using WorkgroupAllocations = std::vector; + +#define MSL_COMPILATION_MEMBERS(X) \ + X(std::string, msl) \ + X(std::string, remappedEntryPointName) \ + X(bool, needsStorageBufferLength) \ + X(bool, hasInvariantAttribute) \ + X(WorkgroupAllocations, workgroupAllocations) + +struct MslCompilation { + static ResultOrError FromBlob(Blob blob); + + DAWN_VISITABLE_MEMBERS(MSL_COMPILATION_MEMBERS) +#undef MSL_COMPILATION_MEMBERS +}; + +} // namespace +} // namespace dawn::native::metal + +namespace dawn::native { + +// Define the implementation to store MslCompilation into the BlobCache. +template <> +void BlobCache::Store(const CacheKey& key, const metal::MslCompilation& c) { + stream::ByteVectorSink sink; + c.VisitAll([&](const auto&... members) { StreamIn(&sink, members...); }); + Store(key, CreateBlob(std::move(sink))); +} + +// Define the implementation to load MslCompilation from a blob. +// static +ResultOrError metal::MslCompilation::FromBlob(Blob blob) { + stream::BlobSource source(std::move(blob)); + metal::MslCompilation c; + DAWN_TRY(c.VisitAll([&](auto&... members) { return StreamOut(&source, &members...); })); + return c; +} + +} // namespace dawn::native + namespace dawn::native::metal { // static @@ -50,17 +115,16 @@ MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult, return InitializeBase(parseResult, compilationMessages); } -ResultOrError ShaderModule::TranslateToMSL( - const char* entryPointName, - SingleShaderStage stage, - const PipelineLayout* layout, - uint32_t sampleMask, - const RenderPipeline* renderPipeline, - std::string* remappedEntryPointName, - bool* needsStorageBufferLength, - bool* hasInvariantAttribute, - std::vector* workgroupAllocations) { - ScopedTintICEHandler scopedICEHandler(GetDevice()); +namespace { + +ResultOrError> TranslateToMSL(DeviceBase* device, + const tint::Program* inputProgram, + const char* entryPointName, + SingleShaderStage stage, + const PipelineLayout* layout, + uint32_t sampleMask, + const RenderPipeline* renderPipeline) { + ScopedTintICEHandler scopedICEHandler(device); std::ostringstream errorStream; errorStream << "Tint MSL failure:" << std::endl; @@ -69,7 +133,6 @@ ResultOrError ShaderModule::TranslateToMSL( using BindingRemapper = tint::transform::BindingRemapper; using BindingPoint = tint::transform::BindingPoint; BindingRemapper::BindingPoints bindingPoints; - BindingRemapper::AccessControls accessControls; for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { const BindGroupLayoutBase::BindingMap& bindingMap = @@ -93,21 +156,13 @@ ResultOrError ShaderModule::TranslateToMSL( } } - tint::transform::Manager transformManager; - tint::transform::DataMap transformInputs; - - // We only remap bindings for the target entry point, so we need to strip all other entry - // points to avoid generating invalid bindings for them. - transformManager.Add(); - transformInputs.Add(entryPointName); - - AddExternalTextureTransform(layout, &transformManager, &transformInputs); + auto externalTextureBindings = BuildExternalTextureTransformBindings(layout); + std::optional vertexPullingTransformConfig; if (stage == SingleShaderStage::Vertex && - GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) { - transformManager.Add(); - AddVertexPullingTransformConfig(*renderPipeline, entryPointName, kPullingBufferBindingSet, - &transformInputs); + device->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) { + vertexPullingTransformConfig = BuildVertexPullingTransformConfig( + *renderPipeline, entryPointName, kPullingBufferBindingSet); for (VertexBufferSlot slot : IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) { uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(slot); @@ -121,62 +176,127 @@ ResultOrError ShaderModule::TranslateToMSL( } } } - if (GetDevice()->IsRobustnessEnabled()) { - transformManager.Add(); - } - transformManager.Add(); - transformManager.Add(); - if (GetDevice()->IsToggleEnabled(Toggle::DisableSymbolRenaming)) { - // We still need to rename MSL reserved keywords - transformInputs.Add( - tint::transform::Renamer::Target::kMslKeywords); - } - - transformInputs.Add(std::move(bindingPoints), - std::move(accessControls), - /* mayCollide */ true); - - tint::Program program; - tint::transform::DataMap transformOutputs; - { - TRACE_EVENT0(GetDevice()->GetPlatform(), General, "RunTransforms"); - DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, GetTintProgram(), transformInputs, - &transformOutputs, nullptr)); - } - - if (auto* data = transformOutputs.Get()) { - auto it = data->remappings.find(entryPointName); - if (it != data->remappings.end()) { - *remappedEntryPointName = it->second; - } else { - DAWN_INVALID_IF(!GetDevice()->IsToggleEnabled(Toggle::DisableSymbolRenaming), - "Could not find remapped name for entry point."); - - *remappedEntryPointName = entryPointName; - } - } else { - return DAWN_FORMAT_VALIDATION_ERROR("Transform output missing renamer data."); - } - - tint::writer::msl::Options options; - options.buffer_size_ubo_index = kBufferLengthBufferSlot; - options.fixed_sample_mask = sampleMask; - options.disable_workgroup_init = GetDevice()->IsToggleEnabled(Toggle::DisableWorkgroupInit); - options.emit_vertex_point_size = + MslCompilationRequest req = {}; + req.inputProgram = inputProgram; + req.bindingPoints = std::move(bindingPoints); + req.externalTextureBindings = std::move(externalTextureBindings); + req.vertexPullingTransformConfig = std::move(vertexPullingTransformConfig); + req.entryPointName = entryPointName; + req.sampleMask = sampleMask; + req.emitVertexPointSize = stage == SingleShaderStage::Vertex && renderPipeline->GetPrimitiveTopology() == wgpu::PrimitiveTopology::PointList; - TRACE_EVENT0(GetDevice()->GetPlatform(), General, "tint::writer::msl::Generate"); - auto result = tint::writer::msl::Generate(&program, options); - DAWN_INVALID_IF(!result.success, "An error occured while generating MSL: %s.", result.error); + req.isRobustnessEnabled = device->IsRobustnessEnabled(); + req.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming); + req.tracePlatform = UnsafeUnkeyedValue(device->GetPlatform()); - *needsStorageBufferLength = result.needs_storage_buffer_sizes; - *hasInvariantAttribute = result.has_invariant_attribute; - *workgroupAllocations = std::move(result.workgroup_allocations[*remappedEntryPointName]); + CacheResult mslCompilation; + DAWN_TRY_LOAD_OR_RUN( + mslCompilation, device, std::move(req), MslCompilation::FromBlob, + [](MslCompilationRequest r) -> ResultOrError { + tint::transform::Manager transformManager; + tint::transform::DataMap transformInputs; - return std::move(result.msl); + // We only remap bindings for the target entry point, so we need to strip all other + // entry points to avoid generating invalid bindings for them. + transformManager.Add(); + transformInputs.Add(r.entryPointName); + + if (!r.externalTextureBindings.empty()) { + transformManager.Add(); + transformInputs.Add( + std::move(r.externalTextureBindings)); + } + + if (r.vertexPullingTransformConfig) { + transformManager.Add(); + transformInputs.Add( + std::move(r.vertexPullingTransformConfig).value()); + } + + if (r.isRobustnessEnabled) { + transformManager.Add(); + } + transformManager.Add(); + transformInputs.Add(std::move(r.bindingPoints), + BindingRemapper::AccessControls{}, + /* mayCollide */ true); + + transformManager.Add(); + + if (r.disableSymbolRenaming) { + // We still need to rename MSL reserved keywords + transformInputs.Add( + tint::transform::Renamer::Target::kMslKeywords); + } + + tint::Program program; + tint::transform::DataMap transformOutputs; + { + TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "RunTransforms"); + DAWN_TRY_ASSIGN(program, + RunTransforms(&transformManager, r.inputProgram, transformInputs, + &transformOutputs, nullptr)); + } + + std::string remappedEntryPointName; + if (auto* data = transformOutputs.Get()) { + auto it = data->remappings.find(r.entryPointName); + if (it != data->remappings.end()) { + remappedEntryPointName = it->second; + } else { + DAWN_INVALID_IF(!r.disableSymbolRenaming, + "Could not find remapped name for entry point."); + + remappedEntryPointName = r.entryPointName; + } + } else { + return DAWN_FORMAT_VALIDATION_ERROR("Transform output missing renamer data."); + } + + tint::writer::msl::Options options; + options.buffer_size_ubo_index = kBufferLengthBufferSlot; + options.fixed_sample_mask = r.sampleMask; + options.disable_workgroup_init = r.disableWorkgroupInit; + options.emit_vertex_point_size = r.emitVertexPointSize; + TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "tint::writer::msl::Generate"); + auto result = tint::writer::msl::Generate(&program, options); + DAWN_INVALID_IF(!result.success, "An error occured while generating MSL: %s.", + result.error); + + // Metal uses Clang to compile the shader as C++14. Disable everything in the -Wall + // category. -Wunused-variable in particular comes up a lot in generated code, and some + // (old?) Metal drivers accidentally treat it as a MTLLibraryErrorCompileError instead + // of a warning. + result.msl = R"( + #ifdef __clang__ + #pragma clang diagnostic ignored "-Wall" + #endif + )" + result.msl; + + auto workgroupAllocations = + std::move(result.workgroup_allocations[remappedEntryPointName]); + return MslCompilation{ + std::move(result.msl), + std::move(remappedEntryPointName), + result.needs_storage_buffer_sizes, + result.has_invariant_attribute, + std::move(workgroupAllocations), + }; + }); + + if (device->IsToggleEnabled(Toggle::DumpShaders)) { + std::ostringstream dumpedMsg; + dumpedMsg << "/* Dumped generated MSL */" << std::endl << mslCompilation->msl; + device->EmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str()); + } + + return mslCompilation; } +} // namespace + MaybeError ShaderModule::CreateFunction(const char* entryPointName, SingleShaderStage stage, const PipelineLayout* layout, @@ -194,33 +314,17 @@ MaybeError ShaderModule::CreateFunction(const char* entryPointName, ASSERT(renderPipeline != nullptr); } - std::string remappedEntryPointName; - std::string msl; - bool hasInvariantAttribute = false; - DAWN_TRY_ASSIGN(msl, TranslateToMSL(entryPointName, stage, layout, sampleMask, renderPipeline, - &remappedEntryPointName, &out->needsStorageBufferLength, - &hasInvariantAttribute, &out->workgroupAllocations)); + CacheResult mslCompilation; + DAWN_TRY_ASSIGN(mslCompilation, TranslateToMSL(GetDevice(), GetTintProgram(), entryPointName, + stage, layout, sampleMask, renderPipeline)); + out->needsStorageBufferLength = mslCompilation->needsStorageBufferLength; + out->workgroupAllocations = std::move(mslCompilation->workgroupAllocations); - // Metal uses Clang to compile the shader as C++14. Disable everything in the -Wall - // category. -Wunused-variable in particular comes up a lot in generated code, and some - // (old?) Metal drivers accidentally treat it as a MTLLibraryErrorCompileError instead - // of a warning. - msl = R"( -#ifdef __clang__ -#pragma clang diagnostic ignored "-Wall" -#endif -)" + msl; - - if (GetDevice()->IsToggleEnabled(Toggle::DumpShaders)) { - std::ostringstream dumpedMsg; - dumpedMsg << "/* Dumped generated MSL */" << std::endl << msl; - GetDevice()->EmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str()); - } - - NSRef mslSource = AcquireNSRef([[NSString alloc] initWithUTF8String:msl.c_str()]); + NSRef mslSource = + AcquireNSRef([[NSString alloc] initWithUTF8String:mslCompilation->msl.c_str()]); NSRef compileOptions = AcquireNSRef([[MTLCompileOptions alloc] init]); - if (hasInvariantAttribute) { + if (mslCompilation->hasInvariantAttribute) { if (@available(macOS 11.0, iOS 13.0, *)) { (*compileOptions).preserveInvariance = true; } @@ -243,8 +347,8 @@ MaybeError ShaderModule::CreateFunction(const char* entryPointName, } ASSERT(library != nil); - NSRef name = - AcquireNSRef([[NSString alloc] initWithUTF8String:remappedEntryPointName.c_str()]); + NSRef name = AcquireNSRef( + [[NSString alloc] initWithUTF8String:mslCompilation->remappedEntryPointName.c_str()]); { TRACE_EVENT0(GetDevice()->GetPlatform(), General, "MTLLibrary::newFunctionWithName"); @@ -269,6 +373,10 @@ MaybeError ShaderModule::CreateFunction(const char* entryPointName, } } + if (BlobCache* cache = GetDevice()->GetBlobCache()) { + cache->EnsureStored(mslCompilation); + } + if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) && GetEntryPoint(entryPointName).usedVertexInputs.any()) { out->needsStorageBufferLength = true; diff --git a/src/dawn/native/opengl/ShaderModuleGL.cpp b/src/dawn/native/opengl/ShaderModuleGL.cpp index 735b4dfd78..8e85076a9b 100644 --- a/src/dawn/native/opengl/ShaderModuleGL.cpp +++ b/src/dawn/native/opengl/ShaderModuleGL.cpp @@ -90,7 +90,12 @@ ResultOrError ShaderModule::TranslateToGLSL(const char* entryPointN tint::transform::Manager transformManager; tint::transform::DataMap transformInputs; - AddExternalTextureTransform(layout, &transformManager, &transformInputs); + auto externalTextureBindings = BuildExternalTextureTransformBindings(layout); + if (!externalTextureBindings.empty()) { + transformManager.Add(); + transformInputs.Add( + std::move(externalTextureBindings)); + } tint::Program program; DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, GetTintProgram(), transformInputs, diff --git a/src/dawn/native/stream/Stream.h b/src/dawn/native/stream/Stream.h index 416aa02134..34333179f8 100644 --- a/src/dawn/native/stream/Stream.h +++ b/src/dawn/native/stream/Stream.h @@ -23,6 +23,8 @@ #include #include +#include + #include "dawn/common/Platform.h" #include "dawn/common/TypedInteger.h" #include "dawn/native/Error.h" @@ -57,6 +59,14 @@ MaybeError StreamOut(Source* s, T* v) { return Stream::Read(s, v); } +// Helper to take an rvalue passed to StreamOut and forward it as a pointer. +// This makes it possible to pass output wrappers like stream::StructMembers inline. +// For example: `DAWN_TRY(StreamOut(&source, stream::StructMembers(...)));` +template +MaybeError StreamOut(Source* s, T&& v) { + return StreamOut(s, &v); +} + // Helper to call StreamIn on a parameter pack. template constexpr void StreamIn(Sink* s, const T& v, const Ts&... vs) { @@ -187,6 +197,19 @@ class Stream>> { } }; +// Stream specialization for std::optional +template +class Stream> { + public: + static void Write(stream::Sink* sink, const std::optional& t) { + bool hasValue = t.has_value(); + StreamIn(sink, hasValue); + if (hasValue) { + StreamIn(sink, *t); + } + } +}; + // Stream specialization for fixed arrays of fundamental types. template class Stream>> { diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp index 8a4f7f73e9..1b40234937 100644 --- a/src/dawn/native/vulkan/ShaderModuleVk.cpp +++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp @@ -232,7 +232,7 @@ ResultOrError ShaderModule::GetHandleAndSpirv( } // Transform external textures into the binding locations specified in the bgl - // TODO(dawn:1082): Replace this block with ShaderModuleBase::AddExternalTextureTransform. + // TODO(dawn:1082): Replace this block with BuildExternalTextureTransformBindings. tint::transform::MultiplanarExternalTexture::BindingsMap newBindingsMap; for (BindGroupIndex i : IterateBitSet(layout->GetBindGroupLayoutsMask())) { const BindGroupLayoutBase* bgl = layout->GetBindGroupLayout(i); diff --git a/src/dawn/tests/end2end/PipelineCachingTests.cpp b/src/dawn/tests/end2end/PipelineCachingTests.cpp index bed41adbd0..5dcf9180e4 100644 --- a/src/dawn/tests/end2end/PipelineCachingTests.cpp +++ b/src/dawn/tests/end2end/PipelineCachingTests.cpp @@ -108,8 +108,8 @@ class PipelineCachingTests : public DawnTest { const EntryCounts counts = { // pipeline caching is only implemented on D3D12/Vulkan IsD3D12() || IsVulkan() ? 1u : 0u, - // shader module caching is only implemented on Vulkan - IsVulkan() ? 1u : 0u, + // shader module caching is only implemented on Vulkan/Metal + IsVulkan() || IsMetal() ? 1u : 0u, }; NiceMock mMockCache; }; @@ -406,7 +406,7 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheDescriptorNegativeCase { wgpu::Device device = CreateDevice(); utils::ComboRenderPipelineDescriptor desc; - desc.primitive.topology = wgpu::PrimitiveTopology::PointList; + desc.EnableDepthStencil(); desc.vertex.module = utils::CreateShaderModule(device, kVertexShaderDefault.data()); desc.vertex.entryPoint = "main"; desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data()); @@ -586,8 +586,9 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheLayout) { device.CreateRenderPipeline(&desc)); } - // Cache should hit for the shaders, but not for the pipeline. - // The shader is different but compiles to the same due to binding number remapping. + // Cache should not hit for the fragment shader, but should hit for the pipeline. + // Except for D3D12, the shader is different but compiles to the same due to binding number + // remapping. { wgpu::Device device = CreateDevice(); utils::ComboRenderPipelineDescriptor desc; @@ -604,8 +605,14 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheLayout) { {1, wgpu::ShaderStage::Fragment, wgpu::BufferBindingType::Uniform}, }), }); - EXPECT_CACHE_STATS(mMockCache, Hit(2 * counts.shaderModule), Add(counts.pipeline), - device.CreateRenderPipeline(&desc)); + if (!IsD3D12()) { + EXPECT_CACHE_STATS(mMockCache, Hit(counts.shaderModule + counts.pipeline), + Add(counts.shaderModule), device.CreateRenderPipeline(&desc)); + } else { + EXPECT_CACHE_STATS(mMockCache, Hit(counts.shaderModule), + Add(counts.shaderModule + counts.pipeline), + device.CreateRenderPipeline(&desc)); + } } } diff --git a/src/dawn/tests/unittests/native/StreamTests.cpp b/src/dawn/tests/unittests/native/StreamTests.cpp index 964823b8d2..6f257000d4 100644 --- a/src/dawn/tests/unittests/native/StreamTests.cpp +++ b/src/dawn/tests/unittests/native/StreamTests.cpp @@ -22,6 +22,7 @@ #include "dawn/common/TypedInteger.h" #include "dawn/native/Blob.h" +#include "dawn/native/VisitableMembers.h" #include "dawn/native/stream/BlobSource.h" #include "dawn/native/stream/ByteVectorSink.h" #include "dawn/native/stream/Stream.h" @@ -193,6 +194,23 @@ TEST(SerializeTests, StdPair) { EXPECT_CACHE_KEY_EQ(std::make_pair(s, uint32_t(42)), expected); } +// Test that ByteVectorSink serializes std::optional as expected. +TEST(SerializeTests, StdOptional) { + std::string_view s = "webgpu"; + { + ByteVectorSink expected; + StreamIn(&expected, true, s); + + EXPECT_CACHE_KEY_EQ(std::optional(s), expected); + } + { + ByteVectorSink expected; + StreamIn(&expected, false); + + EXPECT_CACHE_KEY_EQ(std::optional(), expected); + } +} + // Test that ByteVectorSink serializes std::unordered_map as expected. TEST(SerializeTests, StdUnorderedMap) { std::unordered_map m; @@ -256,6 +274,41 @@ TEST(StreamTests, SerializeDeserializeParamPack) { EXPECT_EQ(c, cOut); } +#define FOO_MEMBERS(X) \ + X(int, a) \ + X(float, b) \ + X(std::string, c) +struct Foo { + DAWN_VISITABLE_MEMBERS(FOO_MEMBERS) +#undef FOO_MEMBERS +}; + +// Test that serializing then deserializing a struct made with DAWN_VISITABLE_MEMBERS works as +// expected. +TEST(StreamTests, SerializeDeserializeVisitableMembers) { + Foo foo{1, 2, "3"}; + ByteVectorSink sink; + foo.VisitAll([&](const auto&... members) { StreamIn(&sink, members...); }); + + // Test that the serialization is correct. + { + ByteVectorSink expected; + StreamIn(&expected, foo.a, foo.b, foo.c); + EXPECT_THAT(sink, VectorEq(expected)); + } + + // Test that deserialization works for StructMembers, passed inline. + { + BlobSource src(CreateBlob(sink)); + Foo out; + auto err = out.VisitAll([&](auto&... members) { return StreamOut(&src, &members...); }); + EXPECT_FALSE(err.IsError()); + EXPECT_EQ(foo.a, out.a); + EXPECT_EQ(foo.b, out.b); + EXPECT_EQ(foo.c, out.c); + } +} + template std::bitset BitsetFromBitString(const char (&str)[N]) { // N - 1 because the last character is the null terminator.