Pipeline overridable constants: D3D12 backend
D3D12 doesn't have native pipeline constants feature. This is done by using #define. Add some new tests to make sure these define approaches work as expected. Also makes duplicate pipeline constant entries an invalid case. Bug: dawn:1137 Change-Id: Iefed44a749625b535bbafbb39f42699f0b42e06a Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/68860 Reviewed-by: Austin Eng <enga@chromium.org> Commit-Queue: Shrek Shao <shrekshao@google.com>
This commit is contained in:
parent
fd066a3345
commit
c6c4588036
|
@ -55,17 +55,22 @@ namespace dawn_native {
|
|||
// pipelineBase is not yet constructed at this moment so iterate constants from descriptor
|
||||
size_t numUninitializedConstants = metadata.uninitializedOverridableConstants.size();
|
||||
// Keep an initialized constants sets to handle duplicate initialization cases
|
||||
// Only storing that of uninialized constants is needed
|
||||
std::unordered_set<std::string> stageInitializedConstantIdentifiers;
|
||||
for (uint32_t i = 0; i < constantCount; i++) {
|
||||
DAWN_INVALID_IF(metadata.overridableConstants.count(constants[i].key) == 0,
|
||||
"Pipeline overridable constant \"%s\" not found in shader module %s.",
|
||||
"Pipeline overridable constant \"%s\" not found in %s.",
|
||||
constants[i].key, module);
|
||||
|
||||
if (metadata.uninitializedOverridableConstants.count(constants[i].key) > 0 &&
|
||||
stageInitializedConstantIdentifiers.count(constants[i].key) == 0) {
|
||||
numUninitializedConstants--;
|
||||
if (stageInitializedConstantIdentifiers.count(constants[i].key) == 0) {
|
||||
if (metadata.uninitializedOverridableConstants.count(constants[i].key) > 0) {
|
||||
numUninitializedConstants--;
|
||||
}
|
||||
stageInitializedConstantIdentifiers.insert(constants[i].key);
|
||||
} else {
|
||||
// There are duplicate initializations
|
||||
return DAWN_FORMAT_VALIDATION_ERROR(
|
||||
"Pipeline overridable constants \"%s\" is set more than once in %s",
|
||||
constants[i].key, module);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -116,11 +121,10 @@ namespace dawn_native {
|
|||
// Record them internally.
|
||||
bool isFirstStage = mStageMask == wgpu::ShaderStage::None;
|
||||
mStageMask |= StageBit(shaderStage);
|
||||
mStages[shaderStage] = {module, entryPointName, &metadata,
|
||||
std::vector<PipelineConstantEntry>()};
|
||||
mStages[shaderStage] = {module, entryPointName, &metadata, {}};
|
||||
auto& constants = mStages[shaderStage].constants;
|
||||
for (uint32_t i = 0; i < stage.constantCount; i++) {
|
||||
constants.emplace_back(stage.constants[i].key, stage.constants[i].value);
|
||||
constants.emplace(stage.constants[i].key, stage.constants[i].value);
|
||||
}
|
||||
|
||||
// Compute the max() of all minBufferSizes across all stages.
|
||||
|
|
|
@ -37,7 +37,9 @@ namespace dawn_native {
|
|||
const PipelineLayoutBase* layout,
|
||||
SingleShaderStage stage);
|
||||
|
||||
using PipelineConstantEntry = std::pair<std::string, double>;
|
||||
// Use map to make sure constant keys are sorted for creating shader cache keys
|
||||
using PipelineConstantEntries = std::map<std::string, double>;
|
||||
|
||||
struct ProgrammableStage {
|
||||
Ref<ShaderModuleBase> module;
|
||||
std::string entryPoint;
|
||||
|
@ -45,7 +47,7 @@ namespace dawn_native {
|
|||
// The metadata lives as long as module, that's ref-ed in the same structure.
|
||||
const EntryPointMetadata* metadata = nullptr;
|
||||
|
||||
std::vector<PipelineConstantEntry> constants;
|
||||
PipelineConstantEntries constants;
|
||||
};
|
||||
|
||||
class PipelineBase : public ApiObjectBase, public CachedObject {
|
||||
|
|
|
@ -222,9 +222,11 @@ namespace dawn_native {
|
|||
OverridableConstantScalar defaultValue;
|
||||
};
|
||||
|
||||
using OverridableConstantsMap = std::unordered_map<std::string, OverridableConstant>;
|
||||
|
||||
// Map identifier to overridable constant
|
||||
// Identifier is unique: either the variable name or the numeric ID if specified
|
||||
std::unordered_map<std::string, OverridableConstant> overridableConstants;
|
||||
OverridableConstantsMap overridableConstants;
|
||||
|
||||
// Overridable constants that are not initialized in shaders
|
||||
// They need value initialization from pipeline stage or it is a validation error
|
||||
|
@ -254,7 +256,7 @@ namespace dawn_native {
|
|||
// Return true iff the program has an entrypoint called `entryPoint`.
|
||||
bool HasEntryPoint(const std::string& entryPoint) const;
|
||||
|
||||
// Returns the metadata for the given `entryPoint`. HasEntryPoint with the same argument
|
||||
// Return the metadata for the given `entryPoint`. HasEntryPoint with the same argument
|
||||
// must be true.
|
||||
const EntryPointMetadata& GetEntryPoint(const std::string& entryPoint) const;
|
||||
|
||||
|
|
|
@ -48,9 +48,8 @@ namespace dawn_native { namespace d3d12 {
|
|||
d3dDesc.pRootSignature = ToBackend(GetLayout())->GetRootSignature();
|
||||
|
||||
CompiledShader compiledShader;
|
||||
DAWN_TRY_ASSIGN(compiledShader,
|
||||
module->Compile(computeStage.entryPoint.c_str(), SingleShaderStage::Compute,
|
||||
ToBackend(GetLayout()), compileFlags));
|
||||
DAWN_TRY_ASSIGN(compiledShader, module->Compile(computeStage, SingleShaderStage::Compute,
|
||||
ToBackend(GetLayout()), compileFlags));
|
||||
d3dDesc.CS = compiledShader.GetD3D12ShaderBytecode();
|
||||
auto* d3d12Device = device->GetD3D12Device();
|
||||
DAWN_TRY(CheckHRESULT(
|
||||
|
|
|
@ -348,10 +348,10 @@ namespace dawn_native { namespace d3d12 {
|
|||
PerStage<CompiledShader> compiledShader;
|
||||
|
||||
for (auto stage : IterateStages(GetStageMask())) {
|
||||
DAWN_TRY_ASSIGN(compiledShader[stage],
|
||||
ToBackend(pipelineStages[stage].module)
|
||||
->Compile(pipelineStages[stage].entryPoint.c_str(), stage,
|
||||
ToBackend(GetLayout()), compileFlags));
|
||||
DAWN_TRY_ASSIGN(
|
||||
compiledShader[stage],
|
||||
ToBackend(pipelineStages[stage].module)
|
||||
->Compile(pipelineStages[stage], stage, ToBackend(GetLayout()), compileFlags));
|
||||
*shaders[stage] = compiledShader[stage].GetD3D12ShaderBytecode();
|
||||
}
|
||||
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
#include "common/Assert.h"
|
||||
#include "common/BitSetIterator.h"
|
||||
#include "common/Log.h"
|
||||
#include "common/WindowsUtils.h"
|
||||
#include "dawn_native/Pipeline.h"
|
||||
#include "dawn_native/TintUtils.h"
|
||||
#include "dawn_native/d3d12/BindGroupLayoutD3D12.h"
|
||||
#include "dawn_native/d3d12/D3D12Error.h"
|
||||
|
@ -32,6 +34,94 @@
|
|||
#include <sstream>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace dawn_native {
|
||||
template <typename StringType, typename T = int32_t>
|
||||
struct NumberToString {
|
||||
static StringType ToStringAsValue(T v);
|
||||
static StringType ToStringAsId(T v);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct NumberToString<std::string, T> {
|
||||
static constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
|
||||
static std::string ToStringAsValue(T v) {
|
||||
return std::to_string(v);
|
||||
}
|
||||
static std::string ToStringAsId(T v) {
|
||||
return std::to_string(v);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct NumberToString<std::wstring, T> {
|
||||
static constexpr WCHAR kSpecConstantPrefix[] = L"WGSL_SPEC_CONSTANT_";
|
||||
static std::wstring ToStringAsValue(T v) {
|
||||
return std::to_wstring(v);
|
||||
}
|
||||
static std::wstring ToStringAsId(T v) {
|
||||
return std::to_wstring(v);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumberToString<std::string, float> {
|
||||
static std::string ToStringAsValue(float v) {
|
||||
std::ostringstream out;
|
||||
// 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
|
||||
out.precision(8);
|
||||
out << std::fixed << v;
|
||||
return out.str();
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumberToString<std::wstring, float> {
|
||||
static std::wstring ToStringAsValue(float v) {
|
||||
std::basic_ostringstream<WCHAR> out;
|
||||
// 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
|
||||
out.precision(8);
|
||||
out << std::fixed << v;
|
||||
return out.str();
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumberToString<std::string, uint32_t> {
|
||||
static std::string ToStringAsValue(uint32_t v) {
|
||||
return std::to_string(v) + "u";
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumberToString<std::wstring, uint32_t> {
|
||||
static std::wstring ToStringAsValue(uint32_t v) {
|
||||
return std::to_wstring(v) + L"u";
|
||||
}
|
||||
};
|
||||
|
||||
template <typename StringType>
|
||||
StringType GetHLSLValueString(EntryPointMetadata::OverridableConstant::Type dawnType,
|
||||
const OverridableConstantScalar* entry,
|
||||
double value = 0) {
|
||||
switch (dawnType) {
|
||||
case EntryPointMetadata::OverridableConstant::Type::Boolean:
|
||||
return NumberToString<StringType, int32_t>::ToStringAsValue(
|
||||
entry ? entry->b : static_cast<int32_t>(value));
|
||||
case EntryPointMetadata::OverridableConstant::Type::Float32:
|
||||
return NumberToString<StringType, float>::ToStringAsValue(
|
||||
entry ? entry->f32 : static_cast<float>(value));
|
||||
case EntryPointMetadata::OverridableConstant::Type::Int32:
|
||||
return NumberToString<StringType, int32_t>::ToStringAsValue(
|
||||
entry ? entry->i32 : static_cast<int32_t>(value));
|
||||
case EntryPointMetadata::OverridableConstant::Type::Uint32:
|
||||
return NumberToString<StringType, uint32_t>::ToStringAsValue(
|
||||
entry ? entry->u32 : static_cast<uint32_t>(value));
|
||||
default:
|
||||
UNREACHABLE();
|
||||
}
|
||||
}
|
||||
} // namespace dawn_native
|
||||
|
||||
namespace dawn_native { namespace d3d12 {
|
||||
|
||||
namespace {
|
||||
|
@ -109,6 +199,8 @@ namespace dawn_native { namespace d3d12 {
|
|||
bool usesNumWorkgroups;
|
||||
uint32_t numWorkgroupsRegisterSpace;
|
||||
uint32_t numWorkgroupsShaderRegister;
|
||||
const PipelineConstantEntries* pipelineConstantEntries;
|
||||
const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants;
|
||||
|
||||
// FXC/DXC common inputs
|
||||
bool disableWorkgroupInit;
|
||||
|
@ -128,7 +220,8 @@ namespace dawn_native { namespace d3d12 {
|
|||
uint32_t compileFlags,
|
||||
const Device* device,
|
||||
const tint::Program* program,
|
||||
const EntryPointMetadata& entryPoint) {
|
||||
const EntryPointMetadata& entryPoint,
|
||||
const ProgrammableStage& programmableStage) {
|
||||
Compiler compiler;
|
||||
uint64_t dxcVersion = 0;
|
||||
if (device->IsToggleEnabled(Toggle::UseDXC)) {
|
||||
|
@ -200,6 +293,10 @@ namespace dawn_native { namespace d3d12 {
|
|||
request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0;
|
||||
request.deviceInfo = &device->GetDeviceInfo();
|
||||
request.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16);
|
||||
request.pipelineConstantEntries = &programmableStage.constants;
|
||||
request.shaderEntryPointConstants =
|
||||
&programmableStage.module->GetEntryPoint(programmableStage.entryPoint)
|
||||
.overridableConstants;
|
||||
return std::move(request);
|
||||
}
|
||||
|
||||
|
@ -251,6 +348,21 @@ namespace dawn_native { namespace d3d12 {
|
|||
stream << " fxcVersion=" << fxcVersion;
|
||||
stream << " dxcVersion=" << dxcVersion;
|
||||
stream << " hasShaderFloat16Feature=" << hasShaderFloat16Feature;
|
||||
|
||||
stream << " overridableConstants={";
|
||||
for (const auto& pipelineConstant : *pipelineConstantEntries) {
|
||||
const std::string& name = pipelineConstant.first;
|
||||
double value = pipelineConstant.second;
|
||||
|
||||
// This is already validated so `name` must exist
|
||||
const auto& moduleConstant = shaderEntryPointConstants->at(name);
|
||||
|
||||
stream << " <" << name << ","
|
||||
<< GetHLSLValueString<std::string>(moduleConstant.type, nullptr, value)
|
||||
<< ">";
|
||||
}
|
||||
stream << " }";
|
||||
|
||||
stream << ")";
|
||||
stream << "\n";
|
||||
|
||||
|
@ -310,6 +422,53 @@ namespace dawn_native { namespace d3d12 {
|
|||
return arguments;
|
||||
}
|
||||
|
||||
template <typename StringType>
|
||||
const std::vector<std::pair<StringType, StringType>> GetOverridableConstantsDefines(
|
||||
const PipelineConstantEntries* pipelineConstantEntries,
|
||||
const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants) {
|
||||
std::vector<std::pair<StringType, StringType>> defineStrings;
|
||||
|
||||
std::unordered_set<std::string> overriddenConstants;
|
||||
|
||||
// Set pipeline overridden values
|
||||
for (const auto& pipelineConstant : *pipelineConstantEntries) {
|
||||
const std::string& name = pipelineConstant.first;
|
||||
double value = pipelineConstant.second;
|
||||
|
||||
overriddenConstants.insert(name);
|
||||
|
||||
// This is already validated so `name` must exist
|
||||
const auto& moduleConstant = shaderEntryPointConstants->at(name);
|
||||
|
||||
defineStrings.emplace_back(
|
||||
NumberToString<StringType>::kSpecConstantPrefix +
|
||||
NumberToString<StringType>::ToStringAsId(
|
||||
static_cast<int32_t>(moduleConstant.id)),
|
||||
GetHLSLValueString<StringType>(moduleConstant.type, nullptr, value));
|
||||
}
|
||||
|
||||
// Set shader initialized default values
|
||||
for (const auto& iter : *shaderEntryPointConstants) {
|
||||
const std::string& name = iter.first;
|
||||
if (overriddenConstants.count(name) != 0) {
|
||||
// This constant already has overridden value
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto& moduleConstant = shaderEntryPointConstants->at(name);
|
||||
|
||||
// Uninitialized default values are okay since they are only defined to pass
|
||||
// compilation but not used
|
||||
defineStrings.emplace_back(NumberToString<StringType>::kSpecConstantPrefix +
|
||||
NumberToString<StringType>::ToStringAsId(
|
||||
static_cast<int32_t>(moduleConstant.id)),
|
||||
GetHLSLValueString<StringType>(
|
||||
moduleConstant.type, &moduleConstant.defaultValue));
|
||||
}
|
||||
|
||||
return defineStrings;
|
||||
}
|
||||
|
||||
ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(IDxcLibrary* dxcLibrary,
|
||||
IDxcCompiler* dxcCompiler,
|
||||
const ShaderCompilationRequest& request,
|
||||
|
@ -326,12 +485,21 @@ namespace dawn_native { namespace d3d12 {
|
|||
std::vector<const wchar_t*> arguments =
|
||||
GetDXCArguments(request.compileFlags, request.hasShaderFloat16Feature);
|
||||
|
||||
// Build defines for overridable constants
|
||||
const auto& defineStrings = GetOverridableConstantsDefines<std::wstring>(
|
||||
request.pipelineConstantEntries, request.shaderEntryPointConstants);
|
||||
std::vector<DxcDefine> dxcDefines;
|
||||
dxcDefines.reserve(defineStrings.size());
|
||||
for (const auto& d : defineStrings) {
|
||||
dxcDefines.push_back({d.first.c_str(), d.second.c_str()});
|
||||
}
|
||||
|
||||
ComPtr<IDxcOperationResult> result;
|
||||
DAWN_TRY(CheckHRESULT(
|
||||
dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(),
|
||||
request.deviceInfo->shaderProfiles[request.stage].c_str(),
|
||||
arguments.data(), arguments.size(), nullptr, 0, nullptr,
|
||||
&result),
|
||||
arguments.data(), arguments.size(), dxcDefines.data(),
|
||||
dxcDefines.size(), nullptr, &result),
|
||||
"DXC compile"));
|
||||
|
||||
HRESULT hr;
|
||||
|
@ -369,8 +537,24 @@ namespace dawn_native { namespace d3d12 {
|
|||
ComPtr<ID3DBlob> compiledShader;
|
||||
ComPtr<ID3DBlob> errors;
|
||||
|
||||
// Build defines for overridable constants
|
||||
const auto& defineStrings = GetOverridableConstantsDefines<std::string>(
|
||||
request.pipelineConstantEntries, request.shaderEntryPointConstants);
|
||||
|
||||
const D3D_SHADER_MACRO* pDefines = nullptr;
|
||||
std::vector<D3D_SHADER_MACRO> fxcDefines;
|
||||
if (defineStrings.size() > 0) {
|
||||
fxcDefines.reserve(defineStrings.size() + 1);
|
||||
for (const auto& d : defineStrings) {
|
||||
fxcDefines.push_back({d.first.c_str(), d.second.c_str()});
|
||||
}
|
||||
// d3dCompile D3D_SHADER_MACRO* pDefines is a nullptr terminated array
|
||||
fxcDefines.push_back({nullptr, nullptr});
|
||||
pDefines = fxcDefines.data();
|
||||
}
|
||||
|
||||
DAWN_INVALID_IF(FAILED(functions->d3dCompile(
|
||||
hlslSource.c_str(), hlslSource.length(), nullptr, nullptr, nullptr,
|
||||
hlslSource.c_str(), hlslSource.length(), nullptr, pDefines, nullptr,
|
||||
request.entryPointName, targetProfile, request.compileFlags, 0,
|
||||
&compiledShader, &errors)),
|
||||
"D3D compile failed with: %s",
|
||||
|
@ -392,6 +576,9 @@ namespace dawn_native { namespace d3d12 {
|
|||
}
|
||||
transformManager.Add<tint::transform::BindingRemapper>();
|
||||
|
||||
transformManager.Add<tint::transform::SingleEntryPoint>();
|
||||
transformInputs.Add<tint::transform::SingleEntryPoint::Config>(request.entryPointName);
|
||||
|
||||
transformManager.Add<tint::transform::Renamer>();
|
||||
|
||||
if (request.disableSymbolRenaming) {
|
||||
|
@ -508,7 +695,7 @@ namespace dawn_native { namespace d3d12 {
|
|||
return InitializeBase(parseResult);
|
||||
}
|
||||
|
||||
ResultOrError<CompiledShader> ShaderModule::Compile(const char* entryPointName,
|
||||
ResultOrError<CompiledShader> ShaderModule::Compile(const ProgrammableStage& programmableStage,
|
||||
SingleShaderStage stage,
|
||||
PipelineLayout* layout,
|
||||
uint32_t compileFlags) {
|
||||
|
@ -555,9 +742,10 @@ namespace dawn_native { namespace d3d12 {
|
|||
}
|
||||
|
||||
ShaderCompilationRequest request;
|
||||
DAWN_TRY_ASSIGN(request, ShaderCompilationRequest::Create(entryPointName, stage, layout,
|
||||
compileFlags, device, program,
|
||||
GetEntryPoint(entryPointName)));
|
||||
DAWN_TRY_ASSIGN(
|
||||
request, ShaderCompilationRequest::Create(
|
||||
programmableStage.entryPoint.c_str(), stage, layout, compileFlags, device,
|
||||
program, GetEntryPoint(programmableStage.entryPoint), programmableStage));
|
||||
|
||||
PersistentCacheKey shaderCacheKey;
|
||||
DAWN_TRY_ASSIGN(shaderCacheKey, request.CreateCacheKey());
|
||||
|
|
|
@ -20,6 +20,10 @@
|
|||
|
||||
#include "dawn_native/d3d12/d3d12_platform.h"
|
||||
|
||||
namespace dawn_native {
|
||||
struct ProgrammableStage;
|
||||
} // namespace dawn_native
|
||||
|
||||
namespace dawn_native { namespace d3d12 {
|
||||
|
||||
class Device;
|
||||
|
@ -49,7 +53,7 @@ namespace dawn_native { namespace d3d12 {
|
|||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult);
|
||||
|
||||
ResultOrError<CompiledShader> Compile(const char* entryPointName,
|
||||
ResultOrError<CompiledShader> Compile(const ProgrammableStage& programmableStage,
|
||||
SingleShaderStage stage,
|
||||
PipelineLayout* layout,
|
||||
uint32_t compileFlags);
|
||||
|
|
|
@ -391,8 +391,8 @@ fn main([[location(0)]] pos : vec4<f32>) -> [[builtin(position)]] vec4<f32> {
|
|||
|
||||
// Test overridable constants without numeric identifiers
|
||||
TEST_P(ShaderTests, OverridableConstants) {
|
||||
// TODO(dawn:1137): D3D12 backend is unimplemented
|
||||
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan() && !IsMetal());
|
||||
DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
|
||||
DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES());
|
||||
|
||||
uint32_t const kCount = 11;
|
||||
std::vector<uint32_t> expected(kCount);
|
||||
|
@ -469,8 +469,8 @@ TEST_P(ShaderTests, OverridableConstants) {
|
|||
|
||||
// Test overridable constants with numeric identifiers
|
||||
TEST_P(ShaderTests, OverridableConstantsNumericIdentifiers) {
|
||||
// TODO(dawn:1137): D3D12 backend is unimplemented
|
||||
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan() && !IsMetal());
|
||||
DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
|
||||
DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES());
|
||||
|
||||
uint32_t const kCount = 4;
|
||||
std::vector<uint32_t> expected{1u, 2u, 3u, 0u};
|
||||
|
@ -523,17 +523,71 @@ TEST_P(ShaderTests, OverridableConstantsNumericIdentifiers) {
|
|||
EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), buffer, 0, kCount);
|
||||
}
|
||||
|
||||
// Test overridable constants precision
|
||||
// D3D12 HLSL shader uses defines so we want float number to have enough precision
|
||||
TEST_P(ShaderTests, OverridableConstantsPrecision) {
|
||||
DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
|
||||
DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES());
|
||||
|
||||
uint32_t const kCount = 2;
|
||||
float const kValue1 = 3.14159;
|
||||
float const kValue2 = 3.141592653589793238;
|
||||
std::vector<float> expected{kValue1, kValue2};
|
||||
wgpu::Buffer buffer = CreateBuffer(kCount);
|
||||
|
||||
std::string shader = R"(
|
||||
[[override(1001)]] let c1: f32;
|
||||
[[override(1002)]] let c2: f32;
|
||||
|
||||
[[block]] struct Buf {
|
||||
data : array<f32, 2>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> buf : Buf;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]] fn main() {
|
||||
buf.data[0] = c1;
|
||||
buf.data[1] = c2;
|
||||
})";
|
||||
|
||||
std::vector<wgpu::ConstantEntry> constants;
|
||||
constants.push_back({nullptr, "1001", kValue1});
|
||||
constants.push_back({nullptr, "1002", kValue2});
|
||||
wgpu::ComputePipeline pipeline = CreateComputePipeline(shader, "main", &constants);
|
||||
|
||||
wgpu::BindGroup bindGroup =
|
||||
utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0), {{0, buffer}});
|
||||
|
||||
wgpu::CommandBuffer commands;
|
||||
{
|
||||
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
||||
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
||||
pass.SetPipeline(pipeline);
|
||||
pass.SetBindGroup(0, bindGroup);
|
||||
pass.Dispatch(1);
|
||||
pass.EndPass();
|
||||
|
||||
commands = encoder.Finish();
|
||||
}
|
||||
|
||||
queue.Submit(1, &commands);
|
||||
|
||||
EXPECT_BUFFER_FLOAT_RANGE_EQ(expected.data(), buffer, 0, kCount);
|
||||
}
|
||||
|
||||
// Test overridable constants for different entry points
|
||||
TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
|
||||
// TODO(dawn:1137): D3D12 backend is unimplemented
|
||||
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan() && !IsMetal());
|
||||
DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
|
||||
DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES());
|
||||
|
||||
uint32_t const kCount = 1;
|
||||
std::vector<uint32_t> expected1{1u};
|
||||
std::vector<uint32_t> expected2{2u};
|
||||
std::vector<uint32_t> expected3{3u};
|
||||
|
||||
wgpu::Buffer buffer1 = CreateBuffer(kCount);
|
||||
wgpu::Buffer buffer2 = CreateBuffer(kCount);
|
||||
wgpu::Buffer buffer3 = CreateBuffer(kCount);
|
||||
|
||||
std::string shader = R"(
|
||||
[[override(1001)]] let c1: u32;
|
||||
|
@ -552,6 +606,10 @@ TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
|
|||
[[stage(compute), workgroup_size(1)]] fn main2() {
|
||||
buf.data[0] = c2;
|
||||
}
|
||||
|
||||
[[stage(compute), workgroup_size(1)]] fn main3() {
|
||||
buf.data[0] = 3u;
|
||||
}
|
||||
)";
|
||||
|
||||
std::vector<wgpu::ConstantEntry> constants1;
|
||||
|
@ -575,10 +633,17 @@ TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
|
|||
csDesc2.compute.constantCount = constants2.size();
|
||||
wgpu::ComputePipeline pipeline2 = device.CreateComputePipeline(&csDesc2);
|
||||
|
||||
wgpu::ComputePipelineDescriptor csDesc3;
|
||||
csDesc3.compute.module = shaderModule;
|
||||
csDesc3.compute.entryPoint = "main3";
|
||||
wgpu::ComputePipeline pipeline3 = device.CreateComputePipeline(&csDesc3);
|
||||
|
||||
wgpu::BindGroup bindGroup1 =
|
||||
utils::MakeBindGroup(device, pipeline1.GetBindGroupLayout(0), {{0, buffer1}});
|
||||
wgpu::BindGroup bindGroup2 =
|
||||
utils::MakeBindGroup(device, pipeline2.GetBindGroupLayout(0), {{0, buffer2}});
|
||||
wgpu::BindGroup bindGroup3 =
|
||||
utils::MakeBindGroup(device, pipeline3.GetBindGroupLayout(0), {{0, buffer3}});
|
||||
|
||||
wgpu::CommandBuffer commands;
|
||||
{
|
||||
|
@ -592,6 +657,10 @@ TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
|
|||
pass.SetBindGroup(0, bindGroup2);
|
||||
pass.Dispatch(1);
|
||||
|
||||
pass.SetPipeline(pipeline3);
|
||||
pass.SetBindGroup(0, bindGroup3);
|
||||
pass.Dispatch(1);
|
||||
|
||||
pass.EndPass();
|
||||
|
||||
commands = encoder.Finish();
|
||||
|
@ -601,14 +670,15 @@ TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
|
|||
|
||||
EXPECT_BUFFER_U32_RANGE_EQ(expected1.data(), buffer1, 0, kCount);
|
||||
EXPECT_BUFFER_U32_RANGE_EQ(expected2.data(), buffer2, 0, kCount);
|
||||
EXPECT_BUFFER_U32_RANGE_EQ(expected3.data(), buffer3, 0, kCount);
|
||||
}
|
||||
|
||||
// Test overridable constants with render pipeline
|
||||
// Draw a triangle covering the render target, with vertex position and color values from
|
||||
// overridable constants
|
||||
TEST_P(ShaderTests, OverridableConstantsRenderPipeline) {
|
||||
// TODO(dawn:1137): D3D12 backend is unimplemented
|
||||
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan() && !IsMetal());
|
||||
DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
|
||||
DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES());
|
||||
|
||||
wgpu::ShaderModule vsModule = utils::CreateShaderModule(device, R"(
|
||||
[[override(1111)]] let xright: f32;
|
||||
|
|
|
@ -105,12 +105,12 @@ TEST_F(ComputePipelineOverridableConstantsValidationTest, ConstantsIdentifierLoo
|
|||
TestCreatePipeline(constants);
|
||||
}
|
||||
{
|
||||
// Valid: set the same constant twice
|
||||
// Error: set the same constant twice
|
||||
std::vector<wgpu::ConstantEntry> constants{
|
||||
{nullptr, "c0", 0},
|
||||
{nullptr, "c0", 1},
|
||||
};
|
||||
TestCreatePipeline(constants);
|
||||
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
|
||||
}
|
||||
{
|
||||
// Valid: find by constant numeric id
|
||||
|
@ -158,12 +158,12 @@ TEST_F(ComputePipelineOverridableConstantsValidationTest, UninitializedConstants
|
|||
TestCreatePipeline(constants);
|
||||
}
|
||||
{
|
||||
// Valid: all constants initialized (with duplicate initializations)
|
||||
// Error: duplicate initializations
|
||||
std::vector<wgpu::ConstantEntry> constants{
|
||||
{nullptr, "c0", false}, {nullptr, "c2", 1}, {nullptr, "c5", 1},
|
||||
{nullptr, "c8", 1}, {nullptr, "c2", 2},
|
||||
};
|
||||
TestCreatePipeline(constants);
|
||||
ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue