Pipeline overridable constants: D3D12 backend

D3D12 doesn't have native pipeline constants feature.
This is done by using #define. Add some new tests to make
sure these define approaches work as expected.

Also makes duplicate pipeline constant entries an invalid
case.

Bug: dawn:1137
Change-Id: Iefed44a749625b535bbafbb39f42699f0b42e06a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/68860
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Shrek Shao <shrekshao@google.com>
This commit is contained in:
shrekshao 2021-11-18 17:47:45 +00:00 committed by Dawn LUCI CQ
parent fd066a3345
commit c6c4588036
9 changed files with 309 additions and 40 deletions

View File

@ -55,17 +55,22 @@ namespace dawn_native {
// pipelineBase is not yet constructed at this moment so iterate constants from descriptor // pipelineBase is not yet constructed at this moment so iterate constants from descriptor
size_t numUninitializedConstants = metadata.uninitializedOverridableConstants.size(); size_t numUninitializedConstants = metadata.uninitializedOverridableConstants.size();
// Keep an initialized constants sets to handle duplicate initialization cases // Keep an initialized constants sets to handle duplicate initialization cases
// Only storing that of uninialized constants is needed
std::unordered_set<std::string> stageInitializedConstantIdentifiers; std::unordered_set<std::string> stageInitializedConstantIdentifiers;
for (uint32_t i = 0; i < constantCount; i++) { for (uint32_t i = 0; i < constantCount; i++) {
DAWN_INVALID_IF(metadata.overridableConstants.count(constants[i].key) == 0, DAWN_INVALID_IF(metadata.overridableConstants.count(constants[i].key) == 0,
"Pipeline overridable constant \"%s\" not found in shader module %s.", "Pipeline overridable constant \"%s\" not found in %s.",
constants[i].key, module); constants[i].key, module);
if (metadata.uninitializedOverridableConstants.count(constants[i].key) > 0 && if (stageInitializedConstantIdentifiers.count(constants[i].key) == 0) {
stageInitializedConstantIdentifiers.count(constants[i].key) == 0) { if (metadata.uninitializedOverridableConstants.count(constants[i].key) > 0) {
numUninitializedConstants--; numUninitializedConstants--;
}
stageInitializedConstantIdentifiers.insert(constants[i].key); stageInitializedConstantIdentifiers.insert(constants[i].key);
} else {
// There are duplicate initializations
return DAWN_FORMAT_VALIDATION_ERROR(
"Pipeline overridable constants \"%s\" is set more than once in %s",
constants[i].key, module);
} }
} }
@ -116,11 +121,10 @@ namespace dawn_native {
// Record them internally. // Record them internally.
bool isFirstStage = mStageMask == wgpu::ShaderStage::None; bool isFirstStage = mStageMask == wgpu::ShaderStage::None;
mStageMask |= StageBit(shaderStage); mStageMask |= StageBit(shaderStage);
mStages[shaderStage] = {module, entryPointName, &metadata, mStages[shaderStage] = {module, entryPointName, &metadata, {}};
std::vector<PipelineConstantEntry>()};
auto& constants = mStages[shaderStage].constants; auto& constants = mStages[shaderStage].constants;
for (uint32_t i = 0; i < stage.constantCount; i++) { for (uint32_t i = 0; i < stage.constantCount; i++) {
constants.emplace_back(stage.constants[i].key, stage.constants[i].value); constants.emplace(stage.constants[i].key, stage.constants[i].value);
} }
// Compute the max() of all minBufferSizes across all stages. // Compute the max() of all minBufferSizes across all stages.

View File

@ -37,7 +37,9 @@ namespace dawn_native {
const PipelineLayoutBase* layout, const PipelineLayoutBase* layout,
SingleShaderStage stage); SingleShaderStage stage);
using PipelineConstantEntry = std::pair<std::string, double>; // Use map to make sure constant keys are sorted for creating shader cache keys
using PipelineConstantEntries = std::map<std::string, double>;
struct ProgrammableStage { struct ProgrammableStage {
Ref<ShaderModuleBase> module; Ref<ShaderModuleBase> module;
std::string entryPoint; std::string entryPoint;
@ -45,7 +47,7 @@ namespace dawn_native {
// The metadata lives as long as module, that's ref-ed in the same structure. // The metadata lives as long as module, that's ref-ed in the same structure.
const EntryPointMetadata* metadata = nullptr; const EntryPointMetadata* metadata = nullptr;
std::vector<PipelineConstantEntry> constants; PipelineConstantEntries constants;
}; };
class PipelineBase : public ApiObjectBase, public CachedObject { class PipelineBase : public ApiObjectBase, public CachedObject {

View File

@ -222,9 +222,11 @@ namespace dawn_native {
OverridableConstantScalar defaultValue; OverridableConstantScalar defaultValue;
}; };
using OverridableConstantsMap = std::unordered_map<std::string, OverridableConstant>;
// Map identifier to overridable constant // Map identifier to overridable constant
// Identifier is unique: either the variable name or the numeric ID if specified // Identifier is unique: either the variable name or the numeric ID if specified
std::unordered_map<std::string, OverridableConstant> overridableConstants; OverridableConstantsMap overridableConstants;
// Overridable constants that are not initialized in shaders // Overridable constants that are not initialized in shaders
// They need value initialization from pipeline stage or it is a validation error // They need value initialization from pipeline stage or it is a validation error
@ -254,7 +256,7 @@ namespace dawn_native {
// Return true iff the program has an entrypoint called `entryPoint`. // Return true iff the program has an entrypoint called `entryPoint`.
bool HasEntryPoint(const std::string& entryPoint) const; bool HasEntryPoint(const std::string& entryPoint) const;
// Returns the metadata for the given `entryPoint`. HasEntryPoint with the same argument // Return the metadata for the given `entryPoint`. HasEntryPoint with the same argument
// must be true. // must be true.
const EntryPointMetadata& GetEntryPoint(const std::string& entryPoint) const; const EntryPointMetadata& GetEntryPoint(const std::string& entryPoint) const;

View File

@ -48,9 +48,8 @@ namespace dawn_native { namespace d3d12 {
d3dDesc.pRootSignature = ToBackend(GetLayout())->GetRootSignature(); d3dDesc.pRootSignature = ToBackend(GetLayout())->GetRootSignature();
CompiledShader compiledShader; CompiledShader compiledShader;
DAWN_TRY_ASSIGN(compiledShader, DAWN_TRY_ASSIGN(compiledShader, module->Compile(computeStage, SingleShaderStage::Compute,
module->Compile(computeStage.entryPoint.c_str(), SingleShaderStage::Compute, ToBackend(GetLayout()), compileFlags));
ToBackend(GetLayout()), compileFlags));
d3dDesc.CS = compiledShader.GetD3D12ShaderBytecode(); d3dDesc.CS = compiledShader.GetD3D12ShaderBytecode();
auto* d3d12Device = device->GetD3D12Device(); auto* d3d12Device = device->GetD3D12Device();
DAWN_TRY(CheckHRESULT( DAWN_TRY(CheckHRESULT(

View File

@ -348,10 +348,10 @@ namespace dawn_native { namespace d3d12 {
PerStage<CompiledShader> compiledShader; PerStage<CompiledShader> compiledShader;
for (auto stage : IterateStages(GetStageMask())) { for (auto stage : IterateStages(GetStageMask())) {
DAWN_TRY_ASSIGN(compiledShader[stage], DAWN_TRY_ASSIGN(
ToBackend(pipelineStages[stage].module) compiledShader[stage],
->Compile(pipelineStages[stage].entryPoint.c_str(), stage, ToBackend(pipelineStages[stage].module)
ToBackend(GetLayout()), compileFlags)); ->Compile(pipelineStages[stage], stage, ToBackend(GetLayout()), compileFlags));
*shaders[stage] = compiledShader[stage].GetD3D12ShaderBytecode(); *shaders[stage] = compiledShader[stage].GetD3D12ShaderBytecode();
} }

View File

@ -17,6 +17,8 @@
#include "common/Assert.h" #include "common/Assert.h"
#include "common/BitSetIterator.h" #include "common/BitSetIterator.h"
#include "common/Log.h" #include "common/Log.h"
#include "common/WindowsUtils.h"
#include "dawn_native/Pipeline.h"
#include "dawn_native/TintUtils.h" #include "dawn_native/TintUtils.h"
#include "dawn_native/d3d12/BindGroupLayoutD3D12.h" #include "dawn_native/d3d12/BindGroupLayoutD3D12.h"
#include "dawn_native/d3d12/D3D12Error.h" #include "dawn_native/d3d12/D3D12Error.h"
@ -32,6 +34,94 @@
#include <sstream> #include <sstream>
#include <unordered_map> #include <unordered_map>
namespace dawn_native {
template <typename StringType, typename T = int32_t>
struct NumberToString {
static StringType ToStringAsValue(T v);
static StringType ToStringAsId(T v);
};
template <typename T>
struct NumberToString<std::string, T> {
static constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
static std::string ToStringAsValue(T v) {
return std::to_string(v);
}
static std::string ToStringAsId(T v) {
return std::to_string(v);
}
};
template <typename T>
struct NumberToString<std::wstring, T> {
static constexpr WCHAR kSpecConstantPrefix[] = L"WGSL_SPEC_CONSTANT_";
static std::wstring ToStringAsValue(T v) {
return std::to_wstring(v);
}
static std::wstring ToStringAsId(T v) {
return std::to_wstring(v);
}
};
template <>
struct NumberToString<std::string, float> {
static std::string ToStringAsValue(float v) {
std::ostringstream out;
// 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
out.precision(8);
out << std::fixed << v;
return out.str();
}
};
template <>
struct NumberToString<std::wstring, float> {
static std::wstring ToStringAsValue(float v) {
std::basic_ostringstream<WCHAR> out;
// 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
out.precision(8);
out << std::fixed << v;
return out.str();
}
};
template <>
struct NumberToString<std::string, uint32_t> {
static std::string ToStringAsValue(uint32_t v) {
return std::to_string(v) + "u";
}
};
template <>
struct NumberToString<std::wstring, uint32_t> {
static std::wstring ToStringAsValue(uint32_t v) {
return std::to_wstring(v) + L"u";
}
};
template <typename StringType>
StringType GetHLSLValueString(EntryPointMetadata::OverridableConstant::Type dawnType,
const OverridableConstantScalar* entry,
double value = 0) {
switch (dawnType) {
case EntryPointMetadata::OverridableConstant::Type::Boolean:
return NumberToString<StringType, int32_t>::ToStringAsValue(
entry ? entry->b : static_cast<int32_t>(value));
case EntryPointMetadata::OverridableConstant::Type::Float32:
return NumberToString<StringType, float>::ToStringAsValue(
entry ? entry->f32 : static_cast<float>(value));
case EntryPointMetadata::OverridableConstant::Type::Int32:
return NumberToString<StringType, int32_t>::ToStringAsValue(
entry ? entry->i32 : static_cast<int32_t>(value));
case EntryPointMetadata::OverridableConstant::Type::Uint32:
return NumberToString<StringType, uint32_t>::ToStringAsValue(
entry ? entry->u32 : static_cast<uint32_t>(value));
default:
UNREACHABLE();
}
}
} // namespace dawn_native
namespace dawn_native { namespace d3d12 { namespace dawn_native { namespace d3d12 {
namespace { namespace {
@ -109,6 +199,8 @@ namespace dawn_native { namespace d3d12 {
bool usesNumWorkgroups; bool usesNumWorkgroups;
uint32_t numWorkgroupsRegisterSpace; uint32_t numWorkgroupsRegisterSpace;
uint32_t numWorkgroupsShaderRegister; uint32_t numWorkgroupsShaderRegister;
const PipelineConstantEntries* pipelineConstantEntries;
const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants;
// FXC/DXC common inputs // FXC/DXC common inputs
bool disableWorkgroupInit; bool disableWorkgroupInit;
@ -128,7 +220,8 @@ namespace dawn_native { namespace d3d12 {
uint32_t compileFlags, uint32_t compileFlags,
const Device* device, const Device* device,
const tint::Program* program, const tint::Program* program,
const EntryPointMetadata& entryPoint) { const EntryPointMetadata& entryPoint,
const ProgrammableStage& programmableStage) {
Compiler compiler; Compiler compiler;
uint64_t dxcVersion = 0; uint64_t dxcVersion = 0;
if (device->IsToggleEnabled(Toggle::UseDXC)) { if (device->IsToggleEnabled(Toggle::UseDXC)) {
@ -200,6 +293,10 @@ namespace dawn_native { namespace d3d12 {
request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0; request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0;
request.deviceInfo = &device->GetDeviceInfo(); request.deviceInfo = &device->GetDeviceInfo();
request.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16); request.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16);
request.pipelineConstantEntries = &programmableStage.constants;
request.shaderEntryPointConstants =
&programmableStage.module->GetEntryPoint(programmableStage.entryPoint)
.overridableConstants;
return std::move(request); return std::move(request);
} }
@ -251,6 +348,21 @@ namespace dawn_native { namespace d3d12 {
stream << " fxcVersion=" << fxcVersion; stream << " fxcVersion=" << fxcVersion;
stream << " dxcVersion=" << dxcVersion; stream << " dxcVersion=" << dxcVersion;
stream << " hasShaderFloat16Feature=" << hasShaderFloat16Feature; stream << " hasShaderFloat16Feature=" << hasShaderFloat16Feature;
stream << " overridableConstants={";
for (const auto& pipelineConstant : *pipelineConstantEntries) {
const std::string& name = pipelineConstant.first;
double value = pipelineConstant.second;
// This is already validated so `name` must exist
const auto& moduleConstant = shaderEntryPointConstants->at(name);
stream << " <" << name << ","
<< GetHLSLValueString<std::string>(moduleConstant.type, nullptr, value)
<< ">";
}
stream << " }";
stream << ")"; stream << ")";
stream << "\n"; stream << "\n";
@ -310,6 +422,53 @@ namespace dawn_native { namespace d3d12 {
return arguments; return arguments;
} }
template <typename StringType>
const std::vector<std::pair<StringType, StringType>> GetOverridableConstantsDefines(
const PipelineConstantEntries* pipelineConstantEntries,
const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants) {
std::vector<std::pair<StringType, StringType>> defineStrings;
std::unordered_set<std::string> overriddenConstants;
// Set pipeline overridden values
for (const auto& pipelineConstant : *pipelineConstantEntries) {
const std::string& name = pipelineConstant.first;
double value = pipelineConstant.second;
overriddenConstants.insert(name);
// This is already validated so `name` must exist
const auto& moduleConstant = shaderEntryPointConstants->at(name);
defineStrings.emplace_back(
NumberToString<StringType>::kSpecConstantPrefix +
NumberToString<StringType>::ToStringAsId(
static_cast<int32_t>(moduleConstant.id)),
GetHLSLValueString<StringType>(moduleConstant.type, nullptr, value));
}
// Set shader initialized default values
for (const auto& iter : *shaderEntryPointConstants) {
const std::string& name = iter.first;
if (overriddenConstants.count(name) != 0) {
// This constant already has overridden value
continue;
}
const auto& moduleConstant = shaderEntryPointConstants->at(name);
// Uninitialized default values are okay since they are only defined to pass
// compilation but not used
defineStrings.emplace_back(NumberToString<StringType>::kSpecConstantPrefix +
NumberToString<StringType>::ToStringAsId(
static_cast<int32_t>(moduleConstant.id)),
GetHLSLValueString<StringType>(
moduleConstant.type, &moduleConstant.defaultValue));
}
return defineStrings;
}
ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(IDxcLibrary* dxcLibrary, ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(IDxcLibrary* dxcLibrary,
IDxcCompiler* dxcCompiler, IDxcCompiler* dxcCompiler,
const ShaderCompilationRequest& request, const ShaderCompilationRequest& request,
@ -326,12 +485,21 @@ namespace dawn_native { namespace d3d12 {
std::vector<const wchar_t*> arguments = std::vector<const wchar_t*> arguments =
GetDXCArguments(request.compileFlags, request.hasShaderFloat16Feature); GetDXCArguments(request.compileFlags, request.hasShaderFloat16Feature);
// Build defines for overridable constants
const auto& defineStrings = GetOverridableConstantsDefines<std::wstring>(
request.pipelineConstantEntries, request.shaderEntryPointConstants);
std::vector<DxcDefine> dxcDefines;
dxcDefines.reserve(defineStrings.size());
for (const auto& d : defineStrings) {
dxcDefines.push_back({d.first.c_str(), d.second.c_str()});
}
ComPtr<IDxcOperationResult> result; ComPtr<IDxcOperationResult> result;
DAWN_TRY(CheckHRESULT( DAWN_TRY(CheckHRESULT(
dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(), dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(),
request.deviceInfo->shaderProfiles[request.stage].c_str(), request.deviceInfo->shaderProfiles[request.stage].c_str(),
arguments.data(), arguments.size(), nullptr, 0, nullptr, arguments.data(), arguments.size(), dxcDefines.data(),
&result), dxcDefines.size(), nullptr, &result),
"DXC compile")); "DXC compile"));
HRESULT hr; HRESULT hr;
@ -369,8 +537,24 @@ namespace dawn_native { namespace d3d12 {
ComPtr<ID3DBlob> compiledShader; ComPtr<ID3DBlob> compiledShader;
ComPtr<ID3DBlob> errors; ComPtr<ID3DBlob> errors;
// Build defines for overridable constants
const auto& defineStrings = GetOverridableConstantsDefines<std::string>(
request.pipelineConstantEntries, request.shaderEntryPointConstants);
const D3D_SHADER_MACRO* pDefines = nullptr;
std::vector<D3D_SHADER_MACRO> fxcDefines;
if (defineStrings.size() > 0) {
fxcDefines.reserve(defineStrings.size() + 1);
for (const auto& d : defineStrings) {
fxcDefines.push_back({d.first.c_str(), d.second.c_str()});
}
// d3dCompile D3D_SHADER_MACRO* pDefines is a nullptr terminated array
fxcDefines.push_back({nullptr, nullptr});
pDefines = fxcDefines.data();
}
DAWN_INVALID_IF(FAILED(functions->d3dCompile( DAWN_INVALID_IF(FAILED(functions->d3dCompile(
hlslSource.c_str(), hlslSource.length(), nullptr, nullptr, nullptr, hlslSource.c_str(), hlslSource.length(), nullptr, pDefines, nullptr,
request.entryPointName, targetProfile, request.compileFlags, 0, request.entryPointName, targetProfile, request.compileFlags, 0,
&compiledShader, &errors)), &compiledShader, &errors)),
"D3D compile failed with: %s", "D3D compile failed with: %s",
@ -392,6 +576,9 @@ namespace dawn_native { namespace d3d12 {
} }
transformManager.Add<tint::transform::BindingRemapper>(); transformManager.Add<tint::transform::BindingRemapper>();
transformManager.Add<tint::transform::SingleEntryPoint>();
transformInputs.Add<tint::transform::SingleEntryPoint::Config>(request.entryPointName);
transformManager.Add<tint::transform::Renamer>(); transformManager.Add<tint::transform::Renamer>();
if (request.disableSymbolRenaming) { if (request.disableSymbolRenaming) {
@ -508,7 +695,7 @@ namespace dawn_native { namespace d3d12 {
return InitializeBase(parseResult); return InitializeBase(parseResult);
} }
ResultOrError<CompiledShader> ShaderModule::Compile(const char* entryPointName, ResultOrError<CompiledShader> ShaderModule::Compile(const ProgrammableStage& programmableStage,
SingleShaderStage stage, SingleShaderStage stage,
PipelineLayout* layout, PipelineLayout* layout,
uint32_t compileFlags) { uint32_t compileFlags) {
@ -555,9 +742,10 @@ namespace dawn_native { namespace d3d12 {
} }
ShaderCompilationRequest request; ShaderCompilationRequest request;
DAWN_TRY_ASSIGN(request, ShaderCompilationRequest::Create(entryPointName, stage, layout, DAWN_TRY_ASSIGN(
compileFlags, device, program, request, ShaderCompilationRequest::Create(
GetEntryPoint(entryPointName))); programmableStage.entryPoint.c_str(), stage, layout, compileFlags, device,
program, GetEntryPoint(programmableStage.entryPoint), programmableStage));
PersistentCacheKey shaderCacheKey; PersistentCacheKey shaderCacheKey;
DAWN_TRY_ASSIGN(shaderCacheKey, request.CreateCacheKey()); DAWN_TRY_ASSIGN(shaderCacheKey, request.CreateCacheKey());

View File

@ -20,6 +20,10 @@
#include "dawn_native/d3d12/d3d12_platform.h" #include "dawn_native/d3d12/d3d12_platform.h"
namespace dawn_native {
struct ProgrammableStage;
} // namespace dawn_native
namespace dawn_native { namespace d3d12 { namespace dawn_native { namespace d3d12 {
class Device; class Device;
@ -49,7 +53,7 @@ namespace dawn_native { namespace d3d12 {
const ShaderModuleDescriptor* descriptor, const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult); ShaderModuleParseResult* parseResult);
ResultOrError<CompiledShader> Compile(const char* entryPointName, ResultOrError<CompiledShader> Compile(const ProgrammableStage& programmableStage,
SingleShaderStage stage, SingleShaderStage stage,
PipelineLayout* layout, PipelineLayout* layout,
uint32_t compileFlags); uint32_t compileFlags);

View File

@ -391,8 +391,8 @@ fn main([[location(0)]] pos : vec4<f32>) -> [[builtin(position)]] vec4<f32> {
// Test overridable constants without numeric identifiers // Test overridable constants without numeric identifiers
TEST_P(ShaderTests, OverridableConstants) { TEST_P(ShaderTests, OverridableConstants) {
// TODO(dawn:1137): D3D12 backend is unimplemented DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan() && !IsMetal()); DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES());
uint32_t const kCount = 11; uint32_t const kCount = 11;
std::vector<uint32_t> expected(kCount); std::vector<uint32_t> expected(kCount);
@ -469,8 +469,8 @@ TEST_P(ShaderTests, OverridableConstants) {
// Test overridable constants with numeric identifiers // Test overridable constants with numeric identifiers
TEST_P(ShaderTests, OverridableConstantsNumericIdentifiers) { TEST_P(ShaderTests, OverridableConstantsNumericIdentifiers) {
// TODO(dawn:1137): D3D12 backend is unimplemented DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan() && !IsMetal()); DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES());
uint32_t const kCount = 4; uint32_t const kCount = 4;
std::vector<uint32_t> expected{1u, 2u, 3u, 0u}; std::vector<uint32_t> expected{1u, 2u, 3u, 0u};
@ -523,17 +523,71 @@ TEST_P(ShaderTests, OverridableConstantsNumericIdentifiers) {
EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), buffer, 0, kCount); EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), buffer, 0, kCount);
} }
// Test overridable constants precision
// D3D12 HLSL shader uses defines so we want float number to have enough precision
TEST_P(ShaderTests, OverridableConstantsPrecision) {
DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES());
uint32_t const kCount = 2;
float const kValue1 = 3.14159;
float const kValue2 = 3.141592653589793238;
std::vector<float> expected{kValue1, kValue2};
wgpu::Buffer buffer = CreateBuffer(kCount);
std::string shader = R"(
[[override(1001)]] let c1: f32;
[[override(1002)]] let c2: f32;
[[block]] struct Buf {
data : array<f32, 2>;
};
[[group(0), binding(0)]] var<storage, read_write> buf : Buf;
[[stage(compute), workgroup_size(1)]] fn main() {
buf.data[0] = c1;
buf.data[1] = c2;
})";
std::vector<wgpu::ConstantEntry> constants;
constants.push_back({nullptr, "1001", kValue1});
constants.push_back({nullptr, "1002", kValue2});
wgpu::ComputePipeline pipeline = CreateComputePipeline(shader, "main", &constants);
wgpu::BindGroup bindGroup =
utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0), {{0, buffer}});
wgpu::CommandBuffer commands;
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
pass.SetPipeline(pipeline);
pass.SetBindGroup(0, bindGroup);
pass.Dispatch(1);
pass.EndPass();
commands = encoder.Finish();
}
queue.Submit(1, &commands);
EXPECT_BUFFER_FLOAT_RANGE_EQ(expected.data(), buffer, 0, kCount);
}
// Test overridable constants for different entry points // Test overridable constants for different entry points
TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) { TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
// TODO(dawn:1137): D3D12 backend is unimplemented DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan() && !IsMetal()); DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES());
uint32_t const kCount = 1; uint32_t const kCount = 1;
std::vector<uint32_t> expected1{1u}; std::vector<uint32_t> expected1{1u};
std::vector<uint32_t> expected2{2u}; std::vector<uint32_t> expected2{2u};
std::vector<uint32_t> expected3{3u};
wgpu::Buffer buffer1 = CreateBuffer(kCount); wgpu::Buffer buffer1 = CreateBuffer(kCount);
wgpu::Buffer buffer2 = CreateBuffer(kCount); wgpu::Buffer buffer2 = CreateBuffer(kCount);
wgpu::Buffer buffer3 = CreateBuffer(kCount);
std::string shader = R"( std::string shader = R"(
[[override(1001)]] let c1: u32; [[override(1001)]] let c1: u32;
@ -552,6 +606,10 @@ TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
[[stage(compute), workgroup_size(1)]] fn main2() { [[stage(compute), workgroup_size(1)]] fn main2() {
buf.data[0] = c2; buf.data[0] = c2;
} }
[[stage(compute), workgroup_size(1)]] fn main3() {
buf.data[0] = 3u;
}
)"; )";
std::vector<wgpu::ConstantEntry> constants1; std::vector<wgpu::ConstantEntry> constants1;
@ -575,10 +633,17 @@ TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
csDesc2.compute.constantCount = constants2.size(); csDesc2.compute.constantCount = constants2.size();
wgpu::ComputePipeline pipeline2 = device.CreateComputePipeline(&csDesc2); wgpu::ComputePipeline pipeline2 = device.CreateComputePipeline(&csDesc2);
wgpu::ComputePipelineDescriptor csDesc3;
csDesc3.compute.module = shaderModule;
csDesc3.compute.entryPoint = "main3";
wgpu::ComputePipeline pipeline3 = device.CreateComputePipeline(&csDesc3);
wgpu::BindGroup bindGroup1 = wgpu::BindGroup bindGroup1 =
utils::MakeBindGroup(device, pipeline1.GetBindGroupLayout(0), {{0, buffer1}}); utils::MakeBindGroup(device, pipeline1.GetBindGroupLayout(0), {{0, buffer1}});
wgpu::BindGroup bindGroup2 = wgpu::BindGroup bindGroup2 =
utils::MakeBindGroup(device, pipeline2.GetBindGroupLayout(0), {{0, buffer2}}); utils::MakeBindGroup(device, pipeline2.GetBindGroupLayout(0), {{0, buffer2}});
wgpu::BindGroup bindGroup3 =
utils::MakeBindGroup(device, pipeline3.GetBindGroupLayout(0), {{0, buffer3}});
wgpu::CommandBuffer commands; wgpu::CommandBuffer commands;
{ {
@ -592,6 +657,10 @@ TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
pass.SetBindGroup(0, bindGroup2); pass.SetBindGroup(0, bindGroup2);
pass.Dispatch(1); pass.Dispatch(1);
pass.SetPipeline(pipeline3);
pass.SetBindGroup(0, bindGroup3);
pass.Dispatch(1);
pass.EndPass(); pass.EndPass();
commands = encoder.Finish(); commands = encoder.Finish();
@ -601,14 +670,15 @@ TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
EXPECT_BUFFER_U32_RANGE_EQ(expected1.data(), buffer1, 0, kCount); EXPECT_BUFFER_U32_RANGE_EQ(expected1.data(), buffer1, 0, kCount);
EXPECT_BUFFER_U32_RANGE_EQ(expected2.data(), buffer2, 0, kCount); EXPECT_BUFFER_U32_RANGE_EQ(expected2.data(), buffer2, 0, kCount);
EXPECT_BUFFER_U32_RANGE_EQ(expected3.data(), buffer3, 0, kCount);
} }
// Test overridable constants with render pipeline // Test overridable constants with render pipeline
// Draw a triangle covering the render target, with vertex position and color values from // Draw a triangle covering the render target, with vertex position and color values from
// overridable constants // overridable constants
TEST_P(ShaderTests, OverridableConstantsRenderPipeline) { TEST_P(ShaderTests, OverridableConstantsRenderPipeline) {
// TODO(dawn:1137): D3D12 backend is unimplemented DAWN_TEST_UNSUPPORTED_IF(IsOpenGL());
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan() && !IsMetal()); DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES());
wgpu::ShaderModule vsModule = utils::CreateShaderModule(device, R"( wgpu::ShaderModule vsModule = utils::CreateShaderModule(device, R"(
[[override(1111)]] let xright: f32; [[override(1111)]] let xright: f32;

View File

@ -105,12 +105,12 @@ TEST_F(ComputePipelineOverridableConstantsValidationTest, ConstantsIdentifierLoo
TestCreatePipeline(constants); TestCreatePipeline(constants);
} }
{ {
// Valid: set the same constant twice // Error: set the same constant twice
std::vector<wgpu::ConstantEntry> constants{ std::vector<wgpu::ConstantEntry> constants{
{nullptr, "c0", 0}, {nullptr, "c0", 0},
{nullptr, "c0", 1}, {nullptr, "c0", 1},
}; };
TestCreatePipeline(constants); ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
} }
{ {
// Valid: find by constant numeric id // Valid: find by constant numeric id
@ -158,12 +158,12 @@ TEST_F(ComputePipelineOverridableConstantsValidationTest, UninitializedConstants
TestCreatePipeline(constants); TestCreatePipeline(constants);
} }
{ {
// Valid: all constants initialized (with duplicate initializations) // Error: duplicate initializations
std::vector<wgpu::ConstantEntry> constants{ std::vector<wgpu::ConstantEntry> constants{
{nullptr, "c0", false}, {nullptr, "c2", 1}, {nullptr, "c5", 1}, {nullptr, "c0", false}, {nullptr, "c2", 1}, {nullptr, "c5", 1},
{nullptr, "c8", 1}, {nullptr, "c2", 2}, {nullptr, "c8", 1}, {nullptr, "c2", 2},
}; };
TestCreatePipeline(constants); ASSERT_DEVICE_ERROR(TestCreatePipeline(constants));
} }
} }