From 2d41f8c1df45b1aa642b5149060a8ec9ad9a0c63 Mon Sep 17 00:00:00 2001 From: Ken Rockot Date: Wed, 21 Jul 2021 20:44:09 +0000 Subject: [PATCH] Enforce per-dimension dispatch size limits Note that this is for direct dispatch calls only. Indirect dispatch calls are still not validated. Bug: dawn:1006 Change-Id: I061c15208a01dfb803923823ba4afd38667cad22 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/59122 Reviewed-by: Ryan Harrison Reviewed-by: Austin Eng Commit-Queue: Ken Rockot --- src/dawn_native/ComputePassEncoder.cpp | 15 +++++ .../validation/ComputeValidationTests.cpp | 67 ++++++++++++++++++- 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/src/dawn_native/ComputePassEncoder.cpp b/src/dawn_native/ComputePassEncoder.cpp index 88c7e69950..dcc5df83c3 100644 --- a/src/dawn_native/ComputePassEncoder.cpp +++ b/src/dawn_native/ComputePassEncoder.cpp @@ -25,6 +25,18 @@ namespace dawn_native { + namespace { + + MaybeError ValidatePerDimensionDispatchSizeLimit(uint32_t size) { + if (size > kMaxComputePerDimensionDispatchSize) { + return DAWN_VALIDATION_ERROR("Dispatch size exceeds defined limits"); + } + + return {}; + } + + } // namespace + ComputePassEncoder::ComputePassEncoder(DeviceBase* device, CommandEncoder* commandEncoder, EncodingContext* encodingContext) @@ -63,6 +75,9 @@ namespace dawn_native { mEncodingContext->TryEncode(this, [&](CommandAllocator* allocator) -> MaybeError { if (IsValidationEnabled()) { DAWN_TRY(mCommandBufferState.ValidateCanDispatch()); + DAWN_TRY(ValidatePerDimensionDispatchSizeLimit(x)); + DAWN_TRY(ValidatePerDimensionDispatchSizeLimit(y)); + DAWN_TRY(ValidatePerDimensionDispatchSizeLimit(z)); } // Record the synchronization scope for Dispatch, which is just the current bindgroups. diff --git a/src/tests/unittests/validation/ComputeValidationTests.cpp b/src/tests/unittests/validation/ComputeValidationTests.cpp index 8135997c99..6d66b87d9d 100644 --- a/src/tests/unittests/validation/ComputeValidationTests.cpp +++ b/src/tests/unittests/validation/ComputeValidationTests.cpp @@ -12,9 +12,72 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "common/Constants.h" #include "tests/unittests/validation/ValidationTest.h" - -class ComputeValidationTest : public ValidationTest {}; +#include "utils/WGPUHelpers.h" // TODO(cwallez@chromium.org): Add a regression test for Disptach validation trying to acces the // input state. + +class ComputeValidationTest : public ValidationTest { + protected: + void SetUp() override { + ValidationTest::SetUp(); + + wgpu::ShaderModule computeModule = utils::CreateShaderModule(device, R"( + [[stage(compute), workgroup_size(1)]] fn main() { + })"); + + // Set up compute pipeline + wgpu::PipelineLayout pl = utils::MakeBasicPipelineLayout(device, nullptr); + + wgpu::ComputePipelineDescriptor csDesc; + csDesc.layout = pl; + csDesc.compute.module = computeModule; + csDesc.compute.entryPoint = "main"; + pipeline = device.CreateComputePipeline(&csDesc); + } + + void TestDispatch(uint32_t x, uint32_t y, uint32_t z) { + wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); + pass.SetPipeline(pipeline); + pass.Dispatch(x, y, z); + pass.EndPass(); + encoder.Finish(); + } + + wgpu::ComputePipeline pipeline; +}; + +// Check that 1x1x1 dispatch is OK. +TEST_F(ComputeValidationTest, PerDimensionDispatchSizeLimits_SmallestValid) { + TestDispatch(1, 1, 1); +} + +// Check that the largest allowed dispatch is OK. +TEST_F(ComputeValidationTest, PerDimensionDispatchSizeLimits_LargestValid) { + constexpr uint32_t kMax = kMaxComputePerDimensionDispatchSize; + TestDispatch(kMax, kMax, kMax); +} + +// Check that exceeding the maximum on the X dimension results in validation failure. +TEST_F(ComputeValidationTest, PerDimensionDispatchSizeLimits_InvalidX) { + ASSERT_DEVICE_ERROR(TestDispatch(kMaxComputePerDimensionDispatchSize + 1, 1, 1)); +} + +// Check that exceeding the maximum on the Y dimension results in validation failure. +TEST_F(ComputeValidationTest, PerDimensionDispatchSizeLimits_InvalidY) { + ASSERT_DEVICE_ERROR(TestDispatch(1, kMaxComputePerDimensionDispatchSize + 1, 1)); +} + +// Check that exceeding the maximum on the Z dimension results in validation failure. +TEST_F(ComputeValidationTest, PerDimensionDispatchSizeLimits_InvalidZ) { + ASSERT_DEVICE_ERROR(TestDispatch(1, 1, kMaxComputePerDimensionDispatchSize + 1)); +} + +// Check that exceeding the maximum on all dimensions results in validation failure. +TEST_F(ComputeValidationTest, PerDimensionDispatchSizeLimits_InvalidAll) { + constexpr uint32_t kMax = kMaxComputePerDimensionDispatchSize; + ASSERT_DEVICE_ERROR(TestDispatch(kMax + 1, kMax + 1, kMax + 1)); +}