Cache WGSL -> MSL compilation

Bug: dawn:1480
Change-Id: Ie2ef7860b38d7f350c99cf2c5451299b23413ec6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/97882
Commit-Queue: Austin Eng <enga@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Loko Kung <lokokung@google.com>
This commit is contained in:
Austin Eng 2022-08-04 01:12:56 +00:00 committed by Dawn LUCI CQ
parent e40bd8e964
commit 0671fe28bf
17 changed files with 569 additions and 282 deletions

View File

@ -333,6 +333,7 @@ source_set("sources") {
"UsageValidationMode.h",
"VertexFormat.cpp",
"VertexFormat.h",
"VisitableMembers.h",
"dawn_platform.h",
"stream/BlobSource.cpp",
"stream/BlobSource.h",

View File

@ -190,6 +190,7 @@ target_sources(dawn_native PRIVATE
"UsageValidationMode.h"
"VertexFormat.cpp"
"VertexFormat.h"
"VisitableMembers.h"
"dawn_platform.h"
"webgpu_absl_format.cpp"
"webgpu_absl_format.h"

View File

@ -26,6 +26,7 @@
#include "dawn/native/CacheResult.h"
#include "dawn/native/Device.h"
#include "dawn/native/Error.h"
#include "dawn/native/VisitableMembers.h"
namespace dawn::native {
@ -94,6 +95,15 @@ class CacheRequestImpl {
CacheRequestImpl(const CacheRequestImpl&) = delete;
CacheRequestImpl& operator=(const CacheRequestImpl&) = delete;
// Create a CacheKey from the request type and all members
CacheKey CreateCacheKey(const DeviceBase* device) const {
CacheKey key = device->GetCacheKey();
StreamIn(&key, Request::kName);
static_cast<const Request*>(this)->VisitAll(
[&](const auto&... members) { StreamIn(&key, members...); });
return key;
}
template <typename CacheHitFn, typename CacheMissFn>
friend auto LoadOrRun(DeviceBase* device,
Request&& r,
@ -171,16 +181,9 @@ class CacheRequestImpl {
#define DAWN_MAKE_CACHE_REQUEST(Request, MEMBERS) \
class Request : public ::dawn::native::CacheRequestImpl<Request> { \
public: \
static constexpr char kName[] = #Request; \
Request() = default; \
MEMBERS(DAWN_INTERNAL_CACHE_REQUEST_DECL_STRUCT_MEMBER) \
\
/* Create a CacheKey from the request type and all members */ \
::dawn::native::CacheKey CreateCacheKey(const ::dawn::native::DeviceBase* device) const { \
::dawn::native::CacheKey key = device->GetCacheKey(); \
StreamIn(&key, #Request); \
MEMBERS(DAWN_INTERNAL_CACHE_REQUEST_RECORD_KEY) \
return key; \
} \
DAWN_VISITABLE_MEMBERS(MEMBERS) \
};
// Helper macro for the common pattern of DAWN_TRY_ASSIGN around LoadOrRun.

View File

@ -37,87 +37,6 @@ namespace dawn::native {
namespace {
tint::transform::VertexFormat ToTintVertexFormat(wgpu::VertexFormat format) {
switch (format) {
case wgpu::VertexFormat::Uint8x2:
return tint::transform::VertexFormat::kUint8x2;
case wgpu::VertexFormat::Uint8x4:
return tint::transform::VertexFormat::kUint8x4;
case wgpu::VertexFormat::Sint8x2:
return tint::transform::VertexFormat::kSint8x2;
case wgpu::VertexFormat::Sint8x4:
return tint::transform::VertexFormat::kSint8x4;
case wgpu::VertexFormat::Unorm8x2:
return tint::transform::VertexFormat::kUnorm8x2;
case wgpu::VertexFormat::Unorm8x4:
return tint::transform::VertexFormat::kUnorm8x4;
case wgpu::VertexFormat::Snorm8x2:
return tint::transform::VertexFormat::kSnorm8x2;
case wgpu::VertexFormat::Snorm8x4:
return tint::transform::VertexFormat::kSnorm8x4;
case wgpu::VertexFormat::Uint16x2:
return tint::transform::VertexFormat::kUint16x2;
case wgpu::VertexFormat::Uint16x4:
return tint::transform::VertexFormat::kUint16x4;
case wgpu::VertexFormat::Sint16x2:
return tint::transform::VertexFormat::kSint16x2;
case wgpu::VertexFormat::Sint16x4:
return tint::transform::VertexFormat::kSint16x4;
case wgpu::VertexFormat::Unorm16x2:
return tint::transform::VertexFormat::kUnorm16x2;
case wgpu::VertexFormat::Unorm16x4:
return tint::transform::VertexFormat::kUnorm16x4;
case wgpu::VertexFormat::Snorm16x2:
return tint::transform::VertexFormat::kSnorm16x2;
case wgpu::VertexFormat::Snorm16x4:
return tint::transform::VertexFormat::kSnorm16x4;
case wgpu::VertexFormat::Float16x2:
return tint::transform::VertexFormat::kFloat16x2;
case wgpu::VertexFormat::Float16x4:
return tint::transform::VertexFormat::kFloat16x4;
case wgpu::VertexFormat::Float32:
return tint::transform::VertexFormat::kFloat32;
case wgpu::VertexFormat::Float32x2:
return tint::transform::VertexFormat::kFloat32x2;
case wgpu::VertexFormat::Float32x3:
return tint::transform::VertexFormat::kFloat32x3;
case wgpu::VertexFormat::Float32x4:
return tint::transform::VertexFormat::kFloat32x4;
case wgpu::VertexFormat::Uint32:
return tint::transform::VertexFormat::kUint32;
case wgpu::VertexFormat::Uint32x2:
return tint::transform::VertexFormat::kUint32x2;
case wgpu::VertexFormat::Uint32x3:
return tint::transform::VertexFormat::kUint32x3;
case wgpu::VertexFormat::Uint32x4:
return tint::transform::VertexFormat::kUint32x4;
case wgpu::VertexFormat::Sint32:
return tint::transform::VertexFormat::kSint32;
case wgpu::VertexFormat::Sint32x2:
return tint::transform::VertexFormat::kSint32x2;
case wgpu::VertexFormat::Sint32x3:
return tint::transform::VertexFormat::kSint32x3;
case wgpu::VertexFormat::Sint32x4:
return tint::transform::VertexFormat::kSint32x4;
case wgpu::VertexFormat::Undefined:
break;
}
UNREACHABLE();
}
tint::transform::VertexStepMode ToTintVertexStepMode(wgpu::VertexStepMode mode) {
switch (mode) {
case wgpu::VertexStepMode::Vertex:
return tint::transform::VertexStepMode::kVertex;
case wgpu::VertexStepMode::Instance:
return tint::transform::VertexStepMode::kInstance;
case wgpu::VertexStepMode::VertexBufferNotUsed:
break;
}
UNREACHABLE();
}
ResultOrError<SingleShaderStage> TintPipelineStageToShaderStage(
tint::inspector::PipelineStage stage) {
switch (stage) {
@ -1088,39 +1007,6 @@ ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform
return std::move(output.program);
}
void AddVertexPullingTransformConfig(const RenderPipelineBase& renderPipeline,
const std::string& entryPoint,
BindGroupIndex pullingBufferBindingSet,
tint::transform::DataMap* transformInputs) {
tint::transform::VertexPulling::Config cfg;
cfg.entry_point_name = entryPoint;
cfg.pulling_group = static_cast<uint32_t>(pullingBufferBindingSet);
cfg.vertex_state.resize(renderPipeline.GetVertexBufferCount());
for (VertexBufferSlot slot : IterateBitSet(renderPipeline.GetVertexBufferSlotsUsed())) {
const VertexBufferInfo& dawnInfo = renderPipeline.GetVertexBuffer(slot);
tint::transform::VertexBufferLayoutDescriptor* tintInfo =
&cfg.vertex_state[static_cast<uint8_t>(slot)];
tintInfo->array_stride = dawnInfo.arrayStride;
tintInfo->step_mode = ToTintVertexStepMode(dawnInfo.stepMode);
}
for (VertexAttributeLocation location :
IterateBitSet(renderPipeline.GetAttributeLocationsUsed())) {
const VertexAttributeInfo& dawnInfo = renderPipeline.GetAttribute(location);
tint::transform::VertexAttributeDescriptor tintInfo;
tintInfo.format = ToTintVertexFormat(dawnInfo.format);
tintInfo.offset = dawnInfo.offset;
tintInfo.shader_location = static_cast<uint32_t>(static_cast<uint8_t>(location));
uint8_t vertexBufferSlot = static_cast<uint8_t>(dawnInfo.vertexBufferSlot);
cfg.vertex_state[vertexBufferSlot].attributes.push_back(tintInfo);
}
transformInputs->Add<tint::transform::VertexPulling::Config>(cfg);
}
MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device,
const EntryPointMetadata& entryPoint,
const PipelineLayoutBase* layout) {
@ -1306,29 +1192,6 @@ OwnedCompilationMessages* ShaderModuleBase::GetCompilationMessages() const {
return mCompilationMessages.get();
}
// static
void ShaderModuleBase::AddExternalTextureTransform(const PipelineLayoutBase* layout,
tint::transform::Manager* transformManager,
tint::transform::DataMap* transformInputs) {
tint::transform::MultiplanarExternalTexture::BindingsMap newBindingsMap;
for (BindGroupIndex i : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
const BindGroupLayoutBase* bgl = layout->GetBindGroupLayout(i);
for (const auto& expansion : bgl->GetExternalTextureBindingExpansionMap()) {
newBindingsMap[{static_cast<uint32_t>(i),
static_cast<uint32_t>(expansion.second.plane0)}] = {
{static_cast<uint32_t>(i), static_cast<uint32_t>(expansion.second.plane1)},
{static_cast<uint32_t>(i), static_cast<uint32_t>(expansion.second.params)}};
}
}
if (!newBindingsMap.empty()) {
transformManager->Add<tint::transform::MultiplanarExternalTexture>();
transformInputs->Add<tint::transform::MultiplanarExternalTexture::NewBindingPoints>(
newBindingsMap);
}
}
MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
mTintProgram = std::move(parseResult->tintProgram);

View File

@ -116,12 +116,6 @@ ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform
tint::transform::DataMap* outputs,
OwnedCompilationMessages* messages);
/// Creates and adds the tint::transform::VertexPulling::Config to transformInputs.
void AddVertexPullingTransformConfig(const RenderPipelineBase& renderPipeline,
const std::string& entryPoint,
BindGroupIndex pullingBufferBindingSet,
tint::transform::DataMap* transformInputs);
// Mirrors wgpu::SamplerBindingLayout but instead stores a single boolean
// for isComparison instead of a wgpu::SamplerBindingType enum.
struct ShaderSamplerBindingInfo {
@ -295,10 +289,6 @@ class ShaderModuleBase : public ApiObjectBase, public CachedObject {
MaybeError InitializeBase(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
static void AddExternalTextureTransform(const PipelineLayoutBase* layout,
tint::transform::Manager* transformManager,
tint::transform::DataMap* transformInputs);
private:
ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag);

View File

@ -59,4 +59,41 @@ void stream::Stream<tint::transform::BindingPoints>::Write(
StreamIn(sink, points.plane_1, points.params);
}
template <>
void stream::Stream<tint::transform::VertexPulling::Config>::Write(
stream::Sink* sink,
const tint::transform::VertexPulling::Config& cfg) {
StreamIn(sink, cfg.entry_point_name, cfg.vertex_state, cfg.pulling_group);
}
template <>
void stream::Stream<tint::transform::VertexBufferLayoutDescriptor>::Write(
stream::Sink* sink,
const tint::transform::VertexBufferLayoutDescriptor& layout) {
using Layout = tint::transform::VertexBufferLayoutDescriptor;
static_assert(offsetof(Layout, array_stride) == 0,
"Please update serialization for tint::transform::VertexBufferLayoutDescriptor");
static_assert(offsetof(Layout, step_mode) == 4,
"Please update serialization for tint::transform::VertexBufferLayoutDescriptor");
static_assert(offsetof(Layout, attributes) == 8,
"Please update serialization for tint::transform::VertexBufferLayoutDescriptor");
StreamIn(sink, layout.array_stride, layout.step_mode, layout.attributes);
}
template <>
void stream::Stream<tint::transform::VertexAttributeDescriptor>::Write(
stream::Sink* sink,
const tint::transform::VertexAttributeDescriptor& attrib) {
using Attrib = tint::transform::VertexAttributeDescriptor;
static_assert(offsetof(Attrib, format) == 0,
"Please update serialization for tint::transform::VertexAttributeDescriptor");
static_assert(offsetof(Attrib, offset) == 4,
"Please update serialization for tint::transform::VertexAttributeDescriptor");
static_assert(offsetof(Attrib, shader_location) == 8,
"Please update serialization for tint::transform::VertexAttributeDescriptor");
static_assert(sizeof(Attrib) == 12,
"Please update serialization for tint::transform::VertexAttributeDescriptor");
StreamIn(sink, attrib.format, attrib.offset, attrib.shader_location);
}
} // namespace dawn::native

View File

@ -14,7 +14,10 @@
#include "dawn/native/TintUtils.h"
#include "dawn/native/BindGroupLayout.h"
#include "dawn/native/Device.h"
#include "dawn/native/PipelineLayout.h"
#include "dawn/native/RenderPipeline.h"
#include "tint/tint.h"
@ -35,6 +38,87 @@ bool InitializeTintErrorReporter() {
return true;
}
tint::transform::VertexFormat ToTintVertexFormat(wgpu::VertexFormat format) {
switch (format) {
case wgpu::VertexFormat::Uint8x2:
return tint::transform::VertexFormat::kUint8x2;
case wgpu::VertexFormat::Uint8x4:
return tint::transform::VertexFormat::kUint8x4;
case wgpu::VertexFormat::Sint8x2:
return tint::transform::VertexFormat::kSint8x2;
case wgpu::VertexFormat::Sint8x4:
return tint::transform::VertexFormat::kSint8x4;
case wgpu::VertexFormat::Unorm8x2:
return tint::transform::VertexFormat::kUnorm8x2;
case wgpu::VertexFormat::Unorm8x4:
return tint::transform::VertexFormat::kUnorm8x4;
case wgpu::VertexFormat::Snorm8x2:
return tint::transform::VertexFormat::kSnorm8x2;
case wgpu::VertexFormat::Snorm8x4:
return tint::transform::VertexFormat::kSnorm8x4;
case wgpu::VertexFormat::Uint16x2:
return tint::transform::VertexFormat::kUint16x2;
case wgpu::VertexFormat::Uint16x4:
return tint::transform::VertexFormat::kUint16x4;
case wgpu::VertexFormat::Sint16x2:
return tint::transform::VertexFormat::kSint16x2;
case wgpu::VertexFormat::Sint16x4:
return tint::transform::VertexFormat::kSint16x4;
case wgpu::VertexFormat::Unorm16x2:
return tint::transform::VertexFormat::kUnorm16x2;
case wgpu::VertexFormat::Unorm16x4:
return tint::transform::VertexFormat::kUnorm16x4;
case wgpu::VertexFormat::Snorm16x2:
return tint::transform::VertexFormat::kSnorm16x2;
case wgpu::VertexFormat::Snorm16x4:
return tint::transform::VertexFormat::kSnorm16x4;
case wgpu::VertexFormat::Float16x2:
return tint::transform::VertexFormat::kFloat16x2;
case wgpu::VertexFormat::Float16x4:
return tint::transform::VertexFormat::kFloat16x4;
case wgpu::VertexFormat::Float32:
return tint::transform::VertexFormat::kFloat32;
case wgpu::VertexFormat::Float32x2:
return tint::transform::VertexFormat::kFloat32x2;
case wgpu::VertexFormat::Float32x3:
return tint::transform::VertexFormat::kFloat32x3;
case wgpu::VertexFormat::Float32x4:
return tint::transform::VertexFormat::kFloat32x4;
case wgpu::VertexFormat::Uint32:
return tint::transform::VertexFormat::kUint32;
case wgpu::VertexFormat::Uint32x2:
return tint::transform::VertexFormat::kUint32x2;
case wgpu::VertexFormat::Uint32x3:
return tint::transform::VertexFormat::kUint32x3;
case wgpu::VertexFormat::Uint32x4:
return tint::transform::VertexFormat::kUint32x4;
case wgpu::VertexFormat::Sint32:
return tint::transform::VertexFormat::kSint32;
case wgpu::VertexFormat::Sint32x2:
return tint::transform::VertexFormat::kSint32x2;
case wgpu::VertexFormat::Sint32x3:
return tint::transform::VertexFormat::kSint32x3;
case wgpu::VertexFormat::Sint32x4:
return tint::transform::VertexFormat::kSint32x4;
case wgpu::VertexFormat::Undefined:
break;
}
UNREACHABLE();
}
tint::transform::VertexStepMode ToTintVertexStepMode(wgpu::VertexStepMode mode) {
switch (mode) {
case wgpu::VertexStepMode::Vertex:
return tint::transform::VertexStepMode::kVertex;
case wgpu::VertexStepMode::Instance:
return tint::transform::VertexStepMode::kInstance;
case wgpu::VertexStepMode::VertexBufferNotUsed:
break;
}
UNREACHABLE();
}
} // namespace
ScopedTintICEHandler::ScopedTintICEHandler(DeviceBase* device) {
@ -53,6 +137,52 @@ ScopedTintICEHandler::~ScopedTintICEHandler() {
tlDevice = nullptr;
}
tint::transform::MultiplanarExternalTexture::BindingsMap BuildExternalTextureTransformBindings(
const PipelineLayoutBase* layout) {
tint::transform::MultiplanarExternalTexture::BindingsMap newBindingsMap;
for (BindGroupIndex i : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
const BindGroupLayoutBase* bgl = layout->GetBindGroupLayout(i);
for (const auto& [_, expansion] : bgl->GetExternalTextureBindingExpansionMap()) {
newBindingsMap[{static_cast<uint32_t>(i), static_cast<uint32_t>(expansion.plane0)}] = {
{static_cast<uint32_t>(i), static_cast<uint32_t>(expansion.plane1)},
{static_cast<uint32_t>(i), static_cast<uint32_t>(expansion.params)}};
}
}
return newBindingsMap;
}
tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig(
const RenderPipelineBase& renderPipeline,
const std::string_view& entryPoint,
BindGroupIndex pullingBufferBindingSet) {
tint::transform::VertexPulling::Config cfg;
cfg.entry_point_name = entryPoint;
cfg.pulling_group = static_cast<uint32_t>(pullingBufferBindingSet);
cfg.vertex_state.resize(renderPipeline.GetVertexBufferCount());
for (VertexBufferSlot slot : IterateBitSet(renderPipeline.GetVertexBufferSlotsUsed())) {
const VertexBufferInfo& dawnInfo = renderPipeline.GetVertexBuffer(slot);
tint::transform::VertexBufferLayoutDescriptor* tintInfo =
&cfg.vertex_state[static_cast<uint8_t>(slot)];
tintInfo->array_stride = dawnInfo.arrayStride;
tintInfo->step_mode = ToTintVertexStepMode(dawnInfo.stepMode);
}
for (VertexAttributeLocation location :
IterateBitSet(renderPipeline.GetAttributeLocationsUsed())) {
const VertexAttributeInfo& dawnInfo = renderPipeline.GetAttribute(location);
tint::transform::VertexAttributeDescriptor tintInfo;
tintInfo.format = ToTintVertexFormat(dawnInfo.format);
tintInfo.offset = dawnInfo.offset;
tintInfo.shader_location = static_cast<uint32_t>(static_cast<uint8_t>(location));
uint8_t vertexBufferSlot = static_cast<uint8_t>(dawnInfo.vertexBufferSlot);
cfg.vertex_state[vertexBufferSlot].attributes.push_back(tintInfo);
}
return cfg;
}
} // namespace dawn::native
bool std::less<tint::sem::BindingPoint>::operator()(const tint::sem::BindingPoint& a,

View File

@ -18,14 +18,15 @@
#include <functional>
#include "dawn/common/NonCopyable.h"
#include "dawn/native/IntegerTypes.h"
namespace tint::sem {
struct BindingPoint;
}
#include "tint/tint.h"
namespace dawn::native {
class DeviceBase;
class PipelineLayoutBase;
class RenderPipelineBase;
// Indicates that for the lifetime of this object tint internal compiler errors should be
// reported to the given device.
@ -38,6 +39,14 @@ class ScopedTintICEHandler : public NonCopyable {
ScopedTintICEHandler(ScopedTintICEHandler&&) = delete;
};
tint::transform::MultiplanarExternalTexture::BindingsMap BuildExternalTextureTransformBindings(
const PipelineLayoutBase* layout);
tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig(
const RenderPipelineBase& renderPipeline,
const std::string_view& entryPoint,
BindGroupIndex pullingBufferBindingSet);
} // namespace dawn::native
// std::less operator for std::map containing BindingPoint

View File

@ -0,0 +1,61 @@
// Copyright 2022 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_DAWN_NATIVE_VISITABLEMEMBERS_H_
#define SRC_DAWN_NATIVE_VISITABLEMEMBERS_H_
#include "dawn/native/stream/Stream.h"
// Helper for X macro to declare a visitable member.
#define DAWN_INTERNAL_VISITABLE_MEMBER_DECL(type, name) type name{};
// Helper for X macro for visiting a visitable member.
#define DAWN_INTERNAL_VISITABLE_MEMBER_ARG(type, name) , name
namespace dawn::native::detail {
constexpr int kInternalVisitableUnusedForComma = 0;
} // namespace dawn::native::detail
// Helper X macro to declare members of a class or struct, along with Visit
// methods to call a functor for each member.
// Example usage:
// #define MEMBERS(X) \
// X(int, a) \
// X(float, b) \
// X(Foo, foo) \
// X(Bar, bar)
// struct MyStruct {
// DAWN_VISITABLE_MEMBERS(MEMBERS)
// };
// #undef MEMBERS
#define DAWN_VISITABLE_MEMBERS(MEMBERS) \
MEMBERS(DAWN_INTERNAL_VISITABLE_MEMBER_DECL) \
\
template <typename V> \
constexpr auto VisitAll(V&& visit) const { \
return [&](int, const auto&... ms) { \
return visit(ms...); \
}(::dawn::native::detail::kInternalVisitableUnusedForComma MEMBERS( \
DAWN_INTERNAL_VISITABLE_MEMBER_ARG)); \
} \
\
template <typename V> \
constexpr auto VisitAll(V&& visit) { \
return [&](int, auto&... ms) { \
return visit(ms...); \
}(::dawn::native::detail::kInternalVisitableUnusedForComma MEMBERS( \
DAWN_INTERNAL_VISITABLE_MEMBER_ARG)); \
}
#endif // SRC_DAWN_NATIVE_VISITABLEMEMBERS_H_

View File

@ -756,7 +756,12 @@ ResultOrError<CompiledShader> ShaderModule::Compile(const ProgrammableStage& pro
const tint::Program* program = GetTintProgram();
tint::Program programAsValue;
AddExternalTextureTransform(layout, &transformManager, &transformInputs);
auto externalTextureBindings = BuildExternalTextureTransformBindings(layout);
if (!externalTextureBindings.empty()) {
transformManager.Add<tint::transform::MultiplanarExternalTexture>();
transformInputs.Add<tint::transform::MultiplanarExternalTexture::NewBindingPoints>(
std::move(externalTextureBindings));
}
if (stage == SingleShaderStage::Vertex) {
transformManager.Add<tint::transform::FirstIndexOffset>();

View File

@ -55,15 +55,6 @@ class ShaderModule final : public ShaderModuleBase {
const RenderPipeline* renderPipeline = nullptr);
private:
ResultOrError<std::string> TranslateToMSL(const char* entryPointName,
SingleShaderStage stage,
const PipelineLayout* layout,
uint32_t sampleMask,
const RenderPipeline* renderPipeline,
std::string* remappedEntryPointName,
bool* needsStorageBufferLength,
bool* hasInvariantAttribute,
std::vector<uint32_t>* workgroupAllocations);
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override;
MaybeError Initialize(ShaderModuleParseResult* parseResult,

View File

@ -15,10 +15,14 @@
#include "dawn/native/metal/ShaderModuleMTL.h"
#include "dawn/native/BindGroupLayout.h"
#include "dawn/native/CacheRequest.h"
#include "dawn/native/TintUtils.h"
#include "dawn/native/VisitableMembers.h"
#include "dawn/native/metal/DeviceMTL.h"
#include "dawn/native/metal/PipelineLayoutMTL.h"
#include "dawn/native/metal/RenderPipelineMTL.h"
#include "dawn/native/stream/BlobSource.h"
#include "dawn/native/stream/ByteVectorSink.h"
#include "dawn/platform/DawnPlatform.h"
#include "dawn/platform/tracing/TraceEvent.h"
@ -26,6 +30,67 @@
#include <sstream>
namespace dawn::native::metal {
namespace {
using OptionalVertexPullingTransformConfig = std::optional<tint::transform::VertexPulling::Config>;
#define MSL_COMPILATION_REQUEST_MEMBERS(X) \
X(const tint::Program*, inputProgram) \
X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \
X(tint::transform::MultiplanarExternalTexture::BindingsMap, externalTextureBindings) \
X(OptionalVertexPullingTransformConfig, vertexPullingTransformConfig) \
X(std::string, entryPointName) \
X(uint32_t, sampleMask) \
X(bool, emitVertexPointSize) \
X(bool, isRobustnessEnabled) \
X(bool, disableSymbolRenaming) \
X(bool, disableWorkgroupInit) \
X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, tracePlatform)
DAWN_MAKE_CACHE_REQUEST(MslCompilationRequest, MSL_COMPILATION_REQUEST_MEMBERS);
#undef MSL_COMPILATION_REQUEST_MEMBERS
using WorkgroupAllocations = std::vector<uint32_t>;
#define MSL_COMPILATION_MEMBERS(X) \
X(std::string, msl) \
X(std::string, remappedEntryPointName) \
X(bool, needsStorageBufferLength) \
X(bool, hasInvariantAttribute) \
X(WorkgroupAllocations, workgroupAllocations)
struct MslCompilation {
static ResultOrError<MslCompilation> FromBlob(Blob blob);
DAWN_VISITABLE_MEMBERS(MSL_COMPILATION_MEMBERS)
#undef MSL_COMPILATION_MEMBERS
};
} // namespace
} // namespace dawn::native::metal
namespace dawn::native {
// Define the implementation to store MslCompilation into the BlobCache.
template <>
void BlobCache::Store<metal::MslCompilation>(const CacheKey& key, const metal::MslCompilation& c) {
stream::ByteVectorSink sink;
c.VisitAll([&](const auto&... members) { StreamIn(&sink, members...); });
Store(key, CreateBlob(std::move(sink)));
}
// Define the implementation to load MslCompilation from a blob.
// static
ResultOrError<metal::MslCompilation> metal::MslCompilation::FromBlob(Blob blob) {
stream::BlobSource source(std::move(blob));
metal::MslCompilation c;
DAWN_TRY(c.VisitAll([&](auto&... members) { return StreamOut(&source, &members...); }));
return c;
}
} // namespace dawn::native
namespace dawn::native::metal {
// static
@ -50,17 +115,16 @@ MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
return InitializeBase(parseResult, compilationMessages);
}
ResultOrError<std::string> ShaderModule::TranslateToMSL(
namespace {
ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(DeviceBase* device,
const tint::Program* inputProgram,
const char* entryPointName,
SingleShaderStage stage,
const PipelineLayout* layout,
uint32_t sampleMask,
const RenderPipeline* renderPipeline,
std::string* remappedEntryPointName,
bool* needsStorageBufferLength,
bool* hasInvariantAttribute,
std::vector<uint32_t>* workgroupAllocations) {
ScopedTintICEHandler scopedICEHandler(GetDevice());
const RenderPipeline* renderPipeline) {
ScopedTintICEHandler scopedICEHandler(device);
std::ostringstream errorStream;
errorStream << "Tint MSL failure:" << std::endl;
@ -69,7 +133,6 @@ ResultOrError<std::string> ShaderModule::TranslateToMSL(
using BindingRemapper = tint::transform::BindingRemapper;
using BindingPoint = tint::transform::BindingPoint;
BindingRemapper::BindingPoints bindingPoints;
BindingRemapper::AccessControls accessControls;
for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
const BindGroupLayoutBase::BindingMap& bindingMap =
@ -93,21 +156,13 @@ ResultOrError<std::string> ShaderModule::TranslateToMSL(
}
}
tint::transform::Manager transformManager;
tint::transform::DataMap transformInputs;
// We only remap bindings for the target entry point, so we need to strip all other entry
// points to avoid generating invalid bindings for them.
transformManager.Add<tint::transform::SingleEntryPoint>();
transformInputs.Add<tint::transform::SingleEntryPoint::Config>(entryPointName);
AddExternalTextureTransform(layout, &transformManager, &transformInputs);
auto externalTextureBindings = BuildExternalTextureTransformBindings(layout);
std::optional<tint::transform::VertexPulling::Config> vertexPullingTransformConfig;
if (stage == SingleShaderStage::Vertex &&
GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) {
transformManager.Add<tint::transform::VertexPulling>();
AddVertexPullingTransformConfig(*renderPipeline, entryPointName, kPullingBufferBindingSet,
&transformInputs);
device->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) {
vertexPullingTransformConfig = BuildVertexPullingTransformConfig(
*renderPipeline, entryPointName, kPullingBufferBindingSet);
for (VertexBufferSlot slot : IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(slot);
@ -121,39 +176,80 @@ ResultOrError<std::string> ShaderModule::TranslateToMSL(
}
}
}
if (GetDevice()->IsRobustnessEnabled()) {
MslCompilationRequest req = {};
req.inputProgram = inputProgram;
req.bindingPoints = std::move(bindingPoints);
req.externalTextureBindings = std::move(externalTextureBindings);
req.vertexPullingTransformConfig = std::move(vertexPullingTransformConfig);
req.entryPointName = entryPointName;
req.sampleMask = sampleMask;
req.emitVertexPointSize =
stage == SingleShaderStage::Vertex &&
renderPipeline->GetPrimitiveTopology() == wgpu::PrimitiveTopology::PointList;
req.isRobustnessEnabled = device->IsRobustnessEnabled();
req.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming);
req.tracePlatform = UnsafeUnkeyedValue(device->GetPlatform());
CacheResult<MslCompilation> mslCompilation;
DAWN_TRY_LOAD_OR_RUN(
mslCompilation, device, std::move(req), MslCompilation::FromBlob,
[](MslCompilationRequest r) -> ResultOrError<MslCompilation> {
tint::transform::Manager transformManager;
tint::transform::DataMap transformInputs;
// We only remap bindings for the target entry point, so we need to strip all other
// entry points to avoid generating invalid bindings for them.
transformManager.Add<tint::transform::SingleEntryPoint>();
transformInputs.Add<tint::transform::SingleEntryPoint::Config>(r.entryPointName);
if (!r.externalTextureBindings.empty()) {
transformManager.Add<tint::transform::MultiplanarExternalTexture>();
transformInputs.Add<tint::transform::MultiplanarExternalTexture::NewBindingPoints>(
std::move(r.externalTextureBindings));
}
if (r.vertexPullingTransformConfig) {
transformManager.Add<tint::transform::VertexPulling>();
transformInputs.Add<tint::transform::VertexPulling::Config>(
std::move(r.vertexPullingTransformConfig).value());
}
if (r.isRobustnessEnabled) {
transformManager.Add<tint::transform::Robustness>();
}
transformManager.Add<tint::transform::BindingRemapper>();
transformManager.Add<BindingRemapper>();
transformInputs.Add<BindingRemapper::Remappings>(std::move(r.bindingPoints),
BindingRemapper::AccessControls{},
/* mayCollide */ true);
transformManager.Add<tint::transform::Renamer>();
if (GetDevice()->IsToggleEnabled(Toggle::DisableSymbolRenaming)) {
if (r.disableSymbolRenaming) {
// We still need to rename MSL reserved keywords
transformInputs.Add<tint::transform::Renamer::Config>(
tint::transform::Renamer::Target::kMslKeywords);
}
transformInputs.Add<BindingRemapper::Remappings>(std::move(bindingPoints),
std::move(accessControls),
/* mayCollide */ true);
tint::Program program;
tint::transform::DataMap transformOutputs;
{
TRACE_EVENT0(GetDevice()->GetPlatform(), General, "RunTransforms");
DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, GetTintProgram(), transformInputs,
TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "RunTransforms");
DAWN_TRY_ASSIGN(program,
RunTransforms(&transformManager, r.inputProgram, transformInputs,
&transformOutputs, nullptr));
}
std::string remappedEntryPointName;
if (auto* data = transformOutputs.Get<tint::transform::Renamer::Data>()) {
auto it = data->remappings.find(entryPointName);
auto it = data->remappings.find(r.entryPointName);
if (it != data->remappings.end()) {
*remappedEntryPointName = it->second;
remappedEntryPointName = it->second;
} else {
DAWN_INVALID_IF(!GetDevice()->IsToggleEnabled(Toggle::DisableSymbolRenaming),
DAWN_INVALID_IF(!r.disableSymbolRenaming,
"Could not find remapped name for entry point.");
*remappedEntryPointName = entryPointName;
remappedEntryPointName = r.entryPointName;
}
} else {
return DAWN_FORMAT_VALIDATION_ERROR("Transform output missing renamer data.");
@ -161,22 +257,46 @@ ResultOrError<std::string> ShaderModule::TranslateToMSL(
tint::writer::msl::Options options;
options.buffer_size_ubo_index = kBufferLengthBufferSlot;
options.fixed_sample_mask = sampleMask;
options.disable_workgroup_init = GetDevice()->IsToggleEnabled(Toggle::DisableWorkgroupInit);
options.emit_vertex_point_size =
stage == SingleShaderStage::Vertex &&
renderPipeline->GetPrimitiveTopology() == wgpu::PrimitiveTopology::PointList;
TRACE_EVENT0(GetDevice()->GetPlatform(), General, "tint::writer::msl::Generate");
options.fixed_sample_mask = r.sampleMask;
options.disable_workgroup_init = r.disableWorkgroupInit;
options.emit_vertex_point_size = r.emitVertexPointSize;
TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "tint::writer::msl::Generate");
auto result = tint::writer::msl::Generate(&program, options);
DAWN_INVALID_IF(!result.success, "An error occured while generating MSL: %s.", result.error);
DAWN_INVALID_IF(!result.success, "An error occured while generating MSL: %s.",
result.error);
*needsStorageBufferLength = result.needs_storage_buffer_sizes;
*hasInvariantAttribute = result.has_invariant_attribute;
*workgroupAllocations = std::move(result.workgroup_allocations[*remappedEntryPointName]);
// Metal uses Clang to compile the shader as C++14. Disable everything in the -Wall
// category. -Wunused-variable in particular comes up a lot in generated code, and some
// (old?) Metal drivers accidentally treat it as a MTLLibraryErrorCompileError instead
// of a warning.
result.msl = R"(
#ifdef __clang__
#pragma clang diagnostic ignored "-Wall"
#endif
)" + result.msl;
return std::move(result.msl);
auto workgroupAllocations =
std::move(result.workgroup_allocations[remappedEntryPointName]);
return MslCompilation{
std::move(result.msl),
std::move(remappedEntryPointName),
result.needs_storage_buffer_sizes,
result.has_invariant_attribute,
std::move(workgroupAllocations),
};
});
if (device->IsToggleEnabled(Toggle::DumpShaders)) {
std::ostringstream dumpedMsg;
dumpedMsg << "/* Dumped generated MSL */" << std::endl << mslCompilation->msl;
device->EmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str());
}
return mslCompilation;
}
} // namespace
MaybeError ShaderModule::CreateFunction(const char* entryPointName,
SingleShaderStage stage,
const PipelineLayout* layout,
@ -194,33 +314,17 @@ MaybeError ShaderModule::CreateFunction(const char* entryPointName,
ASSERT(renderPipeline != nullptr);
}
std::string remappedEntryPointName;
std::string msl;
bool hasInvariantAttribute = false;
DAWN_TRY_ASSIGN(msl, TranslateToMSL(entryPointName, stage, layout, sampleMask, renderPipeline,
&remappedEntryPointName, &out->needsStorageBufferLength,
&hasInvariantAttribute, &out->workgroupAllocations));
CacheResult<MslCompilation> mslCompilation;
DAWN_TRY_ASSIGN(mslCompilation, TranslateToMSL(GetDevice(), GetTintProgram(), entryPointName,
stage, layout, sampleMask, renderPipeline));
out->needsStorageBufferLength = mslCompilation->needsStorageBufferLength;
out->workgroupAllocations = std::move(mslCompilation->workgroupAllocations);
// Metal uses Clang to compile the shader as C++14. Disable everything in the -Wall
// category. -Wunused-variable in particular comes up a lot in generated code, and some
// (old?) Metal drivers accidentally treat it as a MTLLibraryErrorCompileError instead
// of a warning.
msl = R"(
#ifdef __clang__
#pragma clang diagnostic ignored "-Wall"
#endif
)" + msl;
if (GetDevice()->IsToggleEnabled(Toggle::DumpShaders)) {
std::ostringstream dumpedMsg;
dumpedMsg << "/* Dumped generated MSL */" << std::endl << msl;
GetDevice()->EmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str());
}
NSRef<NSString> mslSource = AcquireNSRef([[NSString alloc] initWithUTF8String:msl.c_str()]);
NSRef<NSString> mslSource =
AcquireNSRef([[NSString alloc] initWithUTF8String:mslCompilation->msl.c_str()]);
NSRef<MTLCompileOptions> compileOptions = AcquireNSRef([[MTLCompileOptions alloc] init]);
if (hasInvariantAttribute) {
if (mslCompilation->hasInvariantAttribute) {
if (@available(macOS 11.0, iOS 13.0, *)) {
(*compileOptions).preserveInvariance = true;
}
@ -243,8 +347,8 @@ MaybeError ShaderModule::CreateFunction(const char* entryPointName,
}
ASSERT(library != nil);
NSRef<NSString> name =
AcquireNSRef([[NSString alloc] initWithUTF8String:remappedEntryPointName.c_str()]);
NSRef<NSString> name = AcquireNSRef(
[[NSString alloc] initWithUTF8String:mslCompilation->remappedEntryPointName.c_str()]);
{
TRACE_EVENT0(GetDevice()->GetPlatform(), General, "MTLLibrary::newFunctionWithName");
@ -269,6 +373,10 @@ MaybeError ShaderModule::CreateFunction(const char* entryPointName,
}
}
if (BlobCache* cache = GetDevice()->GetBlobCache()) {
cache->EnsureStored(mslCompilation);
}
if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
GetEntryPoint(entryPointName).usedVertexInputs.any()) {
out->needsStorageBufferLength = true;

View File

@ -90,7 +90,12 @@ ResultOrError<std::string> ShaderModule::TranslateToGLSL(const char* entryPointN
tint::transform::Manager transformManager;
tint::transform::DataMap transformInputs;
AddExternalTextureTransform(layout, &transformManager, &transformInputs);
auto externalTextureBindings = BuildExternalTextureTransformBindings(layout);
if (!externalTextureBindings.empty()) {
transformManager.Add<tint::transform::MultiplanarExternalTexture>();
transformInputs.Add<tint::transform::MultiplanarExternalTexture::NewBindingPoints>(
std::move(externalTextureBindings));
}
tint::Program program;
DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, GetTintProgram(), transformInputs,

View File

@ -23,6 +23,8 @@
#include <utility>
#include <vector>
#include <optional>
#include "dawn/common/Platform.h"
#include "dawn/common/TypedInteger.h"
#include "dawn/native/Error.h"
@ -57,6 +59,14 @@ MaybeError StreamOut(Source* s, T* v) {
return Stream<T>::Read(s, v);
}
// Helper to take an rvalue passed to StreamOut and forward it as a pointer.
// This makes it possible to pass output wrappers like stream::StructMembers inline.
// For example: `DAWN_TRY(StreamOut(&source, stream::StructMembers(...)));`
template <typename T>
MaybeError StreamOut(Source* s, T&& v) {
return StreamOut(s, &v);
}
// Helper to call StreamIn on a parameter pack.
template <typename T, typename... Ts>
constexpr void StreamIn(Sink* s, const T& v, const Ts&... vs) {
@ -187,6 +197,19 @@ class Stream<T, std::enable_if_t<std::is_pointer_v<T>>> {
}
};
// Stream specialization for std::optional
template <typename T>
class Stream<std::optional<T>> {
public:
static void Write(stream::Sink* sink, const std::optional<T>& t) {
bool hasValue = t.has_value();
StreamIn(sink, hasValue);
if (hasValue) {
StreamIn(sink, *t);
}
}
};
// Stream specialization for fixed arrays of fundamental types.
template <typename T, size_t N>
class Stream<T[N], std::enable_if_t<std::is_fundamental_v<T>>> {

View File

@ -232,7 +232,7 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
}
// Transform external textures into the binding locations specified in the bgl
// TODO(dawn:1082): Replace this block with ShaderModuleBase::AddExternalTextureTransform.
// TODO(dawn:1082): Replace this block with BuildExternalTextureTransformBindings.
tint::transform::MultiplanarExternalTexture::BindingsMap newBindingsMap;
for (BindGroupIndex i : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
const BindGroupLayoutBase* bgl = layout->GetBindGroupLayout(i);

View File

@ -108,8 +108,8 @@ class PipelineCachingTests : public DawnTest {
const EntryCounts counts = {
// pipeline caching is only implemented on D3D12/Vulkan
IsD3D12() || IsVulkan() ? 1u : 0u,
// shader module caching is only implemented on Vulkan
IsVulkan() ? 1u : 0u,
// shader module caching is only implemented on Vulkan/Metal
IsVulkan() || IsMetal() ? 1u : 0u,
};
NiceMock<CachingInterfaceMock> mMockCache;
};
@ -406,7 +406,7 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheDescriptorNegativeCase
{
wgpu::Device device = CreateDevice();
utils::ComboRenderPipelineDescriptor desc;
desc.primitive.topology = wgpu::PrimitiveTopology::PointList;
desc.EnableDepthStencil();
desc.vertex.module = utils::CreateShaderModule(device, kVertexShaderDefault.data());
desc.vertex.entryPoint = "main";
desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data());
@ -586,8 +586,9 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheLayout) {
device.CreateRenderPipeline(&desc));
}
// Cache should hit for the shaders, but not for the pipeline.
// The shader is different but compiles to the same due to binding number remapping.
// Cache should not hit for the fragment shader, but should hit for the pipeline.
// Except for D3D12, the shader is different but compiles to the same due to binding number
// remapping.
{
wgpu::Device device = CreateDevice();
utils::ComboRenderPipelineDescriptor desc;
@ -604,9 +605,15 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheLayout) {
{1, wgpu::ShaderStage::Fragment, wgpu::BufferBindingType::Uniform},
}),
});
EXPECT_CACHE_STATS(mMockCache, Hit(2 * counts.shaderModule), Add(counts.pipeline),
if (!IsD3D12()) {
EXPECT_CACHE_STATS(mMockCache, Hit(counts.shaderModule + counts.pipeline),
Add(counts.shaderModule), device.CreateRenderPipeline(&desc));
} else {
EXPECT_CACHE_STATS(mMockCache, Hit(counts.shaderModule),
Add(counts.shaderModule + counts.pipeline),
device.CreateRenderPipeline(&desc));
}
}
}
// Tests that pipeline creation does not hits the cache when it is enabled but we use different

View File

@ -22,6 +22,7 @@
#include "dawn/common/TypedInteger.h"
#include "dawn/native/Blob.h"
#include "dawn/native/VisitableMembers.h"
#include "dawn/native/stream/BlobSource.h"
#include "dawn/native/stream/ByteVectorSink.h"
#include "dawn/native/stream/Stream.h"
@ -193,6 +194,23 @@ TEST(SerializeTests, StdPair) {
EXPECT_CACHE_KEY_EQ(std::make_pair(s, uint32_t(42)), expected);
}
// Test that ByteVectorSink serializes std::optional as expected.
TEST(SerializeTests, StdOptional) {
std::string_view s = "webgpu";
{
ByteVectorSink expected;
StreamIn(&expected, true, s);
EXPECT_CACHE_KEY_EQ(std::optional(s), expected);
}
{
ByteVectorSink expected;
StreamIn(&expected, false);
EXPECT_CACHE_KEY_EQ(std::optional<std::string_view>(), expected);
}
}
// Test that ByteVectorSink serializes std::unordered_map as expected.
TEST(SerializeTests, StdUnorderedMap) {
std::unordered_map<uint32_t, std::string_view> m;
@ -256,6 +274,41 @@ TEST(StreamTests, SerializeDeserializeParamPack) {
EXPECT_EQ(c, cOut);
}
#define FOO_MEMBERS(X) \
X(int, a) \
X(float, b) \
X(std::string, c)
struct Foo {
DAWN_VISITABLE_MEMBERS(FOO_MEMBERS)
#undef FOO_MEMBERS
};
// Test that serializing then deserializing a struct made with DAWN_VISITABLE_MEMBERS works as
// expected.
TEST(StreamTests, SerializeDeserializeVisitableMembers) {
Foo foo{1, 2, "3"};
ByteVectorSink sink;
foo.VisitAll([&](const auto&... members) { StreamIn(&sink, members...); });
// Test that the serialization is correct.
{
ByteVectorSink expected;
StreamIn(&expected, foo.a, foo.b, foo.c);
EXPECT_THAT(sink, VectorEq(expected));
}
// Test that deserialization works for StructMembers, passed inline.
{
BlobSource src(CreateBlob(sink));
Foo out;
auto err = out.VisitAll([&](auto&... members) { return StreamOut(&src, &members...); });
EXPECT_FALSE(err.IsError());
EXPECT_EQ(foo.a, out.a);
EXPECT_EQ(foo.b, out.b);
EXPECT_EQ(foo.c, out.c);
}
}
template <size_t N>
std::bitset<N - 1> BitsetFromBitString(const char (&str)[N]) {
// N - 1 because the last character is the null terminator.