mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-07-14 17:16:01 +00:00
D3D12: Support [[num_workgroups]] for Dispatch
This patch implements [[num_workgroups]] on the API side for Dispatch() calls by setting num_workgroups.xyz as root constants. This patch also adds a temporary validation that on D3D12 backend using a compute pipeline with [[num_workgroups]] in a DispatchIndirect call is not supported. BUG=dawn:839 TEST=dawn_end2end_tests Change-Id: Iaee2ffd162e9420e4e80944fbb222f10a4600c6a Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/66580 Reviewed-by: Corentin Wallez <cwallez@chromium.org> Reviewed-by: Austin Eng <enga@chromium.org> Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
This commit is contained in:
parent
a5c0c8f6be
commit
1349ca182e
@ -677,6 +677,8 @@ namespace dawn_native {
|
|||||||
metadata->localWorkgroupSize.x = entryPoint.workgroup_size_x;
|
metadata->localWorkgroupSize.x = entryPoint.workgroup_size_x;
|
||||||
metadata->localWorkgroupSize.y = entryPoint.workgroup_size_y;
|
metadata->localWorkgroupSize.y = entryPoint.workgroup_size_y;
|
||||||
metadata->localWorkgroupSize.z = entryPoint.workgroup_size_z;
|
metadata->localWorkgroupSize.z = entryPoint.workgroup_size_z;
|
||||||
|
|
||||||
|
metadata->usesNumWorkgroups = entryPoint.num_workgroups_used;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (metadata->stage == SingleShaderStage::Vertex) {
|
if (metadata->stage == SingleShaderStage::Vertex) {
|
||||||
|
@ -204,6 +204,8 @@ namespace dawn_native {
|
|||||||
|
|
||||||
// Store overridableConstants from tint program
|
// Store overridableConstants from tint program
|
||||||
std::unordered_map<std::string, OverridableConstant> overridableConstants;
|
std::unordered_map<std::string, OverridableConstant> overridableConstants;
|
||||||
|
|
||||||
|
bool usesNumWorkgroups = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
class ShaderModuleBase : public ApiObjectBase, public CachedObject {
|
class ShaderModuleBase : public ApiObjectBase, public CachedObject {
|
||||||
|
@ -254,6 +254,18 @@ namespace dawn_native { namespace d3d12 {
|
|||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void RecordNumWorkgroupsForDispatch(ID3D12GraphicsCommandList* commandList,
|
||||||
|
ComputePipeline* pipeline,
|
||||||
|
DispatchCmd* dispatch) {
|
||||||
|
if (!pipeline->UsesNumWorkgroups()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
PipelineLayout* layout = ToBackend(pipeline->GetLayout());
|
||||||
|
commandList->SetComputeRoot32BitConstants(layout->GetNumWorkgroupsParameterIndex(), 3,
|
||||||
|
dispatch, 0);
|
||||||
|
}
|
||||||
|
|
||||||
// Records the necessary barriers for a synchronization scope using the resource usage
|
// Records the necessary barriers for a synchronization scope using the resource usage
|
||||||
// data pre-computed in the frontend. Also performs lazy initialization if required.
|
// data pre-computed in the frontend. Also performs lazy initialization if required.
|
||||||
// Returns whether any UAV are used in the synchronization scope.
|
// Returns whether any UAV are used in the synchronization scope.
|
||||||
@ -1030,6 +1042,7 @@ namespace dawn_native { namespace d3d12 {
|
|||||||
ID3D12GraphicsCommandList* commandList = commandContext->GetCommandList();
|
ID3D12GraphicsCommandList* commandList = commandContext->GetCommandList();
|
||||||
|
|
||||||
Command type;
|
Command type;
|
||||||
|
ComputePipeline* lastPipeline = nullptr;
|
||||||
while (mCommands.NextCommandId(&type)) {
|
while (mCommands.NextCommandId(&type)) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case Command::Dispatch: {
|
case Command::Dispatch: {
|
||||||
@ -1045,6 +1058,7 @@ namespace dawn_native { namespace d3d12 {
|
|||||||
resourceUsages.dispatchUsages[currentDispatch]);
|
resourceUsages.dispatchUsages[currentDispatch]);
|
||||||
DAWN_TRY(bindingTracker->Apply(commandContext));
|
DAWN_TRY(bindingTracker->Apply(commandContext));
|
||||||
|
|
||||||
|
RecordNumWorkgroupsForDispatch(commandList, lastPipeline, dispatch);
|
||||||
commandList->Dispatch(dispatch->x, dispatch->y, dispatch->z);
|
commandList->Dispatch(dispatch->x, dispatch->y, dispatch->z);
|
||||||
currentDispatch++;
|
currentDispatch++;
|
||||||
break;
|
break;
|
||||||
@ -1052,6 +1066,14 @@ namespace dawn_native { namespace d3d12 {
|
|||||||
|
|
||||||
case Command::DispatchIndirect: {
|
case Command::DispatchIndirect: {
|
||||||
DispatchIndirectCmd* dispatch = mCommands.NextCommand<DispatchIndirectCmd>();
|
DispatchIndirectCmd* dispatch = mCommands.NextCommand<DispatchIndirectCmd>();
|
||||||
|
|
||||||
|
// TODO(dawn:839): support [[num_workgroups]] for DispatchIndirect calls
|
||||||
|
if (lastPipeline->UsesNumWorkgroups()) {
|
||||||
|
return DAWN_VALIDATION_ERROR(
|
||||||
|
"Using a compute pipeline with [[num_workgroups]] in a "
|
||||||
|
"DispatchIndirect call is not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
Buffer* buffer = ToBackend(dispatch->indirectBuffer.Get());
|
Buffer* buffer = ToBackend(dispatch->indirectBuffer.Get());
|
||||||
|
|
||||||
TransitionAndClearForSyncScope(commandContext,
|
TransitionAndClearForSyncScope(commandContext,
|
||||||
@ -1078,6 +1100,7 @@ namespace dawn_native { namespace d3d12 {
|
|||||||
commandList->SetPipelineState(pipeline->GetPipelineState());
|
commandList->SetPipelineState(pipeline->GetPipelineState());
|
||||||
|
|
||||||
bindingTracker->OnSetPipeline(pipeline);
|
bindingTracker->OnSetPipeline(pipeline);
|
||||||
|
lastPipeline = pipeline;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,4 +84,8 @@ namespace dawn_native { namespace d3d12 {
|
|||||||
CreateComputePipelineAsyncTask::RunAsync(std::move(asyncTask));
|
CreateComputePipelineAsyncTask::RunAsync(std::move(asyncTask));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ComputePipeline::UsesNumWorkgroups() const {
|
||||||
|
return GetStage(SingleShaderStage::Compute).metadata->usesNumWorkgroups;
|
||||||
|
}
|
||||||
|
|
||||||
}} // namespace dawn_native::d3d12
|
}} // namespace dawn_native::d3d12
|
||||||
|
@ -40,6 +40,8 @@ namespace dawn_native { namespace d3d12 {
|
|||||||
// Dawn API
|
// Dawn API
|
||||||
void SetLabelImpl() override;
|
void SetLabelImpl() override;
|
||||||
|
|
||||||
|
bool UsesNumWorkgroups() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
~ComputePipeline() override;
|
~ComputePipeline() override;
|
||||||
using ComputePipelineBase::ComputePipelineBase;
|
using ComputePipelineBase::ComputePipelineBase;
|
||||||
|
@ -174,6 +174,21 @@ namespace dawn_native { namespace d3d12 {
|
|||||||
// would need to be updated often
|
// would need to be updated often
|
||||||
rootParameters.emplace_back(indexOffsetConstants);
|
rootParameters.emplace_back(indexOffsetConstants);
|
||||||
|
|
||||||
|
// Always allocate 3 constants for num_workgroups_x, num_workgroups_y and num_workgroups_z
|
||||||
|
// for Dispatch calls
|
||||||
|
// NOTE: We should consider delaying root signature creation until we know how many values
|
||||||
|
// we need
|
||||||
|
D3D12_ROOT_PARAMETER numWorkgroupsConstants{};
|
||||||
|
numWorkgroupsConstants.ShaderVisibility = D3D12_SHADER_VISIBILITY_ALL;
|
||||||
|
numWorkgroupsConstants.ParameterType = D3D12_ROOT_PARAMETER_TYPE_32BIT_CONSTANTS;
|
||||||
|
numWorkgroupsConstants.Constants.Num32BitValues = 3;
|
||||||
|
numWorkgroupsConstants.Constants.RegisterSpace = GetNumWorkgroupsRegisterSpace();
|
||||||
|
numWorkgroupsConstants.Constants.ShaderRegister = GetNumWorkgroupsShaderRegister();
|
||||||
|
mNumWorkgroupsParamterIndex = rootParameters.size();
|
||||||
|
// NOTE: We should consider moving this entry to earlier in the root signature since
|
||||||
|
// dispatch sizes would need to be updated often
|
||||||
|
rootParameters.emplace_back(numWorkgroupsConstants);
|
||||||
|
|
||||||
D3D12_ROOT_SIGNATURE_DESC rootSignatureDescriptor;
|
D3D12_ROOT_SIGNATURE_DESC rootSignatureDescriptor;
|
||||||
rootSignatureDescriptor.NumParameters = rootParameters.size();
|
rootSignatureDescriptor.NumParameters = rootParameters.size();
|
||||||
rootSignatureDescriptor.pParameters = rootParameters.data();
|
rootSignatureDescriptor.pParameters = rootParameters.data();
|
||||||
@ -230,7 +245,7 @@ namespace dawn_native { namespace d3d12 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
uint32_t PipelineLayout::GetFirstIndexOffsetRegisterSpace() const {
|
uint32_t PipelineLayout::GetFirstIndexOffsetRegisterSpace() const {
|
||||||
return kReservedRegisterSpace;
|
return kFirstIndexOffsetRegisterSpace;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t PipelineLayout::GetFirstIndexOffsetShaderRegister() const {
|
uint32_t PipelineLayout::GetFirstIndexOffsetShaderRegister() const {
|
||||||
@ -240,4 +255,16 @@ namespace dawn_native { namespace d3d12 {
|
|||||||
uint32_t PipelineLayout::GetFirstIndexOffsetParameterIndex() const {
|
uint32_t PipelineLayout::GetFirstIndexOffsetParameterIndex() const {
|
||||||
return mFirstIndexOffsetParameterIndex;
|
return mFirstIndexOffsetParameterIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint32_t PipelineLayout::GetNumWorkgroupsRegisterSpace() const {
|
||||||
|
return kNumWorkgroupsRegisterSpace;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t PipelineLayout::GetNumWorkgroupsShaderRegister() const {
|
||||||
|
return kNumWorkgroupsBaseRegister;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t PipelineLayout::GetNumWorkgroupsParameterIndex() const {
|
||||||
|
return mNumWorkgroupsParamterIndex;
|
||||||
|
}
|
||||||
}} // namespace dawn_native::d3d12
|
}} // namespace dawn_native::d3d12
|
||||||
|
@ -26,6 +26,9 @@ namespace dawn_native { namespace d3d12 {
|
|||||||
// We reserve a register space that a user cannot use.
|
// We reserve a register space that a user cannot use.
|
||||||
static constexpr uint32_t kReservedRegisterSpace = kMaxBindGroups + 1;
|
static constexpr uint32_t kReservedRegisterSpace = kMaxBindGroups + 1;
|
||||||
static constexpr uint32_t kFirstOffsetInfoBaseRegister = 0;
|
static constexpr uint32_t kFirstOffsetInfoBaseRegister = 0;
|
||||||
|
static constexpr uint32_t kFirstIndexOffsetRegisterSpace = kReservedRegisterSpace;
|
||||||
|
static constexpr uint32_t kNumWorkgroupsRegisterSpace = kReservedRegisterSpace + 1;
|
||||||
|
static constexpr uint32_t kNumWorkgroupsBaseRegister = 0;
|
||||||
|
|
||||||
class Device;
|
class Device;
|
||||||
|
|
||||||
@ -46,6 +49,10 @@ namespace dawn_native { namespace d3d12 {
|
|||||||
uint32_t GetFirstIndexOffsetShaderRegister() const;
|
uint32_t GetFirstIndexOffsetShaderRegister() const;
|
||||||
uint32_t GetFirstIndexOffsetParameterIndex() const;
|
uint32_t GetFirstIndexOffsetParameterIndex() const;
|
||||||
|
|
||||||
|
uint32_t GetNumWorkgroupsRegisterSpace() const;
|
||||||
|
uint32_t GetNumWorkgroupsShaderRegister() const;
|
||||||
|
uint32_t GetNumWorkgroupsParameterIndex() const;
|
||||||
|
|
||||||
ID3D12RootSignature* GetRootSignature() const;
|
ID3D12RootSignature* GetRootSignature() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -59,6 +66,7 @@ namespace dawn_native { namespace d3d12 {
|
|||||||
kMaxBindGroups>
|
kMaxBindGroups>
|
||||||
mDynamicRootParameterIndices;
|
mDynamicRootParameterIndices;
|
||||||
uint32_t mFirstIndexOffsetParameterIndex;
|
uint32_t mFirstIndexOffsetParameterIndex;
|
||||||
|
uint32_t mNumWorkgroupsParamterIndex;
|
||||||
ComPtr<ID3D12RootSignature> mRootSignature;
|
ComPtr<ID3D12RootSignature> mRootSignature;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -106,6 +106,9 @@ namespace dawn_native { namespace d3d12 {
|
|||||||
tint::transform::BindingRemapper::BindingPoints bindingPoints;
|
tint::transform::BindingRemapper::BindingPoints bindingPoints;
|
||||||
tint::transform::BindingRemapper::AccessControls accessControls;
|
tint::transform::BindingRemapper::AccessControls accessControls;
|
||||||
bool isRobustnessEnabled;
|
bool isRobustnessEnabled;
|
||||||
|
bool usesNumWorkgroups;
|
||||||
|
uint32_t numWorkgroupsRegisterSpace;
|
||||||
|
uint32_t numWorkgroupsShaderRegister;
|
||||||
|
|
||||||
// FXC/DXC common inputs
|
// FXC/DXC common inputs
|
||||||
bool disableWorkgroupInit;
|
bool disableWorkgroupInit;
|
||||||
@ -125,7 +128,7 @@ namespace dawn_native { namespace d3d12 {
|
|||||||
uint32_t compileFlags,
|
uint32_t compileFlags,
|
||||||
const Device* device,
|
const Device* device,
|
||||||
const tint::Program* program,
|
const tint::Program* program,
|
||||||
const BindingInfoArray& moduleBindingInfo) {
|
const EntryPointMetadata& entryPoint) {
|
||||||
Compiler compiler;
|
Compiler compiler;
|
||||||
uint64_t dxcVersion = 0;
|
uint64_t dxcVersion = 0;
|
||||||
if (device->IsToggleEnabled(Toggle::UseDXC)) {
|
if (device->IsToggleEnabled(Toggle::UseDXC)) {
|
||||||
@ -145,6 +148,7 @@ namespace dawn_native { namespace d3d12 {
|
|||||||
// Tint AST to make the "bindings" decoration match the offset chosen by
|
// Tint AST to make the "bindings" decoration match the offset chosen by
|
||||||
// d3d12::BindGroupLayout so that Tint produces HLSL with the correct registers
|
// d3d12::BindGroupLayout so that Tint produces HLSL with the correct registers
|
||||||
// assigned to each interface variable.
|
// assigned to each interface variable.
|
||||||
|
const BindingInfoArray& moduleBindingInfo = entryPoint.bindings;
|
||||||
for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
|
for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
|
||||||
const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
|
const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
|
||||||
const auto& groupBindingInfo = moduleBindingInfo[group];
|
const auto& groupBindingInfo = moduleBindingInfo[group];
|
||||||
@ -189,6 +193,9 @@ namespace dawn_native { namespace d3d12 {
|
|||||||
request.isRobustnessEnabled = device->IsRobustnessEnabled();
|
request.isRobustnessEnabled = device->IsRobustnessEnabled();
|
||||||
request.disableWorkgroupInit =
|
request.disableWorkgroupInit =
|
||||||
device->IsToggleEnabled(Toggle::DisableWorkgroupInit);
|
device->IsToggleEnabled(Toggle::DisableWorkgroupInit);
|
||||||
|
request.usesNumWorkgroups = entryPoint.usesNumWorkgroups;
|
||||||
|
request.numWorkgroupsShaderRegister = layout->GetNumWorkgroupsShaderRegister();
|
||||||
|
request.numWorkgroupsRegisterSpace = layout->GetNumWorkgroupsRegisterSpace();
|
||||||
request.fxcVersion = compiler == Compiler::FXC ? GetD3DCompilerVersion() : 0;
|
request.fxcVersion = compiler == Compiler::FXC ? GetD3DCompilerVersion() : 0;
|
||||||
request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0;
|
request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0;
|
||||||
request.deviceInfo = &device->GetDeviceInfo();
|
request.deviceInfo = &device->GetDeviceInfo();
|
||||||
@ -234,6 +241,10 @@ namespace dawn_native { namespace d3d12 {
|
|||||||
stream << " accessControls=";
|
stream << " accessControls=";
|
||||||
Serialize(stream, accessControls);
|
Serialize(stream, accessControls);
|
||||||
|
|
||||||
|
stream << " useNumWorkgroups=" << usesNumWorkgroups;
|
||||||
|
stream << " numWorkgroupsRegisterSpace=" << numWorkgroupsRegisterSpace;
|
||||||
|
stream << " numWorkgroupsShaderRegister=" << numWorkgroupsShaderRegister;
|
||||||
|
|
||||||
stream << " shaderModel=" << deviceInfo->shaderModel;
|
stream << " shaderModel=" << deviceInfo->shaderModel;
|
||||||
stream << " disableWorkgroupInit=" << disableWorkgroupInit;
|
stream << " disableWorkgroupInit=" << disableWorkgroupInit;
|
||||||
stream << " isRobustnessEnabled=" << isRobustnessEnabled;
|
stream << " isRobustnessEnabled=" << isRobustnessEnabled;
|
||||||
@ -423,6 +434,10 @@ namespace dawn_native { namespace d3d12 {
|
|||||||
|
|
||||||
tint::writer::hlsl::Options options;
|
tint::writer::hlsl::Options options;
|
||||||
options.disable_workgroup_init = request.disableWorkgroupInit;
|
options.disable_workgroup_init = request.disableWorkgroupInit;
|
||||||
|
if (request.usesNumWorkgroups) {
|
||||||
|
options.root_constant_binding_point.group = request.numWorkgroupsRegisterSpace;
|
||||||
|
options.root_constant_binding_point.binding = request.numWorkgroupsShaderRegister;
|
||||||
|
}
|
||||||
auto result = tint::writer::hlsl::Generate(&transformedProgram, options);
|
auto result = tint::writer::hlsl::Generate(&transformedProgram, options);
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
errorStream << "Generator: " << result.error << std::endl;
|
errorStream << "Generator: " << result.error << std::endl;
|
||||||
@ -547,9 +562,9 @@ namespace dawn_native { namespace d3d12 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ShaderCompilationRequest request;
|
ShaderCompilationRequest request;
|
||||||
DAWN_TRY_ASSIGN(request, ShaderCompilationRequest::Create(
|
DAWN_TRY_ASSIGN(request, ShaderCompilationRequest::Create(entryPointName, stage, layout,
|
||||||
entryPointName, stage, layout, compileFlags, device, program,
|
compileFlags, device, program,
|
||||||
GetEntryPoint(entryPointName).bindings));
|
GetEntryPoint(entryPointName)));
|
||||||
|
|
||||||
PersistentCacheKey shaderCacheKey;
|
PersistentCacheKey shaderCacheKey;
|
||||||
DAWN_TRY_ASSIGN(shaderCacheKey, request.CreateCacheKey());
|
DAWN_TRY_ASSIGN(shaderCacheKey, request.CreateCacheKey());
|
||||||
|
@ -26,9 +26,30 @@ class ComputeDispatchTests : public DawnTest {
|
|||||||
DawnTest::SetUp();
|
DawnTest::SetUp();
|
||||||
|
|
||||||
// Write workgroup number into the output buffer if we saw the biggest dispatch
|
// Write workgroup number into the output buffer if we saw the biggest dispatch
|
||||||
// This is a workaround since D3D12 doesn't have gl_NumWorkGroups
|
|
||||||
// To make sure the dispatch was not called, write maximum u32 value for 0 dispatches
|
// To make sure the dispatch was not called, write maximum u32 value for 0 dispatches
|
||||||
wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
|
wgpu::ShaderModule moduleForDispatch = utils::CreateShaderModule(device, R"(
|
||||||
|
[[block]] struct OutputBuf {
|
||||||
|
workGroups : vec3<u32>;
|
||||||
|
};
|
||||||
|
|
||||||
|
[[group(0), binding(0)]] var<storage, read_write> output : OutputBuf;
|
||||||
|
|
||||||
|
[[stage(compute), workgroup_size(1, 1, 1)]]
|
||||||
|
fn main([[builtin(global_invocation_id)]] GlobalInvocationID : vec3<u32>,
|
||||||
|
[[builtin(num_workgroups)]] dispatch : vec3<u32>) {
|
||||||
|
if (dispatch.x == 0u || dispatch.y == 0u || dispatch.z == 0u) {
|
||||||
|
output.workGroups = vec3<u32>(0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (all(GlobalInvocationID == dispatch - vec3<u32>(1u, 1u, 1u))) {
|
||||||
|
output.workGroups = dispatch;
|
||||||
|
}
|
||||||
|
})");
|
||||||
|
|
||||||
|
// TODO(dawn:839): use moduleForDispatch for indirect dispatch tests when D3D12 supports
|
||||||
|
// [[num_workgroups]] for indirect dispatch.
|
||||||
|
wgpu::ShaderModule moduleForDispatchIndirect = utils::CreateShaderModule(device, R"(
|
||||||
[[block]] struct InputBuf {
|
[[block]] struct InputBuf {
|
||||||
expectedDispatch : vec3<u32>;
|
expectedDispatch : vec3<u32>;
|
||||||
};
|
};
|
||||||
@ -54,9 +75,12 @@ class ComputeDispatchTests : public DawnTest {
|
|||||||
})");
|
})");
|
||||||
|
|
||||||
wgpu::ComputePipelineDescriptor csDesc;
|
wgpu::ComputePipelineDescriptor csDesc;
|
||||||
csDesc.compute.module = module;
|
csDesc.compute.module = moduleForDispatch;
|
||||||
csDesc.compute.entryPoint = "main";
|
csDesc.compute.entryPoint = "main";
|
||||||
pipeline = device.CreateComputePipeline(&csDesc);
|
pipelineForDispatch = device.CreateComputePipeline(&csDesc);
|
||||||
|
|
||||||
|
csDesc.compute.module = moduleForDispatchIndirect;
|
||||||
|
pipelineForDispatchIndirect = device.CreateComputePipeline(&csDesc);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DirectTest(uint32_t x, uint32_t y, uint32_t z) {
|
void DirectTest(uint32_t x, uint32_t y, uint32_t z) {
|
||||||
@ -66,23 +90,18 @@ class ComputeDispatchTests : public DawnTest {
|
|||||||
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
|
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
|
||||||
kSentinelData);
|
kSentinelData);
|
||||||
|
|
||||||
std::initializer_list<uint32_t> expectedBufferData{x, y, z};
|
|
||||||
wgpu::Buffer expectedBuffer = utils::CreateBufferFromData<uint32_t>(
|
|
||||||
device, wgpu::BufferUsage::Uniform, expectedBufferData);
|
|
||||||
|
|
||||||
// Set up bind group and issue dispatch
|
// Set up bind group and issue dispatch
|
||||||
wgpu::BindGroup bindGroup =
|
wgpu::BindGroup bindGroup =
|
||||||
utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0),
|
utils::MakeBindGroup(device, pipelineForDispatch.GetBindGroupLayout(0),
|
||||||
{
|
{
|
||||||
{0, expectedBuffer, 0, 3 * sizeof(uint32_t)},
|
{0, dst, 0, 3 * sizeof(uint32_t)},
|
||||||
{1, dst, 0, 3 * sizeof(uint32_t)},
|
|
||||||
});
|
});
|
||||||
|
|
||||||
wgpu::CommandBuffer commands;
|
wgpu::CommandBuffer commands;
|
||||||
{
|
{
|
||||||
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
||||||
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
||||||
pass.SetPipeline(pipeline);
|
pass.SetPipeline(pipelineForDispatch);
|
||||||
pass.SetBindGroup(0, bindGroup);
|
pass.SetBindGroup(0, bindGroup);
|
||||||
pass.Dispatch(x, y, z);
|
pass.Dispatch(x, y, z);
|
||||||
pass.EndPass();
|
pass.EndPass();
|
||||||
@ -93,7 +112,7 @@ class ComputeDispatchTests : public DawnTest {
|
|||||||
queue.Submit(1, &commands);
|
queue.Submit(1, &commands);
|
||||||
|
|
||||||
std::vector<uint32_t> expected =
|
std::vector<uint32_t> expected =
|
||||||
x == 0 || y == 0 || z == 0 ? kSentinelData : expectedBufferData;
|
x == 0 || y == 0 || z == 0 ? kSentinelData : std::initializer_list<uint32_t>{x, y, z};
|
||||||
|
|
||||||
// Verify the dispatch got called if all group counts are not zero
|
// Verify the dispatch got called if all group counts are not zero
|
||||||
EXPECT_BUFFER_U32_RANGE_EQ(&expected[0], dst, 0, 3);
|
EXPECT_BUFFER_U32_RANGE_EQ(&expected[0], dst, 0, 3);
|
||||||
@ -118,7 +137,7 @@ class ComputeDispatchTests : public DawnTest {
|
|||||||
|
|
||||||
// Set up bind group and issue dispatch
|
// Set up bind group and issue dispatch
|
||||||
wgpu::BindGroup bindGroup =
|
wgpu::BindGroup bindGroup =
|
||||||
utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0),
|
utils::MakeBindGroup(device, pipelineForDispatchIndirect.GetBindGroupLayout(0),
|
||||||
{
|
{
|
||||||
{0, expectedBuffer, 0, 3 * sizeof(uint32_t)},
|
{0, expectedBuffer, 0, 3 * sizeof(uint32_t)},
|
||||||
{1, dst, 0, 3 * sizeof(uint32_t)},
|
{1, dst, 0, 3 * sizeof(uint32_t)},
|
||||||
@ -128,7 +147,7 @@ class ComputeDispatchTests : public DawnTest {
|
|||||||
{
|
{
|
||||||
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
||||||
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
||||||
pass.SetPipeline(pipeline);
|
pass.SetPipeline(pipelineForDispatchIndirect);
|
||||||
pass.SetBindGroup(0, bindGroup);
|
pass.SetBindGroup(0, bindGroup);
|
||||||
pass.DispatchIndirect(indirectBuffer, indirectOffset);
|
pass.DispatchIndirect(indirectBuffer, indirectOffset);
|
||||||
pass.EndPass();
|
pass.EndPass();
|
||||||
@ -153,7 +172,8 @@ class ComputeDispatchTests : public DawnTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
wgpu::ComputePipeline pipeline;
|
wgpu::ComputePipeline pipelineForDispatch;
|
||||||
|
wgpu::ComputePipeline pipelineForDispatchIndirect;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Test basic direct
|
// Test basic direct
|
||||||
|
Loading…
x
Reference in New Issue
Block a user