Pipeline overridable constants: D3D12 backend
D3D12 doesn't have native pipeline constants feature. This is done by using #define. Add some new tests to make sure these define approaches work as expected. Also makes duplicate pipeline constant entries an invalid case. Bug: dawn:1137 Change-Id: Iefed44a749625b535bbafbb39f42699f0b42e06a Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/68860 Reviewed-by: Austin Eng <enga@chromium.org> Commit-Queue: Shrek Shao <shrekshao@google.com>
This commit is contained in:
parent
fd066a3345
commit
c6c4588036
|
@ -55,17 +55,22 @@ namespace dawn_native {
|
||||||
// pipelineBase is not yet constructed at this moment so iterate constants from descriptor
|
// pipelineBase is not yet constructed at this moment so iterate constants from descriptor
|
||||||
size_t numUninitializedConstants = metadata.uninitializedOverridableConstants.size();
|
size_t numUninitializedConstants = metadata.uninitializedOverridableConstants.size();
|
||||||
// Keep an initialized constants sets to handle duplicate initialization cases
|
// Keep an initialized constants sets to handle duplicate initialization cases
|
||||||
// Only storing that of uninialized constants is needed
|
|
||||||
std::unordered_set<std::string> stageInitializedConstantIdentifiers;
|
std::unordered_set<std::string> stageInitializedConstantIdentifiers;
|
||||||
for (uint32_t i = 0; i < constantCount; i++) {
|
for (uint32_t i = 0; i < constantCount; i++) {
|
||||||
DAWN_INVALID_IF(metadata.overridableConstants.count(constants[i].key) == 0,
|
DAWN_INVALID_IF(metadata.overridableConstants.count(constants[i].key) == 0,
|
||||||
"Pipeline overridable constant \"%s\" not found in shader module %s.",
|
"Pipeline overridable constant \"%s\" not found in %s.",
|
||||||
constants[i].key, module);
|
constants[i].key, module);
|
||||||
|
|
||||||
if (metadata.uninitializedOverridableConstants.count(constants[i].key) > 0 &&
|
if (stageInitializedConstantIdentifiers.count(constants[i].key) == 0) {
|
||||||
stageInitializedConstantIdentifiers.count(constants[i].key) == 0) {
|
if (metadata.uninitializedOverridableConstants.count(constants[i].key) > 0) {
|
||||||
numUninitializedConstants--;
|
numUninitializedConstants--;
|
||||||
|
}
|
||||||
stageInitializedConstantIdentifiers.insert(constants[i].key);
|
stageInitializedConstantIdentifiers.insert(constants[i].key);
|
||||||
|
} else {
|
||||||
|
// There are duplicate initializations
|
||||||
|
return DAWN_FORMAT_VALIDATION_ERROR(
|
||||||
|
"Pipeline overridable constants \"%s\" is set more than once in %s",
|
||||||
|
constants[i].key, module);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -116,11 +121,10 @@ namespace dawn_native {
|
||||||
// Record them internally.
|
// Record them internally.
|
||||||
bool isFirstStage = mStageMask == wgpu::ShaderStage::None;
|
bool isFirstStage = mStageMask == wgpu::ShaderStage::None;
|
||||||
mStageMask |= StageBit(shaderStage);
|
mStageMask |= StageBit(shaderStage);
|
||||||
mStages[shaderStage] = {module, entryPointName, &metadata,
|
mStages[shaderStage] = {module, entryPointName, &metadata, {}};
|
||||||
std::vector<PipelineConstantEntry>()};
|
|
||||||
auto& constants = mStages[shaderStage].constants;
|
auto& constants = mStages[shaderStage].constants;
|
||||||
for (uint32_t i = 0; i < stage.constantCount; i++) {
|
for (uint32_t i = 0; i < stage.constantCount; i++) {
|
||||||
constants.emplace_back(stage.constants[i].key, stage.constants[i].value);
|
constants.emplace(stage.constants[i].key, stage.constants[i].value);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute the max() of all minBufferSizes across all stages.
|
// Compute the max() of all minBufferSizes across all stages.
|
||||||
|
|
|
@ -37,7 +37,9 @@ namespace dawn_native {
|
||||||
const PipelineLayoutBase* layout,
|
const PipelineLayoutBase* layout,
|
||||||
SingleShaderStage stage);
|
SingleShaderStage stage);
|
||||||
|
|
||||||
using PipelineConstantEntry = std::pair<std::string, double>;
|
// Use map to make sure constant keys are sorted for creating shader cache keys
|
||||||
|
using PipelineConstantEntries = std::map<std::string, double>;
|
||||||
|
|
||||||
struct ProgrammableStage {
|
struct ProgrammableStage {
|
||||||
Ref<ShaderModuleBase> module;
|
Ref<ShaderModuleBase> module;
|
||||||
std::string entryPoint;
|
std::string entryPoint;
|
||||||
|
@ -45,7 +47,7 @@ namespace dawn_native {
|
||||||
// The metadata lives as long as module, that's ref-ed in the same structure.
|
// The metadata lives as long as module, that's ref-ed in the same structure.
|
||||||
const EntryPointMetadata* metadata = nullptr;
|
const EntryPointMetadata* metadata = nullptr;
|
||||||
|
|
||||||
std::vector<PipelineConstantEntry> constants;
|
PipelineConstantEntries constants;
|
||||||
};
|
};
|
||||||
|
|
||||||
class PipelineBase : public ApiObjectBase, public CachedObject {
|
class PipelineBase : public ApiObjectBase, public CachedObject {
|
||||||
|
|
|
@ -222,9 +222,11 @@ namespace dawn_native {
|
||||||
OverridableConstantScalar defaultValue;
|
OverridableConstantScalar defaultValue;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using OverridableConstantsMap = std::unordered_map<std::string, OverridableConstant>;
|
||||||
|
|
||||||
// Map identifier to overridable constant
|
// Map identifier to overridable constant
|
||||||
// Identifier is unique: either the variable name or the numeric ID if specified
|
// Identifier is unique: either the variable name or the numeric ID if specified
|
||||||
std::unordered_map<std::string, OverridableConstant> overridableConstants;
|
OverridableConstantsMap overridableConstants;
|
||||||
|
|
||||||
// Overridable constants that are not initialized in shaders
|
// Overridable constants that are not initialized in shaders
|
||||||
// They need value initialization from pipeline stage or it is a validation error
|
// They need value initialization from pipeline stage or it is a validation error
|
||||||
|
@ -254,7 +256,7 @@ namespace dawn_native {
|
||||||
// Return true iff the program has an entrypoint called `entryPoint`.
|
// Return true iff the program has an entrypoint called `entryPoint`.
|
||||||
bool HasEntryPoint(const std::string& entryPoint) const;
|
bool HasEntryPoint(const std::string& entryPoint) const;
|
||||||
|
|
||||||
// Returns the metadata for the given `entryPoint`. HasEntryPoint with the same argument
|
// Return the metadata for the given `entryPoint`. HasEntryPoint with the same argument
|
||||||
// must be true.
|
// must be true.
|
||||||
const EntryPointMetadata& GetEntryPoint(const std::string& entryPoint) const;
|
const EntryPointMetadata& GetEntryPoint(const std::string& entryPoint) const;
|
||||||
|
|
||||||
|
|
|
@ -48,9 +48,8 @@ namespace dawn_native { namespace d3d12 {
|
||||||
d3dDesc.pRootSignature = ToBackend(GetLayout())->GetRootSignature();
|
d3dDesc.pRootSignature = ToBackend(GetLayout())->GetRootSignature();
|
||||||
|
|
||||||
CompiledShader compiledShader;
|
CompiledShader compiledShader;
|
||||||
DAWN_TRY_ASSIGN(compiledShader,
|
DAWN_TRY_ASSIGN(compiledShader, module->Compile(computeStage, SingleShaderStage::Compute,
|
||||||
module->Compile(computeStage.entryPoint.c_str(), SingleShaderStage::Compute,
|
ToBackend(GetLayout()), compileFlags));
|
||||||
ToBackend(GetLayout()), compileFlags));
|
|
||||||
d3dDesc.CS = compiledShader.GetD3D12ShaderBytecode();
|
d3dDesc.CS = compiledShader.GetD3D12ShaderBytecode();
|
||||||
auto* d3d12Device = device->GetD3D12Device();
|
auto* d3d12Device = device->GetD3D12Device();
|
||||||
DAWN_TRY(CheckHRESULT(
|
DAWN_TRY(CheckHRESULT(
|
||||||
|
|
|
@ -348,10 +348,10 @@ namespace dawn_native { namespace d3d12 {
|
||||||
PerStage<CompiledShader> compiledShader;
|
PerStage<CompiledShader> compiledShader;
|
||||||
|
|
||||||
for (auto stage : IterateStages(GetStageMask())) {
|
for (auto stage : IterateStages(GetStageMask())) {
|
||||||
DAWN_TRY_ASSIGN(compiledShader[stage],
|
DAWN_TRY_ASSIGN(
|
||||||
ToBackend(pipelineStages[stage].module)
|
compiledShader[stage],
|
||||||
->Compile(pipelineStages[stage].entryPoint.c_str(), stage,
|
ToBackend(pipelineStages[stage].module)
|
||||||
ToBackend(GetLayout()), compileFlags));
|
->Compile(pipelineStages[stage], stage, ToBackend(GetLayout()), compileFlags));
|
||||||
*shaders[stage] = compiledShader[stage].GetD3D12ShaderBytecode();
|
*shaders[stage] = compiledShader[stage].GetD3D12ShaderBytecode();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,8 @@
|
||||||
#include "common/Assert.h"
|
#include "common/Assert.h"
|
||||||
#include "common/BitSetIterator.h"
|
#include "common/BitSetIterator.h"
|
||||||
#include "common/Log.h"
|
#include "common/Log.h"
|
||||||
|
#include "common/WindowsUtils.h"
|
||||||
|
#include "dawn_native/Pipeline.h"
|
||||||
#include "dawn_native/TintUtils.h"
|
#include "dawn_native/TintUtils.h"
|
||||||
#include "dawn_native/d3d12/BindGroupLayoutD3D12.h"
|
#include "dawn_native/d3d12/BindGroupLayoutD3D12.h"
|
||||||
#include "dawn_native/d3d12/D3D12Error.h"
|
#include "dawn_native/d3d12/D3D12Error.h"
|
||||||
|
@ -32,6 +34,94 @@
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
|
namespace dawn_native {
|
||||||
|
template <typename StringType, typename T = int32_t>
|
||||||
|
struct NumberToString {
|
||||||
|
static StringType ToStringAsValue(T v);
|
||||||
|
static StringType ToStringAsId(T v);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct NumberToString<std::string, T> {
|
||||||
|
static constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
|
||||||
|
static std::string ToStringAsValue(T v) {
|
||||||
|
return std::to_string(v);
|
||||||
|
}
|
||||||
|
static std::string ToStringAsId(T v) {
|
||||||
|
return std::to_string(v);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct NumberToString<std::wstring, T> {
|
||||||
|
static constexpr WCHAR kSpecConstantPrefix[] = L"WGSL_SPEC_CONSTANT_";
|
||||||
|
static std::wstring ToStringAsValue(T v) {
|
||||||
|
return std::to_wstring(v);
|
||||||
|
}
|
||||||
|
static std::wstring ToStringAsId(T v) {
|
||||||
|
return std::to_wstring(v);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct NumberToString<std::string, float> {
|
||||||
|
static std::string ToStringAsValue(float v) {
|
||||||
|
std::ostringstream out;
|
||||||
|
// 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
|
||||||
|
out.precision(8);
|
||||||
|
out << std::fixed << v;
|
||||||
|
return out.str();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct NumberToString<std::wstring, float> {
|
||||||
|
static std::wstring ToStringAsValue(float v) {
|
||||||
|
std::basic_ostringstream<WCHAR> out;
|
||||||
|
// 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
|
||||||
|
out.precision(8);
|
||||||
|
out << std::fixed << v;
|
||||||
|
return out.str();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct NumberToString<std::string, uint32_t> {
|
||||||
|
static std::string ToStringAsValue(uint32_t v) {
|
||||||
|
return std::to_string(v) + "u";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct NumberToString<std::wstring, uint32_t> {
|
||||||
|
static std::wstring ToStringAsValue(uint32_t v) {
|
||||||
|
return std::to_wstring(v) + L"u";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename StringType>
|
||||||
|
StringType GetHLSLValueString(EntryPointMetadata::OverridableConstant::Type dawnType,
|
||||||
|
const OverridableConstantScalar* entry,
|
||||||
|
double value = 0) {
|
||||||
|
switch (dawnType) {
|
||||||
|
case EntryPointMetadata::OverridableConstant::Type::Boolean:
|
||||||
|
return NumberToString<StringType, int32_t>::ToStringAsValue(
|
||||||
|
entry ? entry->b : static_cast<int32_t>(value));
|
||||||
|
case EntryPointMetadata::OverridableConstant::Type::Float32:
|
||||||
|
return NumberToString<StringType, float>::ToStringAsValue(
|
||||||
|
entry ? entry->f32 : static_cast<float>(value));
|
||||||
|
case EntryPointMetadata::OverridableConstant::Type::Int32:
|
||||||
|
return NumberToString<StringType, int32_t>::ToStringAsValue(
|
||||||
|
entry ? entry->i32 : static_cast<int32_t>(value));
|
||||||
|
case EntryPointMetadata::OverridableConstant::Type::Uint32:
|
||||||
|
return NumberToString<StringType, uint32_t>::ToStringAsValue(
|
||||||
|
entry ? entry->u32 : static_cast<uint32_t>(value));
|
||||||
|
default:
|
||||||
|
UNREACHABLE();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace dawn_native
|
||||||
|
|
||||||
namespace dawn_native { namespace d3d12 {
|
namespace dawn_native { namespace d3d12 {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -109,6 +199,8 @@ namespace dawn_native { namespace d3d12 {
|
||||||
bool usesNumWorkgroups;
|
bool usesNumWorkgroups;
|
||||||
uint32_t numWorkgroupsRegisterSpace;
|
uint32_t numWorkgroupsRegisterSpace;
|
||||||
uint32_t numWorkgroupsShaderRegister;
|
uint32_t numWorkgroupsShaderRegister;
|
||||||
|
const PipelineConstantEntries* pipelineConstantEntries;
|
||||||
|
const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants;
|
||||||
|
|
||||||
// FXC/DXC common inputs
|
// FXC/DXC common inputs
|
||||||
bool disableWorkgroupInit;
|
bool disableWorkgroupInit;
|
||||||
|
@ -128,7 +220,8 @@ namespace dawn_native { namespace d3d12 {
|
||||||
uint32_t compileFlags,
|
uint32_t compileFlags,
|
||||||
const Device* device,
|
const Device* device,
|
||||||
const tint::Program* program,
|
const tint::Program* program,
|
||||||
const EntryPointMetadata& entryPoint) {
|
const EntryPointMetadata& entryPoint,
|
||||||
|
const ProgrammableStage& programmableStage) {
|
||||||
Compiler compiler;
|
Compiler compiler;
|
||||||
uint64_t dxcVersion = 0;
|
uint64_t dxcVersion = 0;
|
||||||
if (device->IsToggleEnabled(Toggle::UseDXC)) {
|
if (device->IsToggleEnabled(Toggle::UseDXC)) {
|
||||||
|
@ -200,6 +293,10 @@ namespace dawn_native { namespace d3d12 {
|
||||||
request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0;
|
request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0;
|
||||||
request.deviceInfo = &device->GetDeviceInfo();
|
request.deviceInfo = &device->GetDeviceInfo();
|
||||||
request.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16);
|
request.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16);
|
||||||
|
request.pipelineConstantEntries = &programmableStage.constants;
|
||||||
|
request.shaderEntryPointConstants =
|
||||||
|
&programmableStage.module->GetEntryPoint(programmableStage.entryPoint)
|
||||||
|
.overridableConstants;
|
||||||
return std::move(request);
|
return std::move(request);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -251,6 +348,21 @@ namespace dawn_native { namespace d3d12 {
|
||||||
stream << " fxcVersion=" << fxcVersion;
|
stream << " fxcVersion=" << fxcVersion;
|
||||||
stream << " dxcVersion=" << dxcVersion;
|
stream << " dxcVersion=" << dxcVersion;
|
||||||
stream << " hasShaderFloat16Feature=" << hasShaderFloat16Feature;
|
stream << " hasShaderFloat16Feature=" << hasShaderFloat16Feature;
|
||||||
|
|
||||||
|
stream << " overridableConstants={";
|
||||||
|
for (const auto& pipelineConstant : *pipelineConstantEntries) {
|
||||||
|
const std::string& name = pipelineConstant.first;
|
||||||
|
double value = pipelineConstant.second;
|
||||||
|
|
||||||
|
// This is already validated so `name` must exist
|
||||||
|
const auto& moduleConstant = shaderEntryPointConstants->at(name);
|
||||||
|
|
||||||
|
stream << " <" << name << ","
|
||||||
|
<< GetHLSLValueString<std::string>(moduleConstant.type, nullptr, value)
|
||||||
|
<< ">";
|
||||||
|
}
|
||||||
|
stream << " }";
|
||||||
|
|
||||||
stream << ")";
|
stream << ")";
|
||||||
stream << "\n";
|
stream << "\n";
|
||||||
|
|
||||||
|
@ -310,6 +422,53 @@ namespace dawn_native { namespace d3d12 {
|
||||||
return arguments;
|
return arguments;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename StringType>
|
||||||
|
const std::vector<std::pair<StringType, StringType>> GetOverridableConstantsDefines(
|
||||||
|
const PipelineConstantEntries* pipelineConstantEntries,
|
||||||
|
const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants) {
|
||||||
|
std::vector<std::pair<StringType, StringType>> defineStrings;
|
||||||
|
|
||||||
|
std::unordered_set<std::string> overriddenConstants;
|
||||||
|
|
||||||
|
// Set pipeline overridden values
|
||||||
|
for (const auto& pipelineConstant : *pipelineConstantEntries) {
|
||||||
|
const std::string& name = pipelineConstant.first;
|
||||||
|
double value = pipelineConstant.second;
|
||||||
|
|
||||||
|
overriddenConstants.insert(name);
|
||||||
|
|
||||||
|
// This is already validated so `name` must exist
|
||||||
|
const auto& moduleConstant = shaderEntryPointConstants->at(name);
|
||||||
|
|
||||||
|
defineStrings.emplace_back(
|
||||||
|
NumberToString<StringType>::kSpecConstantPrefix +
|
||||||
|
NumberToString<StringType>::ToStringAsId(
|
||||||
|
static_cast<int32_t>(moduleConstant.id)),
|
||||||
|
GetHLSLValueString<StringType>(moduleConstant.type, nullptr, value));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set shader initialized default values
|
||||||
|
for (const auto& iter : *shaderEntryPointConstants) {
|
||||||
|
const std::string& name = iter.first;
|
||||||
|
if (overriddenConstants.count(name) != 0) {
|
||||||
|
// This constant already has overridden value
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& moduleConstant = shaderEntryPointConstants->at(name);
|
||||||
|
|
||||||
|
// Uninitialized default values are okay since they are only defined to pass
|
||||||
|
// compilation but not used
|
||||||
|
defineStrings.emplace_back(NumberToString<StringType>::kSpecConstantPrefix +
|
||||||
|
NumberToString<StringType>::ToStringAsId(
|
||||||
|
static_cast<int32_t>(moduleConstant.id)),
|
||||||
|
GetHLSLValueString<StringType>(
|
||||||
|
moduleConstant.type, &moduleConstant.defaultValue));
|
||||||
|
}
|
||||||
|
|
||||||
|
return defineStrings;
|
||||||
|
}
|
||||||
|
|
||||||
ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(IDxcLibrary* dxcLibrary,
|
ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(IDxcLibrary* dxcLibrary,
|
||||||
IDxcCompiler* dxcCompiler,
|
IDxcCompiler* dxcCompiler,
|
||||||
const ShaderCompilationRequest& request,
|
const ShaderCompilationRequest& request,
|
||||||
|
@ -326,12 +485,21 @@ namespace dawn_native { namespace d3d12 {
|
||||||
std::vector<const wchar_t*> arguments =
|
std::vector<const wchar_t*> arguments =
|
||||||
GetDXCArguments(request.compileFlags, request.hasShaderFloat16Feature);
|
GetDXCArguments(request.compileFlags, request.hasShaderFloat16Feature);
|
||||||
|
|
||||||
|
// Build defines for overridable constants
|
||||||
|
const auto& defineStrings = GetOverridableConstantsDefines<std::wstring>(
|
||||||
|
request.pipelineConstantEntries, request.shaderEntryPointConstants);
|
||||||
|
std::vector<DxcDefine> dxcDefines;
|
||||||
|
dxcDefines.reserve(defineStrings.size());
|
||||||
|
for (const auto& d : defineStrings) {
|
||||||
|
dxcDefines.push_back({d.first.c_str(), d.second.c_str()});
|
||||||
|
}
|
||||||
|
|
||||||
ComPtr<IDxcOperationResult> result;
|
ComPtr<IDxcOperationResult> result;
|
||||||
DAWN_TRY(CheckHRESULT(
|
DAWN_TRY(CheckHRESULT(
|
||||||
dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(),
|
dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(),
|
||||||
request.deviceInfo->shaderProfiles[request.stage].c_str(),
|
request.deviceInfo->shaderProfiles[request.stage].c_str(),
|
||||||
arguments.data(), arguments.size(), nullptr, 0, nullptr,
|
arguments.data(), arguments.size(), dxcDefines.data(),
|
||||||
&result),
|
dxcDefines.size(), nullptr, &result),
|
||||||
"DXC compile"));
|
"DXC compile"));
|
||||||
|
|
||||||
HRESULT hr;
|
HRESULT hr;
|
||||||
|
@ -369,8 +537,24 @@ namespace dawn_native { namespace d3d12 {
|
||||||
ComPtr<ID3DBlob> compiledShader;
|
ComPtr<ID3DBlob> compiledShader;
|
||||||
ComPtr<ID3DBlob> errors;
|
ComPtr<ID3DBlob> errors;
|
||||||
|
|
||||||
|
// Build defines for overridable constants
|
||||||
|
const auto& defineStrings = GetOverridableConstantsDefines<std::string>(
|
||||||
|
request.pipelineConstantEntries, request.shaderEntryPointConstants);
|
||||||
|
|
||||||
|
const D3D_SHADER_MACRO* pDefines = nullptr;
|
||||||
|
std::vector<D3D_SHADER_MACRO> fxcDefines;
|
||||||
|
if (defineStrings.size() > 0) {
|
||||||
|
fxcDefines.reserve(defineStrings.size() + 1);
|
||||||
|
for (const auto& d : defineStrings) {
|
||||||
|
fxcDefines.push_back({d.first.c_str(), d.second.c_str()});
|
||||||
|
}
|
||||||
|
// d3dCompile D3D_SHADER_MACRO* pDefines is a nullptr terminated array
|
||||||
|
fxcDefines.push_back({nullptr, nullptr});
|
||||||
|
pDefines = fxcDefines.data();
|
||||||
|
}
|
||||||
|
|
||||||
DAWN_INVALID_IF(FAILED(functions->d3dCompile(
|
DAWN_INVALID_IF(FAILED(functions->d3dCompile(
|
||||||
hlslSource.c_str(), hlslSource.length(), nullptr, nullptr, nullptr,
|
hlslSource.c_str(), hlslSource.length(), nullptr, pDefines, nullptr,
|
||||||
request.entryPointName, targetProfile, request.compileFlags, 0,
|
request.entryPointName, targetProfile, request.compileFlags, 0,
|
||||||
&compiledShader, &errors)),
|
&compiledShader, &errors)),
|
||||||
"D3D compile failed with: %s",
|
"D3D compile failed with: %s",
|
||||||
|
@ -392,6 +576,9 @@ namespace dawn_native { namespace d3d12 {
|
||||||
}
|
}
|
||||||
transformManager.Add<tint::transform::BindingRemapper>();
|
transformManager.Add<tint::transform::BindingRemapper>();
|
||||||
|
|
||||||
|
transformManager.Add<tint::transform::SingleEntryPoint>();
|
||||||
|
transformInputs.Add<tint::transform::SingleEntryPoint::Config>(request.entryPointName);
|
||||||
|
|
||||||
transformManager.Add<tint::transform::Renamer>();
|
transformManager.Add<tint::transform::Renamer>();
|
||||||
|
|
||||||
if (request.disableSymbolRenaming) {
|
if (request.disableSymbolRenaming) {
|
||||||
|
@ -508,7 +695,7 @@ namespace dawn_native { namespace d3d12 {
|
||||||
return InitializeBase(parseResult);
|
return InitializeBase(parseResult);
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultOrError<CompiledShader> ShaderModule::Compile(const char* entryPointName,
|
ResultOrError<CompiledShader> ShaderModule::Compile(const ProgrammableStage& programmableStage,
|
||||||
SingleShaderStage stage,
|
SingleShaderStage stage,
|
||||||
PipelineLayout* layout,
|
PipelineLayout* layout,
|
||||||
uint32_t compileFlags) {
|
uint32_t compileFlags) {
|
||||||
|
@ -555,9 +742,10 @@ namespace dawn_native { namespace d3d12 {
|
||||||
}
|
}
|
||||||
|
|
||||||
ShaderCompilationRequest request;
|
ShaderCompilationRequest request;
|
||||||
DAWN_TRY_ASSIGN(request, ShaderCompilationRequest::Create(entryPointName, stage, layout,
|
DAWN_TRY_ASSIGN(
|
||||||
compileFlags, device, program,
|
request, ShaderCompilationRequest::Create(
|
||||||
GetEntryPoint(entryPointName)));
|
programmableStage.entryPoint.c_str(), stage, layout, compileFlags, device,
|
||||||
|
program, GetEntryPoint(programmableStage.entryPoint), programmableStage));
|
||||||
|
|
||||||
PersistentCacheKey shaderCacheKey;
|
PersistentCacheKey shaderCacheKey;
|
||||||
DAWN_TRY_ASSIGN(shaderCacheKey, request.CreateCacheKey());
|
DAWN_TRY_ASSIGN(shaderCacheKey, request.CreateCacheKey());
|
||||||
|
|
|
@ -20,6 +20,10 @@
|
||||||
|
|
||||||
#include "dawn_native/d3d12/d3d12_platform.h"
|
#include "dawn_native/d3d12/d3d12_platform.h"
|
||||||
|
|
||||||
|
namespace dawn_native {
|
||||||
|
struct ProgrammableStage;
|
||||||
|
} // namespace dawn_native
|
||||||
|
|
||||||
namespace dawn_native { namespace d3d12 {
|
namespace dawn_native { namespace d3d12 {
|
||||||
|
|
||||||
class Device;
|
class Device;
|
||||||
|
@ -49,7 +53,7 @@ namespace dawn_native { namespace d3d12 {
|
||||||
const ShaderModuleDescriptor* descriptor,
|
const ShaderModuleDescriptor* descriptor,
|
||||||
ShaderModuleParseResult* parseResult);
|
ShaderModuleParseResult* parseResult);
|
||||||
|
|
||||||
ResultOrError<CompiledShader> Compile(const char* entryPointName,
|
ResultOrError<CompiledShader> Compile(const ProgrammableStage& programmableStage,
|
||||||
SingleShaderStage stage,
|
SingleShaderStage stage,
|
||||||
PipelineLayout* layout,
|
PipelineLayout* layout,
|
||||||
uint32_t compileFlags);
|
uint32_t compileFlags);
|
||||||
|
|
|
@ -391,8 +391,8 @@ fn main([[location(0)]] pos : vec4<f32>) -> [[builtin(position)]] vec4<f32> {
|
||||||
|
|
||||||
// Test overridable constants without numeric identifiers
|
// Test overridable constants without numeric identifiers
|
||||||
TEST_P(ShaderTests, OverridableConstants) {
|
TEST_P(ShaderTests, OverridableConstants) {
|
||||||
// TODO(dawn:1137): D3D12 backend is unimplemented
|
DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
|
||||||
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan() && !IsMetal());
|
DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES());
|
||||||
|
|
||||||
uint32_t const kCount = 11;
|
uint32_t const kCount = 11;
|
||||||
std::vector<uint32_t> expected(kCount);
|
std::vector<uint32_t> expected(kCount);
|
||||||
|
@ -469,8 +469,8 @@ TEST_P(ShaderTests, OverridableConstants) {
|
||||||
|
|
||||||
// Test overridable constants with numeric identifiers
|
// Test overridable constants with numeric identifiers
|
||||||
TEST_P(ShaderTests, OverridableConstantsNumericIdentifiers) {
|
TEST_P(ShaderTests, OverridableConstantsNumericIdentifiers) {
|
||||||
// TODO(dawn:1137): D3D12 backend is unimplemented
|
DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
|
||||||
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan() && !IsMetal());
|
DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES());
|
||||||
|
|
||||||
uint32_t const kCount = 4;
|
uint32_t const kCount = 4;
|
||||||
std::vector<uint32_t> expected{1u, 2u, 3u, 0u};
|
std::vector<uint32_t> expected{1u, 2u, 3u, 0u};
|
||||||
|
@ -523,17 +523,71 @@ TEST_P(ShaderTests, OverridableConstantsNumericIdentifiers) {
|
||||||
EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), buffer, 0, kCount);
|
EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), buffer, 0, kCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test overridable constants precision
|
||||||
|
// D3D12 HLSL shader uses defines so we want float number to have enough precision
|
||||||
|
TEST_P(ShaderTests, OverridableConstantsPrecision) {
|
||||||
|
DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
|
||||||
|
DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES());
|
||||||
|
|
||||||
|
uint32_t const kCount = 2;
|
||||||
|
float const kValue1 = 3.14159;
|
||||||
|
float const kValue2 = 3.141592653589793238;
|
||||||
|
std::vector<float> expected{kValue1, kValue2};
|
||||||
|
wgpu::Buffer buffer = CreateBuffer(kCount);
|
||||||
|
|
||||||
|
std::string shader = R"(
|
||||||
|
[[override(1001)]] let c1: f32;
|
||||||
|
[[override(1002)]] let c2: f32;
|
||||||
|
|
||||||
|
[[block]] struct Buf {
|
||||||
|
data : array<f32, 2>;
|
||||||
|
};
|
||||||
|
|
||||||
|
[[group(0), binding(0)]] var<storage, read_write> buf : Buf;
|
||||||
|
|
||||||
|
[[stage(compute), workgroup_size(1)]] fn main() {
|
||||||
|
buf.data[0] = c1;
|
||||||
|
buf.data[1] = c2;
|
||||||
|
})";
|
||||||
|
|
||||||
|
std::vector<wgpu::ConstantEntry> constants;
|
||||||
|
constants.push_back({nullptr, "1001", kValue1});
|
||||||
|
constants.push_back({nullptr, "1002", kValue2});
|
||||||
|
wgpu::ComputePipeline pipeline = CreateComputePipeline(shader, "main", &constants);
|
||||||
|
|
||||||
|
wgpu::BindGroup bindGroup =
|
||||||
|
utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0), {{0, buffer}});
|
||||||
|
|
||||||
|
wgpu::CommandBuffer commands;
|
||||||
|
{
|
||||||
|
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
||||||
|
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
||||||
|
pass.SetPipeline(pipeline);
|
||||||
|
pass.SetBindGroup(0, bindGroup);
|
||||||
|
pass.Dispatch(1);
|
||||||
|
pass.EndPass();
|
||||||
|
|
||||||
|
commands = encoder.Finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
queue.Submit(1, &commands);
|
||||||
|
|
||||||
|
EXPECT_BUFFER_FLOAT_RANGE_EQ(expected.data(), buffer, 0, kCount);
|
||||||
|
}
|
||||||
|
|
||||||
// Test overridable constants for different entry points
|
// Test overridable constants for different entry points
|
||||||
TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
|
TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
|
||||||
// TODO(dawn:1137): D3D12 backend is unimplemented
|
DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
|
||||||
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan() && !IsMetal());
|
DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES());
|
||||||
|
|
||||||
uint32_t const kCount = 1;
|
uint32_t const kCount = 1;
|
||||||
std::vector<uint32_t> expected1{1u};
|
std::vector<uint32_t> expected1{1u};
|
||||||
std::vector<uint32_t> expected2{2u};
|
std::vector<uint32_t> expected2{2u};
|
||||||
|
std::vector<uint32_t> expected3{3u};
|
||||||
|
|
||||||
wgpu::Buffer buffer1 = CreateBuffer(kCount);
|
wgpu::Buffer buffer1 = CreateBuffer(kCount);
|
||||||
wgpu::Buffer buffer2 = CreateBuffer(kCount);
|
wgpu::Buffer buffer2 = CreateBuffer(kCount);
|
||||||
|
wgpu::Buffer buffer3 = CreateBuffer(kCount);
|
||||||
|
|
||||||
std::string shader = R"(
|
std::string shader = R"(
|
||||||
[[override(1001)]] let c1: u32;
|
[[override(1001)]] let c1: u32;
|
||||||
|
@ -552,6 +606,10 @@ TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
|
||||||
[[stage(compute), workgroup_size(1)]] fn main2() {
|
[[stage(compute), workgroup_size(1)]] fn main2() {
|
||||||
buf.data[0] = c2;
|
buf.data[0] = c2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[stage(compute), workgroup_size(1)]] fn main3() {
|
||||||
|
buf.data[0] = 3u;
|
||||||
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
std::vector<wgpu::ConstantEntry> constants1;
|
std::vector<wgpu::ConstantEntry> constants1;
|
||||||
|
@ -575,10 +633,17 @@ TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
|
||||||
csDesc2.compute.constantCount = constants2.size();
|
csDesc2.compute.constantCount = constants2.size();
|
||||||
wgpu::ComputePipeline pipeline2 = device.CreateComputePipeline(&csDesc2);
|
wgpu::ComputePipeline pipeline2 = device.CreateComputePipeline(&csDesc2);
|
||||||
|
|
||||||
|
wgpu::ComputePipelineDescriptor csDesc3;
|
||||||
|
csDesc3.compute.module = shaderModule;
|
||||||
|
csDesc3.compute.entryPoint = "main3";
|
||||||
|
wgpu::ComputePipeline pipeline3 = device.CreateComputePipeline(&csDesc3);
|
||||||
|
|
||||||
wgpu::BindGroup bindGroup1 =
|
wgpu::BindGroup bindGroup1 =
|
||||||
utils::MakeBindGroup(device, pipeline1.GetBindGroupLayout(0), {{0, buffer1}});
|
utils::MakeBindGroup(device, pipeline1.GetBindGroupLayout(0), {{0, buffer1}});
|
||||||
wgpu::BindGroup bindGroup2 =
|
wgpu::BindGroup bindGroup2 =
|
||||||
utils::MakeBindGroup(device, pipeline2.GetBindGroupLayout(0), {{0, buffer2}});
|
utils::MakeBindGroup(device, pipeline2.GetBindGroupLayout(0), {{0, buffer2}});
|
||||||
|
wgpu::BindGroup bindGroup3 =
|
||||||
|
utils::MakeBindGroup(device, pipeline3.GetBindGroupLayout(0), {{0, buffer3}});
|
||||||
|
|
||||||
wgpu::CommandBuffer commands;
|
wgpu::CommandBuffer commands;
|
||||||
{
|
{
|
||||||
|
@ -592,6 +657,10 @@ TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
|
||||||
pass.SetBindGroup(0, bindGroup2);
|
pass.SetBindGroup(0, bindGroup2);
|
||||||
pass.Dispatch(1);
|
pass.Dispatch(1);
|
||||||
|
|
||||||
|
pass.SetPipeline(pipeline3);
|
||||||
|
pass.SetBindGroup(0, bindGroup3);
|
||||||
|
pass.Dispatch(1);
|
||||||
|
|
||||||
pass.EndPass();
|
pass.EndPass();
|
||||||
|
|
||||||
commands = encoder.Finish();
|
commands = encoder.Finish();
|
||||||
|
@ -601,14 +670,15 @@ TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
|
||||||
|
|
||||||
EXPECT_BUFFER_U32_RANGE_EQ(expected1.data(), buffer1, 0, kCount);
|
EXPECT_BUFFER_U32_RANGE_EQ(expected1.data(), buffer1, 0, kCount);
|
||||||
EXPECT_BUFFER_U32_RANGE_EQ(expected2.data(), buffer2, 0, kCount);
|
EXPECT_BUFFER_U32_RANGE_EQ(expected2.data(), buffer2, 0, kCount);
|
||||||
|
EXPECT_BUFFER_U32_RANGE_EQ(expected3.data(), buffer3, 0, kCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test overridable constants with render pipeline
|
// Test overridable constants with render pipeline
|
||||||
// Draw a triangle covering the render target, with vertex position and color values from
|
// Draw a triangle covering the render target, with vertex position and color values from
|
||||||
// overridable constants
|
// overridable constants
|
||||||
TEST_P(ShaderTests, OverridableConstantsRenderPipeline) {
|
TEST_P(ShaderTests, OverridableConstantsRenderPipeline) {
|
||||||
// TODO(dawn:1137): D3D12 backend is unimplemented
|
DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
|
||||||
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan() && !IsMetal());
|
DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES());
|
||||||
|
|
||||||
wgpu::ShaderModule vsModule = utils::CreateShaderModule(device, R"(
|
wgpu::ShaderModule vsModule = utils::CreateShaderModule(device, R"(
|
||||||
[[override(1111)]] let xright: f32;
|
[[override(1111)]] let xright: f32;
|
||||||
|
|
|
@ -105,12 +105,12 @@ TEST_F(ComputePipelineOverridableConstantsValidationTest, ConstantsIdentifierLoo
|
||||||
TestCreatePipeline(constants);
|
TestCreatePipeline(constants);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
// Valid: set the same constant twice
|
// Error: set the same constant twice
|
||||||
std::vector<wgpu::ConstantEntry> constants{
|
std::vector<wgpu::ConstantEntry> constants{
|
||||||
{nullptr, "c0", 0},
|
{nullptr, "c0", 0},
|
||||||
{nullptr, "c0", 1},
|
{nullptr, "c0", 1},
|
||||||
};
|
};
|
||||||
TestCreatePipeline(constants);
|
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
// Valid: find by constant numeric id
|
// Valid: find by constant numeric id
|
||||||
|
@ -158,12 +158,12 @@ TEST_F(ComputePipelineOverridableConstantsValidationTest, UninitializedConstants
|
||||||
TestCreatePipeline(constants);
|
TestCreatePipeline(constants);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
// Valid: all constants initialized (with duplicate initializations)
|
// Error: duplicate initializations
|
||||||
std::vector<wgpu::ConstantEntry> constants{
|
std::vector<wgpu::ConstantEntry> constants{
|
||||||
{nullptr, "c0", false}, {nullptr, "c2", 1}, {nullptr, "c5", 1},
|
{nullptr, "c0", false}, {nullptr, "c2", 1}, {nullptr, "c5", 1},
|
||||||
{nullptr, "c8", 1}, {nullptr, "c2", 2},
|
{nullptr, "c8", 1}, {nullptr, "c2", 2},
|
||||||
};
|
};
|
||||||
TestCreatePipeline(constants);
|
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue