Cache WGSL -> DXBC/DXIL compilation

Bug: dawn:1480
Change-Id: I858111f62be457c2e7cd5017bbf4c10e76395e83
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/95340
Reviewed-by: Loko Kung <lokokung@google.com>
Commit-Queue: Austin Eng <enga@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Austin Eng 2022-08-05 20:09:07 +00:00 committed by Dawn LUCI CQ
parent 9443cebd53
commit a0df1384f2
14 changed files with 529 additions and 452 deletions

View File

@ -17,6 +17,7 @@
#include "dawn/common/Assert.h" #include "dawn/common/Assert.h"
#include "dawn/common/Math.h" #include "dawn/common/Math.h"
#include "dawn/native/Blob.h" #include "dawn/native/Blob.h"
#include "dawn/native/stream/Stream.h"
namespace dawn::native { namespace dawn::native {
@ -99,4 +100,29 @@ void Blob::AlignTo(size_t alignment) {
*this = std::move(blob); *this = std::move(blob);
} }
template <>
void stream::Stream<Blob>::Write(stream::Sink* s, const Blob& b) {
size_t size = b.Size();
StreamIn(s, size);
if (size > 0) {
void* ptr = s->GetSpace(size);
memcpy(ptr, b.Data(), size);
}
}
template <>
MaybeError stream::Stream<Blob>::Read(stream::Source* s, Blob* b) {
size_t size;
DAWN_TRY(StreamOut(s, &size));
if (size > 0) {
const void* ptr;
DAWN_TRY(s->Read(&ptr, size));
*b = CreateBlob(size);
memcpy(b->Data(), ptr, size);
} else {
*b = Blob();
}
return {};
}
} // namespace dawn::native } // namespace dawn::native

View File

@ -14,7 +14,8 @@
#include "dawn/native/stream/Stream.h" #include "dawn/native/stream/Stream.h"
#include "tint/tint.h" #include "dawn/native/TintUtils.h"
#include "tint/writer/array_length_from_uniform_options.h"
namespace dawn::native { namespace dawn::native {
@ -96,4 +97,21 @@ void stream::Stream<tint::transform::VertexAttributeDescriptor>::Write(
StreamIn(sink, attrib.format, attrib.offset, attrib.shader_location); StreamIn(sink, attrib.format, attrib.offset, attrib.shader_location);
} }
// static
template <>
void stream::Stream<tint::writer::ArrayLengthFromUniformOptions>::Write(
stream::Sink* sink,
const tint::writer::ArrayLengthFromUniformOptions& o) {
static_assert(offsetof(tint::writer::ArrayLengthFromUniformOptions, ubo_binding) == 0,
"Please update serialization for tint::writer::ArrayLengthFromUniformOptions");
static_assert(
offsetof(tint::writer::ArrayLengthFromUniformOptions, bindpoint_to_size_index) == 8,
"Please update serialization for tint::writer::ArrayLengthFromUniformOptions");
static_assert(
sizeof(tint::writer::ArrayLengthFromUniformOptions) ==
8 + sizeof(tint::writer::ArrayLengthFromUniformOptions::bindpoint_to_size_index),
"Please update serialization for tint::writer::ArrayLengthFromUniformOptions");
StreamIn(sink, o.ubo_binding, o.bindpoint_to_size_index);
}
} // namespace dawn::native } // namespace dawn::native

View File

@ -185,7 +185,10 @@ tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig(
} // namespace dawn::native } // namespace dawn::native
bool std::less<tint::sem::BindingPoint>::operator()(const tint::sem::BindingPoint& a, namespace tint::sem {
const tint::sem::BindingPoint& b) const {
bool operator<(const BindingPoint& a, const BindingPoint& b) {
return std::tie(a.group, a.binding) < std::tie(b.group, b.binding); return std::tie(a.group, a.binding) < std::tie(b.group, b.binding);
} }
} // namespace tint::sem

View File

@ -49,10 +49,11 @@ tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig(
} // namespace dawn::native } // namespace dawn::native
// std::less operator for std::map containing BindingPoint namespace tint::sem {
template <>
struct std::less<tint::sem::BindingPoint> { // Defin operator< for std::map containing BindingPoint
bool operator()(const tint::sem::BindingPoint& a, const tint::sem::BindingPoint& b) const; bool operator<(const BindingPoint& a, const BindingPoint& b);
};
} // namespace tint::sem
#endif // SRC_DAWN_NATIVE_TINTUTILS_H_ #endif // SRC_DAWN_NATIVE_TINTUTILS_H_

View File

@ -28,4 +28,16 @@ Blob CreateBlob(ComPtr<ID3DBlob> blob) {
}); });
} }
Blob CreateBlob(ComPtr<IDxcBlob> blob) {
// Detach so the deleter callback can "own" the reference
IDxcBlob* ptr = blob.Detach();
return Blob::UnsafeCreateWithDeleter(reinterpret_cast<uint8_t*>(ptr->GetBufferPointer()),
ptr->GetBufferSize(), [=]() {
// Reattach and drop to delete it.
ComPtr<IDxcBlob> b;
b.Attach(ptr);
b = nullptr;
});
}
} // namespace dawn::native } // namespace dawn::native

View File

@ -18,5 +18,6 @@
namespace dawn::native { namespace dawn::native {
Blob CreateBlob(ComPtr<ID3DBlob> blob); Blob CreateBlob(ComPtr<ID3DBlob> blob);
Blob CreateBlob(ComPtr<IDxcBlob> blob);
} // namespace dawn::native } // namespace dawn::native

View File

@ -29,79 +29,41 @@
#include "dawn/common/Log.h" #include "dawn/common/Log.h"
#include "dawn/common/WindowsUtils.h" #include "dawn/common/WindowsUtils.h"
#include "dawn/native/CacheKey.h" #include "dawn/native/CacheKey.h"
#include "dawn/native/CacheRequest.h"
#include "dawn/native/Pipeline.h" #include "dawn/native/Pipeline.h"
#include "dawn/native/TintUtils.h" #include "dawn/native/TintUtils.h"
#include "dawn/native/d3d12/AdapterD3D12.h" #include "dawn/native/d3d12/AdapterD3D12.h"
#include "dawn/native/d3d12/BackendD3D12.h" #include "dawn/native/d3d12/BackendD3D12.h"
#include "dawn/native/d3d12/BindGroupLayoutD3D12.h" #include "dawn/native/d3d12/BindGroupLayoutD3D12.h"
#include "dawn/native/d3d12/BlobD3D12.h"
#include "dawn/native/d3d12/D3D12Error.h" #include "dawn/native/d3d12/D3D12Error.h"
#include "dawn/native/d3d12/DeviceD3D12.h" #include "dawn/native/d3d12/DeviceD3D12.h"
#include "dawn/native/d3d12/PipelineLayoutD3D12.h" #include "dawn/native/d3d12/PipelineLayoutD3D12.h"
#include "dawn/native/d3d12/PlatformFunctions.h" #include "dawn/native/d3d12/PlatformFunctions.h"
#include "dawn/native/d3d12/UtilsD3D12.h" #include "dawn/native/d3d12/UtilsD3D12.h"
#include "dawn/native/stream/BlobSource.h"
#include "dawn/native/stream/ByteVectorSink.h"
#include "dawn/platform/DawnPlatform.h" #include "dawn/platform/DawnPlatform.h"
#include "dawn/platform/tracing/TraceEvent.h" #include "dawn/platform/tracing/TraceEvent.h"
#include "tint/tint.h" #include "tint/tint.h"
namespace dawn::native::stream {
// Define no-op serializations for pD3DCompile, IDxcLibrary, and IDxcCompiler.
// These are output-only interfaces used to generate bytecode.
template <>
void Stream<IDxcLibrary*>::Write(Sink*, IDxcLibrary* const&) {}
template <>
void Stream<IDxcCompiler*>::Write(Sink*, IDxcCompiler* const&) {}
template <>
void Stream<pD3DCompile>::Write(Sink*, pD3DCompile const&) {}
} // namespace dawn::native::stream
namespace dawn::native::d3d12 { namespace dawn::native::d3d12 {
namespace { namespace {
uint64_t GetD3DCompilerVersion() {
return D3D_COMPILER_VERSION;
}
struct CompareBindingPoint {
constexpr bool operator()(const tint::transform::BindingPoint& lhs,
const tint::transform::BindingPoint& rhs) const {
if (lhs.group != rhs.group) {
return lhs.group < rhs.group;
} else {
return lhs.binding < rhs.binding;
}
}
};
void StreamIn(std::stringstream& output, const tint::ast::Access& access) {
output << access;
}
void StreamIn(std::stringstream& output, const tint::transform::BindingPoint& binding_point) {
output << "(BindingPoint";
output << " group=" << binding_point.group;
output << " binding=" << binding_point.binding;
output << ")";
}
template <typename T, typename = typename std::enable_if<std::is_fundamental<T>::value>::type>
void StreamIn(std::stringstream& output, const T& val) {
output << val;
}
template <typename T>
void StreamIn(std::stringstream& output,
const std::unordered_map<tint::transform::BindingPoint, T>& map) {
output << "(map";
std::map<tint::transform::BindingPoint, T, CompareBindingPoint> sorted(map.begin(), map.end());
for (auto& [bindingPoint, value] : sorted) {
output << " ";
StreamIn(output, bindingPoint);
output << "=";
StreamIn(output, value);
}
output << ")";
}
void StreamIn(std::stringstream& output,
const tint::writer::ArrayLengthFromUniformOptions& arrayLengthFromUniform) {
output << "(ArrayLengthFromUniformOptions";
output << " ubo_binding=";
StreamIn(output, arrayLengthFromUniform.ubo_binding);
output << " bindpoint_to_size_index=";
StreamIn(output, arrayLengthFromUniform.bindpoint_to_size_index);
output << ")";
}
// 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough // 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
std::string FloatToStringWithPrecision(float v, std::streamsize n = 8) { std::string FloatToStringWithPrecision(float v, std::streamsize n = 8) {
@ -130,253 +92,113 @@ std::string GetHLSLValueString(EntryPointMetadata::Override::Type dawnType,
constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_"; constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
void GetOverridableConstantsDefines( using DefineStrings = std::vector<std::pair<std::string, std::string>>;
std::vector<std::pair<std::string, std::string>>* defineStrings,
const PipelineConstantEntries* pipelineConstantEntries, DefineStrings GetOverridableConstantsDefines(
const EntryPointMetadata::OverridesMap* shaderEntryPointConstants) { const PipelineConstantEntries& pipelineConstantEntries,
const EntryPointMetadata::OverridesMap& shaderEntryPointConstants) {
DefineStrings defineStrings;
std::unordered_set<std::string> overriddenConstants; std::unordered_set<std::string> overriddenConstants;
// Set pipeline overridden values // Set pipeline overridden values
for (const auto& [name, value] : *pipelineConstantEntries) { for (const auto& [name, value] : pipelineConstantEntries) {
overriddenConstants.insert(name); overriddenConstants.insert(name);
// This is already validated so `name` must exist // This is already validated so `name` must exist
const auto& moduleConstant = shaderEntryPointConstants->at(name); const auto& moduleConstant = shaderEntryPointConstants.at(name);
defineStrings->emplace_back( defineStrings.emplace_back(
kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)), kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)),
GetHLSLValueString(moduleConstant.type, nullptr, value)); GetHLSLValueString(moduleConstant.type, nullptr, value));
} }
// Set shader initialized default values // Set shader initialized default values
for (const auto& iter : *shaderEntryPointConstants) { for (const auto& iter : shaderEntryPointConstants) {
const std::string& name = iter.first; const std::string& name = iter.first;
if (overriddenConstants.count(name) != 0) { if (overriddenConstants.count(name) != 0) {
// This constant already has overridden value // This constant already has overridden value
continue; continue;
} }
const auto& moduleConstant = shaderEntryPointConstants->at(name); const auto& moduleConstant = shaderEntryPointConstants.at(name);
// Uninitialized default values are okay since they ar only defined to pass // Uninitialized default values are okay since they ar only defined to pass
// compilation but not used // compilation but not used
defineStrings->emplace_back( defineStrings.emplace_back(
kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)), kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)),
GetHLSLValueString(moduleConstant.type, &moduleConstant.defaultValue)); GetHLSLValueString(moduleConstant.type, &moduleConstant.defaultValue));
} }
return defineStrings;
} }
// The inputs to a shader compilation. These have been intentionally isolated from the enum class Compiler { FXC, DXC };
// device to help ensure that the pipeline cache key contains all inputs for compilation.
struct ShaderCompilationRequest {
enum Compiler { FXC, DXC };
// Common inputs #define HLSL_COMPILATION_REQUEST_MEMBERS(X) \
Compiler compiler; X(const tint::Program*, inputProgram) \
const tint::Program* program; X(std::string_view, entryPointName) \
const char* entryPointName; X(SingleShaderStage, stage) \
SingleShaderStage stage; X(uint32_t, shaderModel) \
uint32_t compileFlags; X(uint32_t, compileFlags) \
bool disableSymbolRenaming; X(Compiler, compiler) \
tint::transform::BindingRemapper::BindingPoints remappedBindingPoints; X(uint64_t, compilerVersion) \
tint::transform::BindingRemapper::AccessControls remappedAccessControls; X(std::wstring_view, dxcShaderProfile) \
bool isRobustnessEnabled; X(std::string_view, fxcShaderProfile) \
bool usesNumWorkgroups; X(pD3DCompile, d3dCompile) \
uint32_t numWorkgroupsRegisterSpace; X(IDxcLibrary*, dxcLibrary) \
uint32_t numWorkgroupsShaderRegister; X(IDxcCompiler*, dxcCompiler) \
tint::writer::ArrayLengthFromUniformOptions arrayLengthFromUniform; X(uint32_t, firstIndexOffsetShaderRegister) \
std::vector<std::pair<std::string, std::string>> defineStrings; X(uint32_t, firstIndexOffsetRegisterSpace) \
X(bool, usesNumWorkgroups) \
X(uint32_t, numWorkgroupsShaderRegister) \
X(uint32_t, numWorkgroupsRegisterSpace) \
X(DefineStrings, defineStrings) \
X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \
X(tint::writer::ArrayLengthFromUniformOptions, arrayLengthFromUniform) \
X(tint::transform::BindingRemapper::BindingPoints, remappedBindingPoints) \
X(tint::transform::BindingRemapper::AccessControls, remappedAccessControls) \
X(bool, disableSymbolRenaming) \
X(bool, isRobustnessEnabled) \
X(bool, disableWorkgroupInit) \
X(bool, dumpShaders)
// FXC/DXC common inputs #define D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS(X) \
bool disableWorkgroupInit; X(bool, hasShaderFloat16Feature) \
X(uint32_t, compileFlags) \
X(Compiler, compiler) \
X(uint64_t, compilerVersion) \
X(std::wstring_view, dxcShaderProfile) \
X(std::string_view, fxcShaderProfile) \
X(pD3DCompile, d3dCompile) \
X(IDxcLibrary*, dxcLibrary) \
X(IDxcCompiler*, dxcCompiler) \
X(DefineStrings, defineStrings)
// FXC inputs struct HlslCompilationRequest {
uint64_t fxcVersion; DAWN_VISITABLE_MEMBERS(HLSL_COMPILATION_REQUEST_MEMBERS)
// DXC inputs friend void StreamIn(stream::Sink* sink, const HlslCompilationRequest& r) {
uint64_t dxcVersion; r.VisitAll([&](const auto&... members) { StreamIn(sink, members...); });
const D3D12DeviceInfo* deviceInfo;
bool hasShaderFloat16Feature;
static ResultOrError<ShaderCompilationRequest> Create(
const char* entryPointName,
SingleShaderStage stage,
const PipelineLayout* layout,
uint32_t compileFlags,
const Device* device,
const tint::Program* program,
const EntryPointMetadata& entryPoint,
const ProgrammableStage& programmableStage) {
Compiler compiler;
uint64_t dxcVersion = 0;
if (device->IsToggleEnabled(Toggle::UseDXC)) {
compiler = Compiler::DXC;
DAWN_TRY_ASSIGN(dxcVersion,
ToBackend(device->GetAdapter())->GetBackend()->GetDXCompilerVersion());
} else {
compiler = Compiler::FXC;
}
using tint::transform::BindingPoint;
using tint::transform::BindingRemapper;
BindingRemapper::BindingPoints remappedBindingPoints;
BindingRemapper::AccessControls remappedAccessControls;
tint::writer::ArrayLengthFromUniformOptions arrayLengthFromUniform;
arrayLengthFromUniform.ubo_binding = {
layout->GetDynamicStorageBufferLengthsRegisterSpace(),
layout->GetDynamicStorageBufferLengthsShaderRegister()};
const BindingInfoArray& moduleBindingInfo = entryPoint.bindings;
for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
const auto& groupBindingInfo = moduleBindingInfo[group];
// d3d12::BindGroupLayout packs the bindings per HLSL register-space. We modify
// the Tint AST to make the "bindings" decoration match the offset chosen by
// d3d12::BindGroupLayout so that Tint produces HLSL with the correct registers
// assigned to each interface variable.
for (const auto& [binding, bindingInfo] : groupBindingInfo) {
BindingIndex bindingIndex = bgl->GetBindingIndex(binding);
BindingPoint srcBindingPoint{static_cast<uint32_t>(group),
static_cast<uint32_t>(binding)};
BindingPoint dstBindingPoint{static_cast<uint32_t>(group),
bgl->GetShaderRegister(bindingIndex)};
if (srcBindingPoint != dstBindingPoint) {
remappedBindingPoints.emplace(srcBindingPoint, dstBindingPoint);
}
// Declaring a read-only storage buffer in HLSL but specifying a storage
// buffer in the BGL produces the wrong output. Force read-only storage
// buffer bindings to be treated as UAV instead of SRV. Internal storage
// buffer is a storage buffer used in the internal pipeline.
const bool forceStorageBufferAsUAV =
(bindingInfo.buffer.type == wgpu::BufferBindingType::ReadOnlyStorage &&
(bgl->GetBindingInfo(bindingIndex).buffer.type ==
wgpu::BufferBindingType::Storage ||
bgl->GetBindingInfo(bindingIndex).buffer.type ==
kInternalStorageBufferBinding));
if (forceStorageBufferAsUAV) {
remappedAccessControls.emplace(srcBindingPoint, tint::ast::Access::kReadWrite);
}
}
// Add arrayLengthFromUniform options
{
for (const auto& bindingAndRegisterOffset :
layout->GetDynamicStorageBufferLengthInfo()[group].bindingAndRegisterOffsets) {
BindingNumber binding = bindingAndRegisterOffset.binding;
uint32_t registerOffset = bindingAndRegisterOffset.registerOffset;
BindingPoint bindingPoint{static_cast<uint32_t>(group),
static_cast<uint32_t>(binding)};
// Get the renamed binding point if it was remapped.
auto it = remappedBindingPoints.find(bindingPoint);
if (it != remappedBindingPoints.end()) {
bindingPoint = it->second;
}
arrayLengthFromUniform.bindpoint_to_size_index.emplace(bindingPoint,
registerOffset);
}
}
}
ShaderCompilationRequest request;
request.compiler = compiler;
request.program = program;
request.entryPointName = entryPointName;
request.stage = stage;
request.compileFlags = compileFlags;
request.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming);
request.remappedBindingPoints = std::move(remappedBindingPoints);
request.remappedAccessControls = std::move(remappedAccessControls);
request.isRobustnessEnabled = device->IsRobustnessEnabled();
request.disableWorkgroupInit = device->IsToggleEnabled(Toggle::DisableWorkgroupInit);
request.usesNumWorkgroups = entryPoint.usesNumWorkgroups;
request.numWorkgroupsShaderRegister = layout->GetNumWorkgroupsShaderRegister();
request.numWorkgroupsRegisterSpace = layout->GetNumWorkgroupsRegisterSpace();
request.arrayLengthFromUniform = std::move(arrayLengthFromUniform);
request.fxcVersion = compiler == Compiler::FXC ? GetD3DCompilerVersion() : 0;
request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0;
request.deviceInfo = &device->GetDeviceInfo();
request.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16);
GetOverridableConstantsDefines(
&request.defineStrings, &programmableStage.constants,
&programmableStage.module->GetEntryPoint(programmableStage.entryPoint).overrides);
return std::move(request);
}
// TODO(dawn:1341): Move to use CacheKey instead of the vector.
ResultOrError<std::vector<uint8_t>> CreateCacheKey() const {
// Generate the WGSL from the Tint program so it's normalized.
// TODO(tint:1180): Consider using a binary serialization of the tint AST for a more
// compact representation.
auto result = tint::writer::wgsl::Generate(program, tint::writer::wgsl::Options{});
if (!result.success) {
std::ostringstream errorStream;
errorStream << "Tint WGSL failure:" << std::endl;
errorStream << "Generator: " << result.error << std::endl;
return DAWN_INTERNAL_ERROR(errorStream.str().c_str());
}
std::stringstream stream;
// Prefix the key with the type to avoid collisions from another type that could
// have the same key.
stream << static_cast<uint32_t>(CacheKey::Type::Shader);
stream << "\n";
stream << result.wgsl.length();
stream << "\n";
stream << result.wgsl;
stream << "\n";
stream << "(ShaderCompilationRequest";
stream << " compiler=" << compiler;
stream << " entryPointName=" << entryPointName;
stream << " stage=" << uint32_t(stage);
stream << " compileFlags=" << compileFlags;
stream << " disableSymbolRenaming=" << disableSymbolRenaming;
stream << " remappedBindingPoints=";
StreamIn(stream, remappedBindingPoints);
stream << " remappedAccessControls=";
StreamIn(stream, remappedAccessControls);
stream << " useNumWorkgroups=" << usesNumWorkgroups;
stream << " numWorkgroupsRegisterSpace=" << numWorkgroupsRegisterSpace;
stream << " numWorkgroupsShaderRegister=" << numWorkgroupsShaderRegister;
stream << " arrayLengthFromUniform=";
StreamIn(stream, arrayLengthFromUniform);
stream << " shaderModel=" << deviceInfo->shaderModel;
stream << " disableWorkgroupInit=" << disableWorkgroupInit;
stream << " isRobustnessEnabled=" << isRobustnessEnabled;
stream << " fxcVersion=" << fxcVersion;
stream << " dxcVersion=" << dxcVersion;
stream << " hasShaderFloat16Feature=" << hasShaderFloat16Feature;
stream << " defines={";
for (const auto& [name, value] : defineStrings) {
stream << " <" << name << "," << value << ">";
}
stream << " }";
stream << ")";
stream << "\n";
return std::vector<uint8_t>(std::istreambuf_iterator<char>{stream},
std::istreambuf_iterator<char>{});
} }
}; };
struct D3DBytecodeCompilationRequest {
DAWN_VISITABLE_MEMBERS(D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS)
friend void StreamIn(stream::Sink* sink, const D3DBytecodeCompilationRequest& r) {
r.VisitAll([&](const auto&... members) { StreamIn(sink, members...); });
}
};
#define D3D_COMPILATION_REQUEST_MEMBERS(X) \
X(HlslCompilationRequest, hlsl) \
X(D3DBytecodeCompilationRequest, bytecode) \
X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, tracePlatform)
DAWN_MAKE_CACHE_REQUEST(D3DCompilationRequest, D3D_COMPILATION_REQUEST_MEMBERS);
#undef HLSL_COMPILATION_REQUEST_MEMBERS
#undef D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS
#undef D3D_COMPILATION_REQUEST_MEMBERS
std::vector<const wchar_t*> GetDXCArguments(uint32_t compileFlags, bool enable16BitTypes) { std::vector<const wchar_t*> GetDXCArguments(uint32_t compileFlags, bool enable16BitTypes) {
std::vector<const wchar_t*> arguments; std::vector<const wchar_t*> arguments;
if (compileFlags & D3DCOMPILE_ENABLE_BACKWARDS_COMPATIBILITY) { if (compileFlags & D3DCOMPILE_ENABLE_BACKWARDS_COMPATIBILITY) {
@ -429,25 +251,24 @@ std::vector<const wchar_t*> GetDXCArguments(uint32_t compileFlags, bool enable16
return arguments; return arguments;
} }
ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(IDxcLibrary* dxcLibrary, ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(const D3DBytecodeCompilationRequest& r,
IDxcCompiler* dxcCompiler, const std::string& entryPointName,
const ShaderCompilationRequest& request,
const std::string& hlslSource) { const std::string& hlslSource) {
ComPtr<IDxcBlobEncoding> sourceBlob; ComPtr<IDxcBlobEncoding> sourceBlob;
DAWN_TRY(CheckHRESULT(dxcLibrary->CreateBlobWithEncodingOnHeapCopy( DAWN_TRY(CheckHRESULT(r.dxcLibrary->CreateBlobWithEncodingFromPinned(
hlslSource.c_str(), hlslSource.length(), CP_UTF8, &sourceBlob), hlslSource.c_str(), hlslSource.length(), CP_UTF8, &sourceBlob),
"DXC create blob")); "DXC create blob"));
std::wstring entryPointW; std::wstring entryPointW;
DAWN_TRY_ASSIGN(entryPointW, ConvertStringToWstring(request.entryPointName)); DAWN_TRY_ASSIGN(entryPointW, ConvertStringToWstring(entryPointName));
std::vector<const wchar_t*> arguments = std::vector<const wchar_t*> arguments =
GetDXCArguments(request.compileFlags, request.hasShaderFloat16Feature); GetDXCArguments(r.compileFlags, r.hasShaderFloat16Feature);
// Build defines for overridable constants // Build defines for overridable constants
std::vector<std::pair<std::wstring, std::wstring>> defineStrings; std::vector<std::pair<std::wstring, std::wstring>> defineStrings;
defineStrings.reserve(request.defineStrings.size()); defineStrings.reserve(r.defineStrings.size());
for (const auto& [name, value] : request.defineStrings) { for (const auto& [name, value] : r.defineStrings) {
defineStrings.emplace_back(UTF8ToWStr(name.c_str()), UTF8ToWStr(value.c_str())); defineStrings.emplace_back(UTF8ToWStr(name.c_str()), UTF8ToWStr(value.c_str()));
} }
@ -458,12 +279,11 @@ ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(IDxcLibrary* dxcLibrary,
} }
ComPtr<IDxcOperationResult> result; ComPtr<IDxcOperationResult> result;
DAWN_TRY( DAWN_TRY(CheckHRESULT(
CheckHRESULT(dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(), r.dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(),
request.deviceInfo->shaderProfiles[request.stage].c_str(), r.dxcShaderProfile.data(), arguments.data(), arguments.size(),
arguments.data(), arguments.size(), dxcDefines.data(), dxcDefines.data(), dxcDefines.size(), nullptr, &result),
dxcDefines.size(), nullptr, &result), "DXC compile"));
"DXC compile"));
HRESULT hr; HRESULT hr;
DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status")); DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status"));
@ -543,31 +363,18 @@ std::string CompileFlagsToStringFXC(uint32_t compileFlags) {
return result; return result;
} }
ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(const PlatformFunctions* functions, ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(const D3DBytecodeCompilationRequest& r,
const ShaderCompilationRequest& request, const std::string& entryPointName,
const std::string& hlslSource) { const std::string& hlslSource) {
const char* targetProfile = nullptr;
switch (request.stage) {
case SingleShaderStage::Vertex:
targetProfile = "vs_5_1";
break;
case SingleShaderStage::Fragment:
targetProfile = "ps_5_1";
break;
case SingleShaderStage::Compute:
targetProfile = "cs_5_1";
break;
}
ComPtr<ID3DBlob> compiledShader; ComPtr<ID3DBlob> compiledShader;
ComPtr<ID3DBlob> errors; ComPtr<ID3DBlob> errors;
// Build defines for overridable constants // Build defines for overridable constants
const D3D_SHADER_MACRO* pDefines = nullptr; const D3D_SHADER_MACRO* pDefines = nullptr;
std::vector<D3D_SHADER_MACRO> fxcDefines; std::vector<D3D_SHADER_MACRO> fxcDefines;
if (request.defineStrings.size() > 0) { if (r.defineStrings.size() > 0) {
fxcDefines.reserve(request.defineStrings.size() + 1); fxcDefines.reserve(r.defineStrings.size() + 1);
for (const auto& [name, value] : request.defineStrings) { for (const auto& [name, value] : r.defineStrings) {
fxcDefines.push_back({name.c_str(), value.c_str()}); fxcDefines.push_back({name.c_str(), value.c_str()});
} }
// d3dCompile D3D_SHADER_MACRO* pDefines is a nullptr terminated array // d3dCompile D3D_SHADER_MACRO* pDefines is a nullptr terminated array
@ -575,36 +382,49 @@ ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(const PlatformFunctions* functi
pDefines = fxcDefines.data(); pDefines = fxcDefines.data();
} }
DAWN_INVALID_IF( DAWN_INVALID_IF(FAILED(r.d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, pDefines,
FAILED(functions->d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, pDefines, nullptr, entryPointName.c_str(), r.fxcShaderProfile.data(),
nullptr, request.entryPointName, targetProfile, r.compileFlags, 0, &compiledShader, &errors)),
request.compileFlags, 0, &compiledShader, &errors)), "D3D compile failed with: %s", static_cast<char*>(errors->GetBufferPointer()));
"D3D compile failed with: %s", static_cast<char*>(errors->GetBufferPointer()));
return std::move(compiledShader); return std::move(compiledShader);
} }
ResultOrError<std::string> TranslateToHLSL(dawn::platform::Platform* platform, ResultOrError<std::string> TranslateToHLSL(
const ShaderCompilationRequest& request, HlslCompilationRequest r,
std::string* remappedEntryPointName) { CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*> tracePlatform,
std::string* remappedEntryPointName,
bool* usesVertexOrInstanceIndex) {
std::ostringstream errorStream; std::ostringstream errorStream;
errorStream << "Tint HLSL failure:" << std::endl; errorStream << "Tint HLSL failure:" << std::endl;
tint::transform::Manager transformManager; tint::transform::Manager transformManager;
tint::transform::DataMap transformInputs; tint::transform::DataMap transformInputs;
if (request.isRobustnessEnabled) { if (!r.newBindingsMap.empty()) {
transformManager.Add<tint::transform::MultiplanarExternalTexture>();
transformInputs.Add<tint::transform::MultiplanarExternalTexture::NewBindingPoints>(
std::move(r.newBindingsMap));
}
if (r.stage == SingleShaderStage::Vertex) {
transformManager.Add<tint::transform::FirstIndexOffset>();
transformInputs.Add<tint::transform::FirstIndexOffset::BindingPoint>(
r.firstIndexOffsetShaderRegister, r.firstIndexOffsetRegisterSpace);
}
if (r.isRobustnessEnabled) {
transformManager.Add<tint::transform::Robustness>(); transformManager.Add<tint::transform::Robustness>();
} }
transformManager.Add<tint::transform::BindingRemapper>(); transformManager.Add<tint::transform::BindingRemapper>();
transformManager.Add<tint::transform::SingleEntryPoint>(); transformManager.Add<tint::transform::SingleEntryPoint>();
transformInputs.Add<tint::transform::SingleEntryPoint::Config>(request.entryPointName); transformInputs.Add<tint::transform::SingleEntryPoint::Config>(r.entryPointName.data());
transformManager.Add<tint::transform::Renamer>(); transformManager.Add<tint::transform::Renamer>();
if (request.disableSymbolRenaming) { if (r.disableSymbolRenaming) {
// We still need to rename HLSL reserved keywords // We still need to rename HLSL reserved keywords
transformInputs.Add<tint::transform::Renamer::Config>( transformInputs.Add<tint::transform::Renamer::Config>(
tint::transform::Renamer::Target::kHlslKeywords); tint::transform::Renamer::Target::kHlslKeywords);
@ -615,104 +435,92 @@ ResultOrError<std::string> TranslateToHLSL(dawn::platform::Platform* platform,
// different types. // different types.
const bool mayCollide = true; const bool mayCollide = true;
transformInputs.Add<tint::transform::BindingRemapper::Remappings>( transformInputs.Add<tint::transform::BindingRemapper::Remappings>(
std::move(request.remappedBindingPoints), std::move(request.remappedAccessControls), std::move(r.remappedBindingPoints), std::move(r.remappedAccessControls), mayCollide);
mayCollide);
tint::Program transformedProgram; tint::Program transformedProgram;
tint::transform::DataMap transformOutputs; tint::transform::DataMap transformOutputs;
{ {
TRACE_EVENT0(platform, General, "RunTransforms"); TRACE_EVENT0(tracePlatform.UnsafeGetValue(), General, "RunTransforms");
DAWN_TRY_ASSIGN(transformedProgram, DAWN_TRY_ASSIGN(transformedProgram,
RunTransforms(&transformManager, request.program, transformInputs, RunTransforms(&transformManager, r.inputProgram, transformInputs,
&transformOutputs, nullptr)); &transformOutputs, nullptr));
} }
if (auto* data = transformOutputs.Get<tint::transform::Renamer::Data>()) { if (auto* data = transformOutputs.Get<tint::transform::Renamer::Data>()) {
auto it = data->remappings.find(request.entryPointName); auto it = data->remappings.find(r.entryPointName.data());
if (it != data->remappings.end()) { if (it != data->remappings.end()) {
*remappedEntryPointName = it->second; *remappedEntryPointName = it->second;
} else { } else {
DAWN_INVALID_IF(!request.disableSymbolRenaming, DAWN_INVALID_IF(!r.disableSymbolRenaming,
"Could not find remapped name for entry point."); "Could not find remapped name for entry point.");
*remappedEntryPointName = request.entryPointName; *remappedEntryPointName = r.entryPointName;
} }
} else { } else {
return DAWN_FORMAT_VALIDATION_ERROR("Transform output missing renamer data."); return DAWN_FORMAT_VALIDATION_ERROR("Transform output missing renamer data.");
} }
if (r.stage == SingleShaderStage::Vertex) {
if (auto* data = transformOutputs.Get<tint::transform::FirstIndexOffset::Data>()) {
*usesVertexOrInstanceIndex = data->has_vertex_or_instance_index;
} else {
return DAWN_FORMAT_VALIDATION_ERROR(
"Transform output missing first index offset data.");
}
}
tint::writer::hlsl::Options options; tint::writer::hlsl::Options options;
options.disable_workgroup_init = request.disableWorkgroupInit; options.disable_workgroup_init = r.disableWorkgroupInit;
if (request.usesNumWorkgroups) { if (r.usesNumWorkgroups) {
options.root_constant_binding_point = tint::sem::BindingPoint{ options.root_constant_binding_point =
request.numWorkgroupsRegisterSpace, request.numWorkgroupsShaderRegister}; tint::sem::BindingPoint{r.numWorkgroupsRegisterSpace, r.numWorkgroupsShaderRegister};
} }
// TODO(dawn:549): HLSL generation outputs the indices into the // TODO(dawn:549): HLSL generation outputs the indices into the
// array_length_from_uniform buffer that were actually used. When the blob cache can // array_length_from_uniform buffer that were actually used. When the blob cache can
// store more than compiled shaders, we should reflect these used indices and store // store more than compiled shaders, we should reflect these used indices and store
// them as well. This would allow us to only upload root constants that are actually // them as well. This would allow us to only upload root constants that are actually
// read by the shader. // read by the shader.
options.array_length_from_uniform = request.arrayLengthFromUniform; options.array_length_from_uniform = r.arrayLengthFromUniform;
TRACE_EVENT0(platform, General, "tint::writer::hlsl::Generate"); TRACE_EVENT0(tracePlatform.UnsafeGetValue(), General, "tint::writer::hlsl::Generate");
auto result = tint::writer::hlsl::Generate(&transformedProgram, options); auto result = tint::writer::hlsl::Generate(&transformedProgram, options);
DAWN_INVALID_IF(!result.success, "An error occured while generating HLSL: %s", result.error); DAWN_INVALID_IF(!result.success, "An error occured while generating HLSL: %s", result.error);
return std::move(result.hlsl); return std::move(result.hlsl);
} }
template <typename F> ResultOrError<CompiledShader> CompileShader(D3DCompilationRequest r) {
MaybeError CompileShader(dawn::platform::Platform* platform, CompiledShader compiledShader;
const PlatformFunctions* functions,
IDxcLibrary* dxcLibrary,
IDxcCompiler* dxcCompiler,
ShaderCompilationRequest&& request,
bool dumpShaders,
F&& DumpShadersEmitLog,
CompiledShader* compiledShader) {
// Compile the source shader to HLSL. // Compile the source shader to HLSL.
std::string hlslSource;
std::string remappedEntryPoint; std::string remappedEntryPoint;
DAWN_TRY_ASSIGN(hlslSource, TranslateToHLSL(platform, request, &remappedEntryPoint)); DAWN_TRY_ASSIGN(compiledShader.hlslSource,
if (dumpShaders) { TranslateToHLSL(std::move(r.hlsl), r.tracePlatform, &remappedEntryPoint,
std::ostringstream dumpedMsg; &compiledShader.usesVertexOrInstanceIndex));
dumpedMsg << "/* Dumped generated HLSL */" << std::endl << hlslSource;
DumpShadersEmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str()); switch (r.bytecode.compiler) {
} case Compiler::DXC: {
request.entryPointName = remappedEntryPoint.c_str(); TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "CompileShaderDXC");
switch (request.compiler) { ComPtr<IDxcBlob> compiledDXCShader;
case ShaderCompilationRequest::Compiler::DXC: { DAWN_TRY_ASSIGN(compiledDXCShader, CompileShaderDXC(r.bytecode, remappedEntryPoint,
TRACE_EVENT0(platform, General, "CompileShaderDXC"); compiledShader.hlslSource));
DAWN_TRY_ASSIGN(compiledShader->compiledDXCShader, compiledShader.shaderBlob = CreateBlob(std::move(compiledDXCShader));
CompileShaderDXC(dxcLibrary, dxcCompiler, request, hlslSource));
break; break;
} }
case ShaderCompilationRequest::Compiler::FXC: { case Compiler::FXC: {
TRACE_EVENT0(platform, General, "CompileShaderFXC"); TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "CompileShaderFXC");
DAWN_TRY_ASSIGN(compiledShader->compiledFXCShader, ComPtr<ID3DBlob> compiledFXCShader;
CompileShaderFXC(functions, request, hlslSource)); DAWN_TRY_ASSIGN(compiledFXCShader, CompileShaderFXC(r.bytecode, remappedEntryPoint,
compiledShader.hlslSource));
compiledShader.shaderBlob = CreateBlob(std::move(compiledFXCShader));
break; break;
} }
} }
if (dumpShaders && request.compiler == ShaderCompilationRequest::Compiler::FXC) { // If dumpShaders is false, we don't need the HLSL for logging. Clear the contents so it
std::ostringstream dumpedMsg; // isn't stored into the cache.
dumpedMsg << "/* FXC compile flags */ " << std::endl if (!r.hlsl.dumpShaders) {
<< CompileFlagsToStringFXC(request.compileFlags) << std::endl; compiledShader.hlslSource = "";
dumpedMsg << "/* Dumped disassembled DXBC */" << std::endl;
ComPtr<ID3DBlob> disassembly;
if (FAILED(functions->d3dDisassemble(compiledShader->compiledFXCShader->GetBufferPointer(),
compiledShader->compiledFXCShader->GetBufferSize(), 0,
nullptr, &disassembly))) {
dumpedMsg << "D3D disassemble failed" << std::endl;
} else {
dumpedMsg << reinterpret_cast<const char*>(disassembly->GetBufferPointer());
}
DumpShadersEmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str());
} }
return compiledShader;
return {};
} }
} // anonymous namespace } // anonymous namespace
@ -741,74 +549,202 @@ ResultOrError<CompiledShader> ShaderModule::Compile(const ProgrammableStage& pro
SingleShaderStage stage, SingleShaderStage stage,
const PipelineLayout* layout, const PipelineLayout* layout,
uint32_t compileFlags) { uint32_t compileFlags) {
TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleD3D12::Compile"); Device* device = ToBackend(GetDevice());
TRACE_EVENT0(device->GetPlatform(), General, "ShaderModuleD3D12::Compile");
ASSERT(!IsError()); ASSERT(!IsError());
ScopedTintICEHandler scopedICEHandler(GetDevice()); ScopedTintICEHandler scopedICEHandler(device);
const EntryPointMetadata& entryPoint = GetEntryPoint(programmableStage.entryPoint);
Device* device = ToBackend(GetDevice()); D3DCompilationRequest req = {};
req.tracePlatform = UnsafeUnkeyedValue(device->GetPlatform());
req.hlsl.shaderModel = device->GetDeviceInfo().shaderModel;
req.hlsl.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming);
req.hlsl.isRobustnessEnabled = device->IsRobustnessEnabled();
req.hlsl.disableWorkgroupInit = device->IsToggleEnabled(Toggle::DisableWorkgroupInit);
req.hlsl.dumpShaders = device->IsToggleEnabled(Toggle::DumpShaders);
CompiledShader compiledShader = {}; req.bytecode.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16);
req.bytecode.compileFlags = compileFlags;
tint::transform::Manager transformManager; req.bytecode.defineStrings =
tint::transform::DataMap transformInputs; GetOverridableConstantsDefines(programmableStage.constants, entryPoint.overrides);
if (device->IsToggleEnabled(Toggle::UseDXC)) {
const tint::Program* program = GetTintProgram(); req.bytecode.compiler = Compiler::DXC;
tint::Program programAsValue; req.bytecode.dxcLibrary = device->GetDxcLibrary().Get();
req.bytecode.dxcCompiler = device->GetDxcCompiler().Get();
auto externalTextureBindings = BuildExternalTextureTransformBindings(layout); DAWN_TRY_ASSIGN(req.bytecode.compilerVersion,
if (!externalTextureBindings.empty()) { ToBackend(device->GetAdapter())->GetBackend()->GetDXCompilerVersion());
transformManager.Add<tint::transform::MultiplanarExternalTexture>(); req.bytecode.dxcShaderProfile = device->GetDeviceInfo().shaderProfiles[stage];
transformInputs.Add<tint::transform::MultiplanarExternalTexture::NewBindingPoints>( } else {
std::move(externalTextureBindings)); req.bytecode.compiler = Compiler::FXC;
} req.bytecode.d3dCompile = device->GetFunctions()->d3dCompile;
req.bytecode.compilerVersion = D3D_COMPILER_VERSION;
if (stage == SingleShaderStage::Vertex) { switch (stage) {
transformManager.Add<tint::transform::FirstIndexOffset>(); case SingleShaderStage::Vertex:
transformInputs.Add<tint::transform::FirstIndexOffset::BindingPoint>( req.bytecode.fxcShaderProfile = "vs_5_1";
layout->GetFirstIndexOffsetShaderRegister(), break;
layout->GetFirstIndexOffsetRegisterSpace()); case SingleShaderStage::Fragment:
} req.bytecode.fxcShaderProfile = "ps_5_1";
break;
tint::transform::DataMap transformOutputs; case SingleShaderStage::Compute:
DAWN_TRY_ASSIGN(programAsValue, RunTransforms(&transformManager, program, transformInputs, req.bytecode.fxcShaderProfile = "cs_5_1";
&transformOutputs, nullptr)); break;
program = &programAsValue;
if (stage == SingleShaderStage::Vertex) {
if (auto* data = transformOutputs.Get<tint::transform::FirstIndexOffset::Data>()) {
// TODO(dawn:549): Consider adding this information to the pipeline cache once we
// can store more than the shader blob in it.
compiledShader.usesVertexOrInstanceIndex = data->has_vertex_or_instance_index;
} }
} }
ShaderCompilationRequest request; using tint::transform::BindingPoint;
DAWN_TRY_ASSIGN(request, using tint::transform::BindingRemapper;
ShaderCompilationRequest::Create(
programmableStage.entryPoint.c_str(), stage, layout, compileFlags, device,
program, GetEntryPoint(programmableStage.entryPoint), programmableStage));
// TODO(dawn:1341): Add shader cache key generation and caching for the compiled shader. BindingRemapper::BindingPoints remappedBindingPoints;
DAWN_TRY(CompileShader( BindingRemapper::AccessControls remappedAccessControls;
device->GetPlatform(), device->GetFunctions(),
device->IsToggleEnabled(Toggle::UseDXC) ? device->GetDxcLibrary().Get() : nullptr, tint::writer::ArrayLengthFromUniformOptions arrayLengthFromUniform;
device->IsToggleEnabled(Toggle::UseDXC) ? device->GetDxcCompiler().Get() : nullptr, arrayLengthFromUniform.ubo_binding = {layout->GetDynamicStorageBufferLengthsRegisterSpace(),
std::move(request), device->IsToggleEnabled(Toggle::DumpShaders), layout->GetDynamicStorageBufferLengthsShaderRegister()};
[&](WGPULoggingType loggingType, const char* message) {
GetDevice()->EmitLog(loggingType, message); const BindingInfoArray& moduleBindingInfo = entryPoint.bindings;
}, for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
&compiledShader)); const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
return std::move(compiledShader); const auto& groupBindingInfo = moduleBindingInfo[group];
// d3d12::BindGroupLayout packs the bindings per HLSL register-space. We modify
// the Tint AST to make the "bindings" decoration match the offset chosen by
// d3d12::BindGroupLayout so that Tint produces HLSL with the correct registers
// assigned to each interface variable.
for (const auto& [binding, bindingInfo] : groupBindingInfo) {
BindingIndex bindingIndex = bgl->GetBindingIndex(binding);
BindingPoint srcBindingPoint{static_cast<uint32_t>(group),
static_cast<uint32_t>(binding)};
BindingPoint dstBindingPoint{static_cast<uint32_t>(group),
bgl->GetShaderRegister(bindingIndex)};
if (srcBindingPoint != dstBindingPoint) {
remappedBindingPoints.emplace(srcBindingPoint, dstBindingPoint);
}
// Declaring a read-only storage buffer in HLSL but specifying a storage
// buffer in the BGL produces the wrong output. Force read-only storage
// buffer bindings to be treated as UAV instead of SRV. Internal storage
// buffer is a storage buffer used in the internal pipeline.
const bool forceStorageBufferAsUAV =
(bindingInfo.buffer.type == wgpu::BufferBindingType::ReadOnlyStorage &&
(bgl->GetBindingInfo(bindingIndex).buffer.type ==
wgpu::BufferBindingType::Storage ||
bgl->GetBindingInfo(bindingIndex).buffer.type == kInternalStorageBufferBinding));
if (forceStorageBufferAsUAV) {
remappedAccessControls.emplace(srcBindingPoint, tint::ast::Access::kReadWrite);
}
}
// Add arrayLengthFromUniform options
{
for (const auto& bindingAndRegisterOffset :
layout->GetDynamicStorageBufferLengthInfo()[group].bindingAndRegisterOffsets) {
BindingNumber binding = bindingAndRegisterOffset.binding;
uint32_t registerOffset = bindingAndRegisterOffset.registerOffset;
BindingPoint bindingPoint{static_cast<uint32_t>(group),
static_cast<uint32_t>(binding)};
// Get the renamed binding point if it was remapped.
auto it = remappedBindingPoints.find(bindingPoint);
if (it != remappedBindingPoints.end()) {
bindingPoint = it->second;
}
arrayLengthFromUniform.bindpoint_to_size_index.emplace(bindingPoint,
registerOffset);
}
}
}
req.hlsl.inputProgram = GetTintProgram();
req.hlsl.entryPointName = programmableStage.entryPoint.c_str();
req.hlsl.stage = stage;
req.hlsl.firstIndexOffsetShaderRegister = layout->GetFirstIndexOffsetShaderRegister();
req.hlsl.firstIndexOffsetRegisterSpace = layout->GetFirstIndexOffsetRegisterSpace();
req.hlsl.usesNumWorkgroups = entryPoint.usesNumWorkgroups;
req.hlsl.numWorkgroupsShaderRegister = layout->GetNumWorkgroupsShaderRegister();
req.hlsl.numWorkgroupsRegisterSpace = layout->GetNumWorkgroupsRegisterSpace();
req.hlsl.remappedBindingPoints = std::move(remappedBindingPoints);
req.hlsl.remappedAccessControls = std::move(remappedAccessControls);
req.hlsl.newBindingsMap = BuildExternalTextureTransformBindings(layout);
req.hlsl.arrayLengthFromUniform = std::move(arrayLengthFromUniform);
CacheResult<CompiledShader> compiledShader;
DAWN_TRY_LOAD_OR_RUN(compiledShader, device, std::move(req), CompiledShader::FromBlob,
CompileShader);
if (device->IsToggleEnabled(Toggle::DumpShaders)) {
std::ostringstream dumpedMsg;
dumpedMsg << "/* Dumped generated HLSL */" << std::endl
<< compiledShader->hlslSource << std::endl;
device->EmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str());
if (device->IsToggleEnabled(Toggle::UseDXC)) {
dumpedMsg << "/* Dumped disassembled DXIL */" << std::endl;
D3D12_SHADER_BYTECODE code = compiledShader->GetD3D12ShaderBytecode();
ComPtr<IDxcBlobEncoding> dxcBlob;
ComPtr<IDxcBlobEncoding> disassembly;
if (FAILED(device->GetDxcLibrary()->CreateBlobWithEncodingFromPinned(
code.pShaderBytecode, code.BytecodeLength, 0, &dxcBlob)) ||
FAILED(device->GetDxcCompiler()->Disassemble(dxcBlob.Get(), &disassembly))) {
dumpedMsg << "DXC disassemble failed" << std::endl;
} else {
dumpedMsg << std::string_view(
static_cast<const char*>(disassembly->GetBufferPointer()),
disassembly->GetBufferSize());
}
} else {
dumpedMsg << "/* FXC compile flags */ " << std::endl
<< CompileFlagsToStringFXC(compileFlags) << std::endl;
dumpedMsg << "/* Dumped disassembled DXBC */" << std::endl;
ComPtr<ID3DBlob> disassembly;
D3D12_SHADER_BYTECODE code = compiledShader->GetD3D12ShaderBytecode();
if (FAILED(device->GetFunctions()->d3dDisassemble(
code.pShaderBytecode, code.BytecodeLength, 0, nullptr, &disassembly))) {
dumpedMsg << "D3D disassemble failed" << std::endl;
} else {
dumpedMsg << std::string_view(
static_cast<const char*>(disassembly->GetBufferPointer()),
disassembly->GetBufferSize());
}
}
device->EmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str());
}
if (BlobCache* cache = device->GetBlobCache()) {
cache->EnsureStored(compiledShader);
}
// Clear the hlslSource. It is only used for logging and should not be used
// outside of the compilation.
CompiledShader result = compiledShader.Acquire();
result.hlslSource = "";
return result;
} }
D3D12_SHADER_BYTECODE CompiledShader::GetD3D12ShaderBytecode() const { D3D12_SHADER_BYTECODE CompiledShader::GetD3D12ShaderBytecode() const {
if (compiledFXCShader != nullptr) { return {shaderBlob.Data(), shaderBlob.Size()};
return {compiledFXCShader->GetBufferPointer(), compiledFXCShader->GetBufferSize()};
} else if (compiledDXCShader != nullptr) {
return {compiledDXCShader->GetBufferPointer(), compiledDXCShader->GetBufferSize()};
}
UNREACHABLE();
return {};
} }
} // namespace dawn::native::d3d12 } // namespace dawn::native::d3d12
namespace dawn::native {
// Define the implementation to store d3d12::CompiledShader into the BlobCache.
template <>
void BlobCache::Store<d3d12::CompiledShader>(const CacheKey& key, const d3d12::CompiledShader& c) {
stream::ByteVectorSink sink;
c.VisitAll([&](const auto&... members) { StreamIn(&sink, members...); });
Store(key, CreateBlob(std::move(sink)));
}
// Define the implementation to load d3d12::CompiledShader from a Blob.
// static
ResultOrError<d3d12::CompiledShader> d3d12::CompiledShader::FromBlob(Blob blob) {
stream::BlobSource source(std::move(blob));
d3d12::CompiledShader c;
DAWN_TRY(c.VisitAll([&](auto&... members) { return StreamOut(&source, &members...); }));
return c;
}
} // namespace dawn::native

View File

@ -15,8 +15,11 @@
#ifndef SRC_DAWN_NATIVE_D3D12_SHADERMODULED3D12_H_ #ifndef SRC_DAWN_NATIVE_D3D12_SHADERMODULED3D12_H_
#define SRC_DAWN_NATIVE_D3D12_SHADERMODULED3D12_H_ #define SRC_DAWN_NATIVE_D3D12_SHADERMODULED3D12_H_
#include "dawn/native/ShaderModule.h" #include <string>
#include "dawn/native/Blob.h"
#include "dawn/native/ShaderModule.h"
#include "dawn/native/VisitableMembers.h"
#include "dawn/native/d3d12/d3d12_platform.h" #include "dawn/native/d3d12/d3d12_platform.h"
namespace dawn::native { namespace dawn::native {
@ -28,14 +31,22 @@ namespace dawn::native::d3d12 {
class Device; class Device;
class PipelineLayout; class PipelineLayout;
// Manages a ref to one of the various representations of shader blobs and information used to #define COMPILED_SHADER_MEMBERS(X) \
// emulate vertex/instance index starts X(Blob, shaderBlob) \
X(std::string, hlslSource) \
X(bool, usesVertexOrInstanceIndex)
// `CompiledShader` holds a ref to one of the various representations of shader blobs and
// information used to emulate vertex/instance index starts. It also holds the `hlslSource` for the
// shader compilation, which is only transiently available during Compile, and cleared before it
// returns. It is not written to or loaded from the cache unless Toggle dump_shaders is true.
struct CompiledShader { struct CompiledShader {
ComPtr<ID3DBlob> compiledFXCShader; static ResultOrError<CompiledShader> FromBlob(Blob blob);
ComPtr<IDxcBlob> compiledDXCShader;
D3D12_SHADER_BYTECODE GetD3D12ShaderBytecode() const; D3D12_SHADER_BYTECODE GetD3D12ShaderBytecode() const;
bool usesVertexOrInstanceIndex; DAWN_VISITABLE_MEMBERS(COMPILED_SHADER_MEMBERS)
#undef COMPILED_SHADER_MEMBERS
}; };
class ShaderModule final : public ShaderModuleBase { class ShaderModule final : public ShaderModuleBase {

View File

@ -81,19 +81,19 @@ bool NeedBufferSizeWorkaroundForBufferTextureCopyOnD3D12(const BufferCopy& buffe
} // anonymous namespace } // anonymous namespace
ResultOrError<std::wstring> ConvertStringToWstring(const char* str) { ResultOrError<std::wstring> ConvertStringToWstring(std::string_view s) {
size_t len = strlen(str); size_t len = s.length();
if (len == 0) { if (len == 0) {
return std::wstring(); return std::wstring();
} }
int numChars = MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, str, len, nullptr, 0); int numChars = MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, s.data(), len, nullptr, 0);
if (numChars == 0) { if (numChars == 0) {
return DAWN_INTERNAL_ERROR("Failed to convert string to wide string"); return DAWN_INTERNAL_ERROR("Failed to convert string to wide string");
} }
std::wstring result; std::wstring result;
result.resize(numChars); result.resize(numChars);
int numConvertedChars = int numConvertedChars =
MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, str, len, &result[0], numChars); MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, s.data(), len, &result[0], numChars);
if (numConvertedChars != numChars) { if (numConvertedChars != numChars) {
return DAWN_INTERNAL_ERROR("Failed to convert string to wide string"); return DAWN_INTERNAL_ERROR("Failed to convert string to wide string");
} }

View File

@ -26,7 +26,7 @@
namespace dawn::native::d3d12 { namespace dawn::native::d3d12 {
ResultOrError<std::wstring> ConvertStringToWstring(const char* str); ResultOrError<std::wstring> ConvertStringToWstring(std::string_view s);
D3D12_COMPARISON_FUNC ToD3D12ComparisonFunc(wgpu::CompareFunction func); D3D12_COMPARISON_FUNC ToD3D12ComparisonFunc(wgpu::CompareFunction func);

View File

@ -48,4 +48,14 @@ void Stream<std::string_view>::Write(Sink* s, const std::string_view& t) {
} }
} }
template <>
void Stream<std::wstring_view>::Write(Sink* s, const std::wstring_view& t) {
StreamIn(s, t.length());
size_t size = t.length() * sizeof(wchar_t);
if (size > 0) {
void* ptr = s->GetSpace(size);
memcpy(ptr, t.data(), size);
}
}
} // namespace dawn::native::stream } // namespace dawn::native::stream

View File

@ -297,10 +297,9 @@ class Stream<std::unordered_map<K, V>> {
public: public:
static void Write(stream::Sink* sink, const std::unordered_map<K, V>& m) { static void Write(stream::Sink* sink, const std::unordered_map<K, V>& m) {
std::vector<std::pair<K, V>> ordered(m.begin(), m.end()); std::vector<std::pair<K, V>> ordered(m.begin(), m.end());
std::sort(ordered.begin(), ordered.end(), std::sort(
[](const std::pair<K, V>& a, const std::pair<K, V>& b) { ordered.begin(), ordered.end(),
return std::less<K>{}(a.first, b.first); [](const std::pair<K, V>& a, const std::pair<K, V>& b) { return a.first < b.first; });
});
StreamIn(sink, ordered); StreamIn(sink, ordered);
} }
}; };

View File

@ -108,8 +108,8 @@ class PipelineCachingTests : public DawnTest {
const EntryCounts counts = { const EntryCounts counts = {
// pipeline caching is only implemented on D3D12/Vulkan // pipeline caching is only implemented on D3D12/Vulkan
IsD3D12() || IsVulkan() ? 1u : 0u, IsD3D12() || IsVulkan() ? 1u : 0u,
// shader module caching is only implemented on Vulkan/Metal // shader module caching is only implemented on Vulkan/D3D12/Metal
IsVulkan() || IsMetal() ? 1u : 0u, IsVulkan() || IsMetal() || IsD3D12() ? 1u : 0u,
}; };
NiceMock<CachingInterfaceMock> mMockCache; NiceMock<CachingInterfaceMock> mMockCache;
}; };
@ -646,6 +646,7 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheIsolationKey) {
DAWN_INSTANTIATE_TEST(SinglePipelineCachingTests, DAWN_INSTANTIATE_TEST(SinglePipelineCachingTests,
D3D12Backend({"enable_blob_cache"}), D3D12Backend({"enable_blob_cache"}),
D3D12Backend({"enable_blob_cache", "use_dxc"}),
MetalBackend({"enable_blob_cache"}), MetalBackend({"enable_blob_cache"}),
OpenGLBackend({"enable_blob_cache"}), OpenGLBackend({"enable_blob_cache"}),
OpenGLESBackend({"enable_blob_cache"}), OpenGLESBackend({"enable_blob_cache"}),

View File

@ -174,6 +174,30 @@ TEST(SerializeTests, StdStringViews) {
EXPECT_CACHE_KEY_EQ(str, expected); EXPECT_CACHE_KEY_EQ(str, expected);
} }
// Test that ByteVectorSink serializes std::wstring_views as expected.
TEST(SerializeTests, StdWStringViews) {
static constexpr std::wstring_view str(L"Hello world!");
ByteVectorSink expected;
StreamIn(&expected, size_t(str.length()));
size_t bytes = str.length() * sizeof(wchar_t);
memcpy(expected.GetSpace(bytes), str.data(), bytes);
EXPECT_CACHE_KEY_EQ(str, expected);
}
// Test that ByteVectorSink serializes Blobs as expected.
TEST(SerializeTests, Blob) {
uint8_t data[] = "dawn native Blob";
Blob blob = Blob::UnsafeCreateWithDeleter(data, sizeof(data), []() {});
ByteVectorSink expected;
StreamIn(&expected, sizeof(data));
expected.insert(expected.end(), data, data + sizeof(data));
EXPECT_CACHE_KEY_EQ(blob, expected);
}
// Test that ByteVectorSink serializes other ByteVectorSinks as expected. // Test that ByteVectorSink serializes other ByteVectorSinks as expected.
TEST(SerializeTests, ByteVectorSinks) { TEST(SerializeTests, ByteVectorSinks) {
ByteVectorSink data = {'d', 'a', 't', 'a'}; ByteVectorSink data = {'d', 'a', 't', 'a'};
@ -309,6 +333,41 @@ TEST(StreamTests, SerializeDeserializeVisitableMembers) {
} }
} }
// Test that serializing then deserializing a Blob yields the same data.
// Tested here instead of in the type-parameterized tests since Blobs are not copyable.
TEST(StreamTests, SerializeDeserializeBlobs) {
// Test an empty blob
{
Blob blob;
EXPECT_EQ(blob.Size(), 0u);
ByteVectorSink sink;
StreamIn(&sink, blob);
BlobSource src(CreateBlob(sink));
Blob out;
auto err = StreamOut(&src, &out);
EXPECT_FALSE(err.IsError());
EXPECT_EQ(blob.Size(), out.Size());
EXPECT_EQ(memcmp(blob.Data(), out.Data(), blob.Size()), 0);
}
// Test a blob with some data
{
Blob blob = CreateBlob(std::vector<double>{6.24, 3.12222});
ByteVectorSink sink;
StreamIn(&sink, blob);
BlobSource src(CreateBlob(sink));
Blob out;
auto err = StreamOut(&src, &out);
EXPECT_FALSE(err.IsError());
EXPECT_EQ(blob.Size(), out.Size());
EXPECT_EQ(memcmp(blob.Data(), out.Data(), blob.Size()), 0);
}
}
template <size_t N> template <size_t N>
std::bitset<N - 1> BitsetFromBitString(const char (&str)[N]) { std::bitset<N - 1> BitsetFromBitString(const char (&str)[N]) {
// N - 1 because the last character is the null terminator. // N - 1 because the last character is the null terminator.