diff --git a/src/dawn/native/BUILD.gn b/src/dawn/native/BUILD.gn index bcff7d799e..19ca4d4b85 100644 --- a/src/dawn/native/BUILD.gn +++ b/src/dawn/native/BUILD.gn @@ -313,6 +313,7 @@ source_set("sources") { "Sampler.h", "ScratchBuffer.cpp", "ScratchBuffer.h", + "Serializable.h", "ShaderModule.cpp", "ShaderModule.h", "StagingBuffer.cpp", diff --git a/src/dawn/native/BlobCache.h b/src/dawn/native/BlobCache.h index 4dae92d49b..e3035cd862 100644 --- a/src/dawn/native/BlobCache.h +++ b/src/dawn/native/BlobCache.h @@ -43,16 +43,12 @@ class BlobCache { void Store(const CacheKey& key, size_t valueSize, const void* value); void Store(const CacheKey& key, const Blob& value); - // Other types may specialize BlobCache::Store to define how T is serialized into the cache. - template - void Store(const CacheKey& key, const T& value); - // Store a CacheResult into the cache if it isn't cached yet. - // Calls Store which should be defined elsewhere. + // Calls T::ToBlob which should be defined elsewhere. template void EnsureStored(const CacheResult& cacheResult) { if (!cacheResult.IsCached()) { - Store(cacheResult.GetCacheKey(), *cacheResult); + Store(cacheResult.GetCacheKey(), cacheResult->ToBlob()); } } diff --git a/src/dawn/native/CMakeLists.txt b/src/dawn/native/CMakeLists.txt index 897363a7f3..3f417bcbd2 100644 --- a/src/dawn/native/CMakeLists.txt +++ b/src/dawn/native/CMakeLists.txt @@ -170,6 +170,7 @@ target_sources(dawn_native PRIVATE "Sampler.h" "ScratchBuffer.cpp" "ScratchBuffer.h" + "Serializable.h" "ShaderModule.cpp" "ShaderModule.h" "StagingBuffer.cpp" diff --git a/src/dawn/native/Serializable.h b/src/dawn/native/Serializable.h new file mode 100644 index 0000000000..80d8339e2d --- /dev/null +++ b/src/dawn/native/Serializable.h @@ -0,0 +1,74 @@ +// Copyright 2022 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_DAWN_NATIVE_SERIALIZABLE_H_ +#define SRC_DAWN_NATIVE_SERIALIZABLE_H_ + +#include + +#include "dawn/native/VisitableMembers.h" +#include "dawn/native/stream/BlobSource.h" +#include "dawn/native/stream/ByteVectorSink.h" +#include "dawn/native/stream/Stream.h" + +namespace dawn::native { + +// Base CRTP for implementing StreamIn/StreamOut/FromBlob/ToBlob for Derived, +// assuming Derived has VisitAll methods provided by DAWN_VISITABLE_MEMBERS. +template +class Serializable { + public: + friend void StreamIn(stream::Sink* s, const Derived& in) { + in.VisitAll([&](const auto&... members) { StreamIn(s, members...); }); + } + + friend MaybeError StreamOut(stream::Source* s, Derived* out) { + return out->VisitAll([&](auto&... members) { return StreamOut(s, &members...); }); + } + + static ResultOrError FromBlob(Blob blob) { + stream::BlobSource source(std::move(blob)); + Derived out; + DAWN_TRY(StreamOut(&source, &out)); + return out; + } + + Blob ToBlob() const { + stream::ByteVectorSink sink; + StreamIn(&sink, static_cast(*this)); + return CreateBlob(std::move(sink)); + } +}; +} // namespace dawn::native + +// Helper macro to define a struct or class along with VisitAll methods to call +// a functor on all members. Derives from Visitable which provides +// implementations of StreamIn/StreamOut/FromBlob/ToBlob. +// Example usage: +// #define MEMBERS(X) \ +// X(int, a) \ +// X(float, b) \ +// X(Foo, foo) \ +// X(Bar, bar) +// DAWN_SERIALIZABLE(struct, MyStruct, MEMBERS) { +// void SomeAdditionalMethod(); +// }; +// #undef MEMBERS +#define DAWN_SERIALIZABLE(qualifier, Name, MEMBERS) \ + struct Name##__Contents { \ + DAWN_VISITABLE_MEMBERS(MEMBERS) \ + }; \ + qualifier Name : Name##__Contents, public ::dawn::native::Serializable + +#endif // SRC_DAWN_NATIVE_SERIALIZABLE_H_ diff --git a/src/dawn/native/VisitableMembers.h b/src/dawn/native/VisitableMembers.h index 2debfad548..82fb58d535 100644 --- a/src/dawn/native/VisitableMembers.h +++ b/src/dawn/native/VisitableMembers.h @@ -12,9 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef SRC_DAWN_NATIVE_VISITABLEMEMBERS_H_ -#define SRC_DAWN_NATIVE_VISITABLEMEMBERS_H_ +#ifndef SRC_DAWN_NATIVE_VISITABLE_H_ +#define SRC_DAWN_NATIVE_VISITABLE_H_ +#include + +#include "dawn/native/stream/BlobSource.h" +#include "dawn/native/stream/ByteVectorSink.h" #include "dawn/native/stream/Stream.h" // Helper for X macro to declare a visitable member. @@ -27,8 +31,8 @@ namespace dawn::native::detail { constexpr int kInternalVisitableUnusedForComma = 0; } // namespace dawn::native::detail -// Helper X macro to declare members of a class or struct, along with Visit -// methods to call a functor for each member. +// Helper X macro to declare members of a class or struct, along with VisitAll +// methods to call a functor on all members. // Example usage: // #define MEMBERS(X) \ // X(int, a) \ @@ -58,4 +62,4 @@ constexpr int kInternalVisitableUnusedForComma = 0; DAWN_INTERNAL_VISITABLE_MEMBER_ARG)); \ } -#endif // SRC_DAWN_NATIVE_VISITABLEMEMBERS_H_ +#endif // SRC_DAWN_NATIVE_VISITABLE_H_ diff --git a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp index 2fd04bdfc4..ce83dbefee 100644 --- a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp +++ b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp @@ -173,21 +173,13 @@ enum class Compiler { FXC, DXC }; X(IDxcCompiler*, dxcCompiler) \ X(DefineStrings, defineStrings) -struct HlslCompilationRequest { - DAWN_VISITABLE_MEMBERS(HLSL_COMPILATION_REQUEST_MEMBERS) +DAWN_SERIALIZABLE(struct, HlslCompilationRequest, HLSL_COMPILATION_REQUEST_MEMBERS){}; +#undef HLSL_COMPILATION_REQUEST_MEMBERS - friend void StreamIn(stream::Sink* sink, const HlslCompilationRequest& r) { - r.VisitAll([&](const auto&... members) { StreamIn(sink, members...); }); - } -}; - -struct D3DBytecodeCompilationRequest { - DAWN_VISITABLE_MEMBERS(D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS) - - friend void StreamIn(stream::Sink* sink, const D3DBytecodeCompilationRequest& r) { - r.VisitAll([&](const auto&... members) { StreamIn(sink, members...); }); - } -}; +DAWN_SERIALIZABLE(struct, + D3DBytecodeCompilationRequest, + D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS){}; +#undef D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS #define D3D_COMPILATION_REQUEST_MEMBERS(X) \ X(HlslCompilationRequest, hlsl) \ @@ -195,8 +187,6 @@ struct D3DBytecodeCompilationRequest { X(CacheKey::UnsafeUnkeyedValue, tracePlatform) DAWN_MAKE_CACHE_REQUEST(D3DCompilationRequest, D3D_COMPILATION_REQUEST_MEMBERS); -#undef HLSL_COMPILATION_REQUEST_MEMBERS -#undef D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS #undef D3D_COMPILATION_REQUEST_MEMBERS std::vector GetDXCArguments(uint32_t compileFlags, bool enable16BitTypes) { @@ -727,24 +717,3 @@ D3D12_SHADER_BYTECODE CompiledShader::GetD3D12ShaderBytecode() const { } } // namespace dawn::native::d3d12 - -namespace dawn::native { - -// Define the implementation to store d3d12::CompiledShader into the BlobCache. -template <> -void BlobCache::Store(const CacheKey& key, const d3d12::CompiledShader& c) { - stream::ByteVectorSink sink; - c.VisitAll([&](const auto&... members) { StreamIn(&sink, members...); }); - Store(key, CreateBlob(std::move(sink))); -} - -// Define the implementation to load d3d12::CompiledShader from a Blob. -// static -ResultOrError d3d12::CompiledShader::FromBlob(Blob blob) { - stream::BlobSource source(std::move(blob)); - d3d12::CompiledShader c; - DAWN_TRY(c.VisitAll([&](auto&... members) { return StreamOut(&source, &members...); })); - return c; -} - -} // namespace dawn::native diff --git a/src/dawn/native/d3d12/ShaderModuleD3D12.h b/src/dawn/native/d3d12/ShaderModuleD3D12.h index 528e1e47ba..8c70bb45a4 100644 --- a/src/dawn/native/d3d12/ShaderModuleD3D12.h +++ b/src/dawn/native/d3d12/ShaderModuleD3D12.h @@ -18,8 +18,8 @@ #include #include "dawn/native/Blob.h" +#include "dawn/native/Serializable.h" #include "dawn/native/ShaderModule.h" -#include "dawn/native/VisitableMembers.h" #include "dawn/native/d3d12/d3d12_platform.h" namespace dawn::native { @@ -40,14 +40,10 @@ class PipelineLayout; // information used to emulate vertex/instance index starts. It also holds the `hlslSource` for the // shader compilation, which is only transiently available during Compile, and cleared before it // returns. It is not written to or loaded from the cache unless Toggle dump_shaders is true. -struct CompiledShader { - static ResultOrError FromBlob(Blob blob); - +DAWN_SERIALIZABLE(struct, CompiledShader, COMPILED_SHADER_MEMBERS) { D3D12_SHADER_BYTECODE GetD3D12ShaderBytecode() const; - - DAWN_VISITABLE_MEMBERS(COMPILED_SHADER_MEMBERS) -#undef COMPILED_SHADER_MEMBERS }; +#undef COMPILED_SHADER_MEMBERS class ShaderModule final : public ShaderModuleBase { public: diff --git a/src/dawn/native/metal/ShaderModuleMTL.mm b/src/dawn/native/metal/ShaderModuleMTL.mm index 9cb6a13246..e56eb22d9b 100644 --- a/src/dawn/native/metal/ShaderModuleMTL.mm +++ b/src/dawn/native/metal/ShaderModuleMTL.mm @@ -16,8 +16,8 @@ #include "dawn/native/BindGroupLayout.h" #include "dawn/native/CacheRequest.h" +#include "dawn/native/Serializable.h" #include "dawn/native/TintUtils.h" -#include "dawn/native/VisitableMembers.h" #include "dawn/native/metal/DeviceMTL.h" #include "dawn/native/metal/PipelineLayoutMTL.h" #include "dawn/native/metal/RenderPipelineMTL.h" @@ -60,37 +60,12 @@ using WorkgroupAllocations = std::vector; X(bool, hasInvariantAttribute) \ X(WorkgroupAllocations, workgroupAllocations) -struct MslCompilation { - static ResultOrError FromBlob(Blob blob); - - DAWN_VISITABLE_MEMBERS(MSL_COMPILATION_MEMBERS) +DAWN_SERIALIZABLE(struct, MslCompilation, MSL_COMPILATION_MEMBERS){}; #undef MSL_COMPILATION_MEMBERS -}; } // namespace } // namespace dawn::native::metal -namespace dawn::native { - -// Define the implementation to store MslCompilation into the BlobCache. -template <> -void BlobCache::Store(const CacheKey& key, const metal::MslCompilation& c) { - stream::ByteVectorSink sink; - c.VisitAll([&](const auto&... members) { StreamIn(&sink, members...); }); - Store(key, CreateBlob(std::move(sink))); -} - -// Define the implementation to load MslCompilation from a blob. -// static -ResultOrError metal::MslCompilation::FromBlob(Blob blob) { - stream::BlobSource source(std::move(blob)); - metal::MslCompilation c; - DAWN_TRY(c.VisitAll([&](auto&... members) { return StreamOut(&source, &members...); })); - return c; -} - -} // namespace dawn::native - namespace dawn::native::metal { // static @@ -277,13 +252,13 @@ ResultOrError> TranslateToMSL(DeviceBase* device, auto workgroupAllocations = std::move(result.workgroup_allocations[remappedEntryPointName]); - return MslCompilation{ + return MslCompilation{{ std::move(result.msl), std::move(remappedEntryPointName), result.needs_storage_buffer_sizes, result.has_invariant_attribute, std::move(workgroupAllocations), - }; + }}; }); if (device->IsToggleEnabled(Toggle::DumpShaders)) { diff --git a/src/dawn/native/opengl/ShaderModuleGL.cpp b/src/dawn/native/opengl/ShaderModuleGL.cpp index 30ca69293a..3ad4cc25ff 100644 --- a/src/dawn/native/opengl/ShaderModuleGL.cpp +++ b/src/dawn/native/opengl/ShaderModuleGL.cpp @@ -31,7 +31,6 @@ #include "tint/tint.h" namespace dawn::native { - namespace { GLenum GLShaderType(SingleShaderStage stage) { @@ -73,53 +72,11 @@ DAWN_MAKE_CACHE_REQUEST(GLSLCompilationRequest, GLSL_COMPILATION_REQUEST_MEMBERS X(std::string, glsl) \ X(bool, needsPlaceholderSampler) \ X(opengl::CombinedSamplerInfo, combinedSamplerInfo) -struct GLSLCompilation { - DAWN_VISITABLE_MEMBERS(GLSL_COMPILATION_MEMBERS) + +DAWN_SERIALIZABLE(struct, GLSLCompilation, GLSL_COMPILATION_MEMBERS){}; #undef GLSL_COMPILATION_MEMBERS - static ResultOrError FromBlob(Blob blob) { - stream::BlobSource source(std::move(blob)); - GLSLCompilation out; - DAWN_TRY(out.VisitAll([&](auto&... members) { return StreamOut(&source, &members...); })); - return out; - } -}; - } // namespace - -template <> -void BlobCache::Store(const CacheKey& key, const GLSLCompilation& c) { - stream::ByteVectorSink sink; - c.VisitAll([&](const auto&... members) { StreamIn(&sink, members...); }); - Store(key, CreateBlob(std::move(sink))); -} - -template <> -void stream::Stream::Write( - stream::Sink* s, - const opengl::BindingLocation& bindingLocation) { - bindingLocation.VisitAll([&](auto&... members) { return StreamIn(s, members...); }); -} - -template <> -MaybeError stream::Stream::Read(stream::Source* s, - opengl::BindingLocation* bindingLocation) { - return bindingLocation->VisitAll([&](auto&... members) { return StreamOut(s, &members...); }); -} - -template <> -void stream::Stream::Write( - stream::Sink* s, - const opengl::CombinedSampler& combinedSampler) { - combinedSampler.VisitAll([&](auto&... members) { return StreamIn(s, members...); }); -} - -template <> -MaybeError stream::Stream::Read(stream::Source* s, - opengl::CombinedSampler* combinedSampler) { - return combinedSampler->VisitAll([&](auto&... members) { return StreamOut(s, &members...); }); -} - } // namespace dawn::native namespace dawn::native::opengl { @@ -280,8 +237,8 @@ ResultOrError ShaderModule::CompileShader(const OpenGLFunctions& gl, DAWN_INVALID_IF(!result.success, "An error occured while generating GLSL: %s.", result.error); - return GLSLCompilation{std::move(result.glsl), needsPlaceholderSampler, - std::move(combinedSamplerInfo)}; + return GLSLCompilation{ + {std::move(result.glsl), needsPlaceholderSampler, std::move(combinedSamplerInfo)}}; }); if (GetDevice()->IsToggleEnabled(Toggle::DumpShaders)) { diff --git a/src/dawn/native/opengl/ShaderModuleGL.h b/src/dawn/native/opengl/ShaderModuleGL.h index 838d7e8a5d..b19b4c62c8 100644 --- a/src/dawn/native/opengl/ShaderModuleGL.h +++ b/src/dawn/native/opengl/ShaderModuleGL.h @@ -20,8 +20,8 @@ #include #include +#include "dawn/native/Serializable.h" #include "dawn/native/ShaderModule.h" -#include "dawn/native/VisitableMembers.h" #include "dawn/native/opengl/opengl_platform.h" namespace dawn::native { @@ -44,10 +44,9 @@ std::string GetBindingName(BindGroupIndex group, BindingNumber bindingNumber); #define BINDING_LOCATION_MEMBERS(X) \ X(BindGroupIndex, group) \ X(BindingNumber, binding) -struct BindingLocation { - DAWN_VISITABLE_MEMBERS(BINDING_LOCATION_MEMBERS) +DAWN_SERIALIZABLE(struct, BindingLocation, BINDING_LOCATION_MEMBERS){}; #undef BINDING_LOCATION_MEMBERS -}; + bool operator<(const BindingLocation& a, const BindingLocation& b); #define COMBINED_SAMPLER_MEMBERS(X) \ @@ -58,12 +57,11 @@ bool operator<(const BindingLocation& a, const BindingLocation& b); /* |samplerLocation| is unused. */ \ X(bool, usePlaceholderSampler) -struct CombinedSampler { - DAWN_VISITABLE_MEMBERS(COMBINED_SAMPLER_MEMBERS) -#undef COMBINED_SAMPLER_MEMBERS - +DAWN_SERIALIZABLE(struct, CombinedSampler, COMBINED_SAMPLER_MEMBERS) { std::string GetName() const; }; +#undef COMBINED_SAMPLER_MEMBERS + bool operator<(const CombinedSampler& a, const CombinedSampler& b); using CombinedSamplerInfo = std::vector; diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp index 1b40234937..5477d84784 100644 --- a/src/dawn/native/vulkan/ShaderModuleVk.cpp +++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp @@ -44,6 +44,8 @@ class ShaderModule::Spirv : private Blob { return static_cast(blob); } + const Blob& ToBlob() const { return *this; } + static Spirv Create(std::vector code) { Blob blob = CreateBlob(std::move(code)); ASSERT(IsPtrAligned(blob.Data(), alignof(uint32_t))); @@ -56,17 +58,6 @@ class ShaderModule::Spirv : private Blob { } // namespace dawn::native::vulkan -namespace dawn::native { - -// Define the implementation to store vulkan::ShaderModule::Spirv into the BlobCache. -template <> -void BlobCache::Store(const CacheKey& key, - const vulkan::ShaderModule::Spirv& spirv) { - Store(key, spirv.WordCount() * sizeof(uint32_t), spirv.Code()); -} - -} // namespace dawn::native - namespace dawn::native::vulkan { class ShaderModule::ConcurrentTransformedShaderModuleCache { diff --git a/src/dawn/tests/unittests/native/StreamTests.cpp b/src/dawn/tests/unittests/native/StreamTests.cpp index a1196f3465..2246fb3bba 100644 --- a/src/dawn/tests/unittests/native/StreamTests.cpp +++ b/src/dawn/tests/unittests/native/StreamTests.cpp @@ -22,7 +22,7 @@ #include "dawn/common/TypedInteger.h" #include "dawn/native/Blob.h" -#include "dawn/native/VisitableMembers.h" +#include "dawn/native/Serializable.h" #include "dawn/native/stream/BlobSource.h" #include "dawn/native/stream/ByteVectorSink.h" #include "dawn/native/stream/Stream.h" @@ -302,17 +302,15 @@ TEST(StreamTests, SerializeDeserializeParamPack) { X(int, a) \ X(float, b) \ X(std::string, c) -struct Foo { - DAWN_VISITABLE_MEMBERS(FOO_MEMBERS) +DAWN_SERIALIZABLE(struct, Foo, FOO_MEMBERS){}; #undef FOO_MEMBERS -}; -// Test that serializing then deserializing a struct made with DAWN_VISITABLE_MEMBERS works as +// Test that serializing then deserializing a struct made with DAWN_SERIALIZABLE works as // expected. -TEST(StreamTests, SerializeDeserializeVisitableMembers) { - Foo foo{1, 2, "3"}; +TEST(StreamTests, SerializeDeserializeVisitable) { + Foo foo{{1, 2, "3"}}; ByteVectorSink sink; - foo.VisitAll([&](const auto&... members) { StreamIn(&sink, members...); }); + StreamIn(&sink, foo); // Test that the serialization is correct. { @@ -325,7 +323,7 @@ TEST(StreamTests, SerializeDeserializeVisitableMembers) { { BlobSource src(CreateBlob(sink)); Foo out; - auto err = out.VisitAll([&](auto&... members) { return StreamOut(&src, &members...); }); + auto err = StreamOut(&src, &out); EXPECT_FALSE(err.IsError()); EXPECT_EQ(foo.a, out.a); EXPECT_EQ(foo.b, out.b);