Validate workgroup size and storage requirements

We define hard limits on these attributes for compute stages. This
enforces them.

BUG: dawn:322
Change-Id: I9b279774e877b5d40d912cb9f812f23d61c20a42
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/56806
Commit-Queue: Ken Rockot <rockot@google.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
This commit is contained in:
Ken Rockot 2021-07-21 20:19:20 +00:00 committed by Dawn LUCI CQ
parent e50f8c65b7
commit 59668e95c7
3 changed files with 130 additions and 25 deletions

View File

@ -14,6 +14,7 @@
#include "dawn_native/ShaderModule.h" #include "dawn_native/ShaderModule.h"
#include "common/Constants.h"
#include "common/HashUtils.h" #include "common/HashUtils.h"
#include "common/VertexFormatUtils.h" #include "common/VertexFormatUtils.h"
#include "dawn_native/BindGroupLayout.h" #include "dawn_native/BindGroupLayout.h"
@ -903,6 +904,50 @@ namespace dawn_native {
DAWN_TRY_ASSIGN(metadata->stage, TintPipelineStageToShaderStage(entryPoint.stage)); DAWN_TRY_ASSIGN(metadata->stage, TintPipelineStageToShaderStage(entryPoint.stage));
if (metadata->stage == SingleShaderStage::Compute) { if (metadata->stage == SingleShaderStage::Compute) {
if (entryPoint.workgroup_size_x > kMaxComputeWorkgroupSizeX) {
errorStream << "Workgroup X dimension exceeds maximum allowed:"
<< entryPoint.workgroup_size_x << " > "
<< kMaxComputeWorkgroupSizeX;
return DAWN_VALIDATION_ERROR(errorStream.str());
}
if (entryPoint.workgroup_size_y > kMaxComputeWorkgroupSizeY) {
errorStream << "Workgroup Y dimension exceeds maximum allowed: "
<< entryPoint.workgroup_size_y << " > "
<< kMaxComputeWorkgroupSizeY;
return DAWN_VALIDATION_ERROR(errorStream.str());
}
if (entryPoint.workgroup_size_z > kMaxComputeWorkgroupSizeZ) {
errorStream << "Workgroup Z dimension exceeds maximum allowed: "
<< entryPoint.workgroup_size_z << " > "
<< kMaxComputeWorkgroupSizeZ;
return DAWN_VALIDATION_ERROR(errorStream.str());
}
// Dimensions have already been validated against their individual limits above.
// This assertion ensures that the product of such limited dimensions cannot
// possibly overflow a uint32_t.
static_assert(static_cast<uint64_t>(kMaxComputeWorkgroupSizeX) *
kMaxComputeWorkgroupSizeY * kMaxComputeWorkgroupSizeZ <=
std::numeric_limits<uint32_t>::max(),
"Per-dimension workgroup size limits are too high");
uint32_t num_invocations = entryPoint.workgroup_size_x *
entryPoint.workgroup_size_y *
entryPoint.workgroup_size_z;
if (num_invocations > kMaxComputeWorkgroupInvocations) {
errorStream << "Number of workgroup invocations exceeds maximum allowed: "
<< num_invocations << " > " << kMaxComputeWorkgroupInvocations;
return DAWN_VALIDATION_ERROR(errorStream.str());
}
const size_t workgroup_storage_size =
inspector.GetWorkgroupStorageSize(entryPoint.name);
if (workgroup_storage_size > kMaxComputeWorkgroupStorageSize) {
errorStream << "Workgroup shared storage size for " << entryPoint.name
<< " exceeds the maximum allowed: " << workgroup_storage_size
<< " > " << kMaxComputeWorkgroupStorageSize;
return DAWN_VALIDATION_ERROR(errorStream.str());
}
metadata->localWorkgroupSize.x = entryPoint.workgroup_size_x; metadata->localWorkgroupSize.x = entryPoint.workgroup_size_x;
metadata->localWorkgroupSize.y = entryPoint.workgroup_size_y; metadata->localWorkgroupSize.y = entryPoint.workgroup_size_y;
metadata->localWorkgroupSize.z = entryPoint.workgroup_size_z; metadata->localWorkgroupSize.z = entryPoint.workgroup_size_z;

View File

@ -17,7 +17,7 @@
#include "utils/WGPUHelpers.h" #include "utils/WGPUHelpers.h"
namespace { namespace {
constexpr uint32_t kTileSize = 64u; constexpr uint32_t kTileSize = 32u;
const std::string& kMatMulFloatHeader = R"( const std::string& kMatMulFloatHeader = R"(
[[block]] struct Uniforms { [[block]] struct Uniforms {
@ -62,18 +62,18 @@ namespace {
let RowPerThread : u32 = 4u; let RowPerThread : u32 = 4u;
let ColPerThread : u32 = 4u; let ColPerThread : u32 = 4u;
let TileAOuter : u32 = 64u; let TileAOuter : u32 = 32u;
let TileBOuter : u32 = 64u; let TileBOuter : u32 = 32u;
let TileInner : u32 = 64u;)"; let TileInner : u32 = 32u;)";
const std::string& kMatMulFloatSharedArray1D = R"( const std::string& kMatMulFloatSharedArray1D = R"(
var<workgroup> mm_Asub : array<f32, 4096>; var<workgroup> mm_Asub : array<f32, 1024>;
var<workgroup> mm_Bsub : array<f32, 4096>;)"; var<workgroup> mm_Bsub : array<f32, 1024>;)";
const std::string& kMatMulFloatSharedArray2D = R"( const std::string& kMatMulFloatSharedArray2D = R"(
var<workgroup> mm_Asub : array<array<f32, 64>, 64>; var<workgroup> mm_Asub : array<array<f32, 32>, 32>;
var<workgroup> mm_Bsub : array<array<f32, 64>, 64>;)"; var<workgroup> mm_Bsub : array<array<f32, 32>, 32>;)";
const std::string& kMatMulFloatBodyPart1 = R"( const std::string& kMatMulFloatBodyPart1 = R"(
[[stage(compute), workgroup_size(16, 16, 1)]] [[stage(compute), workgroup_size(8, 8, 1)]]
fn main([[builtin(local_invocation_id)]] local_id : vec3<u32>, fn main([[builtin(local_invocation_id)]] local_id : vec3<u32>,
[[builtin(global_invocation_id)]] global_id : vec3<u32>) { [[builtin(global_invocation_id)]] global_id : vec3<u32>) {
let tileRow : u32 = local_id.y * RowPerThread; let tileRow : u32 = local_id.y * RowPerThread;
@ -95,9 +95,9 @@ namespace {
acc[index] = 0.; acc[index] = 0.;
} }
let ColPerThreadA : u32 = TileInner / 16u; let ColPerThreadA : u32 = TileInner / 8u;
let tileColA : u32 = local_id.x * ColPerThreadA; let tileColA : u32 = local_id.x * ColPerThreadA;
let RowPerThreadB : u32 = TileInner / 16u; let RowPerThreadB : u32 = TileInner / 8u;
let tileRowB : u32 = local_id.y * RowPerThreadB; let tileRowB : u32 = local_id.y * RowPerThreadB;
// Loop over shared dimension. // Loop over shared dimension.
@ -229,17 +229,16 @@ namespace {
let RowPerThread : u32 = 4u; let RowPerThread : u32 = 4u;
let ColPerThread : u32 = 4u; let ColPerThread : u32 = 4u;
let TileAOuter : u32 = 64u; let TileOuter : u32 = 32u;
let TileBOuter : u32 = 64u; let TileInner : u32 = 32u;)";
let TileInner : u32 = 64u;)";
const std::string& kMatMulVec4SharedArray1D = R"( const std::string& kMatMulVec4SharedArray1D = R"(
var<workgroup> mm_Asub : array<vec4<f32>, 1024>; var<workgroup> mm_Asub : array<vec4<f32>, 256>;
var<workgroup> mm_Bsub : array<vec4<f32>, 1024>;)"; var<workgroup> mm_Bsub : array<vec4<f32>, 256>;)";
const std::string& kMatMulVec4SharedArray2D = R"( const std::string& kMatMulVec4SharedArray2D = R"(
var<workgroup> mm_Asub : array<array<vec4<f32>, 16>, 64>; var<workgroup> mm_Asub : array<array<vec4<f32>, 8>, 32>;
var<workgroup> mm_Bsub : array<array<vec4<f32>, 16>, 64>;)"; var<workgroup> mm_Bsub : array<array<vec4<f32>, 8>, 32>;)";
const std::string& kMatMulVec4BodyPart1 = R"( const std::string& kMatMulVec4BodyPart1 = R"(
[[stage(compute), workgroup_size(16, 16, 1)]] [[stage(compute), workgroup_size(8, 8, 1)]]
fn main([[builtin(local_invocation_id)]] local_id : vec3<u32>, fn main([[builtin(local_invocation_id)]] local_id : vec3<u32>,
[[builtin(global_invocation_id)]] global_id : vec3<u32>) { [[builtin(global_invocation_id)]] global_id : vec3<u32>) {
let tileRow : u32 = local_id.y * RowPerThread; let tileRow : u32 = local_id.y * RowPerThread;
@ -262,7 +261,7 @@ namespace {
} }
var globalColA : u32 = tileCol; var globalColA : u32 = tileCol;
let RowPerThreadB : u32 = TileInner / 16u; let RowPerThreadB : u32 = TileInner / 8u;
let tileRowB : u32 = local_id.y * RowPerThreadB; let tileRowB : u32 = local_id.y * RowPerThreadB;
// Loop over shared dimension. // Loop over shared dimension.
@ -281,7 +280,7 @@ namespace {
for (var innerRow : u32 = 0u; innerRow < RowPerThreadB; innerRow = innerRow + 1u) { for (var innerRow : u32 = 0u; innerRow < RowPerThreadB; innerRow = innerRow + 1u) {
let inputRow : u32 = tileRowB + innerRow; let inputRow : u32 = tileRowB + innerRow;
let inputCol : u32 = tileCol; let inputCol : u32 = tileCol;
let index : u32 = inputRow * TileBOuter / ColPerThread + inputCol; let index : u32 = inputRow * TileOuter / ColPerThread + inputCol;
mm_Bsub[index] = mm_readB(t * TileInner + inputRow, globalCol);; mm_Bsub[index] = mm_readB(t * TileInner + inputRow, globalCol);;
} }
@ -289,10 +288,10 @@ namespace {
// Compute acc values for a single thread. // Compute acc values for a single thread.
for (var k : u32 = 0u; k < TileInner / ColPerThread; k = k + 1u) { for (var k : u32 = 0u; k < TileInner / ColPerThread; k = k + 1u) {
BCached[0] = mm_Bsub[(k * ColPerThread) * (TileBOuter / ColPerThread) + tileCol]; BCached[0] = mm_Bsub[(k * ColPerThread) * (TileOuter / ColPerThread) + tileCol];
BCached[1] = mm_Bsub[(k * ColPerThread + 1u) * (TileBOuter / ColPerThread) + tileCol]; BCached[1] = mm_Bsub[(k * ColPerThread + 1u) * (TileOuter / ColPerThread) + tileCol];
BCached[2] = mm_Bsub[(k * ColPerThread + 2u) * (TileBOuter / ColPerThread) + tileCol]; BCached[2] = mm_Bsub[(k * ColPerThread + 2u) * (TileOuter / ColPerThread) + tileCol];
BCached[3] = mm_Bsub[(k * ColPerThread + 3u) * (TileBOuter / ColPerThread) + tileCol]; BCached[3] = mm_Bsub[(k * ColPerThread + 3u) * (TileOuter / ColPerThread) + tileCol];
for (var i : u32 = 0u; i < RowPerThread; i = i + 1u) { for (var i : u32 = 0u; i < RowPerThread; i = i + 1u) {
ACached = mm_Asub[(tileRow + i) * (TileInner / ColPerThread) + k];)"; ACached = mm_Asub[(tileRow + i) * (TileInner / ColPerThread) + k];)";

View File

@ -282,3 +282,64 @@ TEST_F(ShaderModuleValidationTest, MaximumShaderIOLocations) {
ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, fragmentShader.c_str())); ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, fragmentShader.c_str()));
} }
} }
// Tests that we validate workgroup size limits.
TEST_F(ShaderModuleValidationTest, ComputeWorkgroupSizeLimits) {
DAWN_SKIP_TEST_IF(!HasToggleEnabled("use_tint_generator"));
auto MakeShaderWithWorkgroupSize = [this](uint32_t x, uint32_t y, uint32_t z) {
std::ostringstream ss;
ss << "[[stage(compute), workgroup_size(" << x << "," << y << "," << z
<< ")]] fn main() {}";
utils::CreateShaderModule(device, ss.str().c_str());
};
MakeShaderWithWorkgroupSize(1, 1, 1);
MakeShaderWithWorkgroupSize(kMaxComputeWorkgroupSizeX, 1, 1);
MakeShaderWithWorkgroupSize(1, kMaxComputeWorkgroupSizeY, 1);
MakeShaderWithWorkgroupSize(1, 1, kMaxComputeWorkgroupSizeZ);
ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupSize(kMaxComputeWorkgroupSizeX + 1, 1, 1));
ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupSize(1, kMaxComputeWorkgroupSizeY + 1, 1));
ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupSize(1, 1, kMaxComputeWorkgroupSizeZ + 1));
// No individual dimension exceeds its limit, but the combined size should definitely exceed the
// total invocation limit.
ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupSize(
kMaxComputeWorkgroupSizeX, kMaxComputeWorkgroupSizeY, kMaxComputeWorkgroupSizeZ));
}
// Tests that we validate workgroup storage size limits.
TEST_F(ShaderModuleValidationTest, ComputeWorkgroupStorageSizeLimits) {
DAWN_SKIP_TEST_IF(!HasToggleEnabled("use_tint_generator"));
constexpr uint32_t kVec4Size = 16;
constexpr uint32_t kMaxVec4Count = kMaxComputeWorkgroupStorageSize / kVec4Size;
constexpr uint32_t kMat4Size = 64;
constexpr uint32_t kMaxMat4Count = kMaxComputeWorkgroupStorageSize / kMat4Size;
auto MakeShaderWithWorkgroupStorage = [this](uint32_t vec4_count, uint32_t mat4_count) {
std::ostringstream ss;
std::ostringstream body;
if (vec4_count > 0) {
ss << "var<workgroup> vec4_data: array<vec4<f32>, " << vec4_count << ">;";
body << "ignore(vec4_data);";
}
if (mat4_count > 0) {
ss << "var<workgroup> mat4_data: array<mat4x4<f32>, " << mat4_count << ">;";
body << "ignore(mat4_data);";
}
ss << "[[stage(compute), workgroup_size(1)]] fn main() { " << body.str() << " }";
utils::CreateShaderModule(device, ss.str().c_str());
};
MakeShaderWithWorkgroupStorage(1, 1);
MakeShaderWithWorkgroupStorage(kMaxVec4Count, 0);
MakeShaderWithWorkgroupStorage(0, kMaxMat4Count);
MakeShaderWithWorkgroupStorage(kMaxVec4Count - 4, 1);
MakeShaderWithWorkgroupStorage(4, kMaxMat4Count - 1);
ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupStorage(kMaxVec4Count + 1, 0));
ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupStorage(kMaxVec4Count - 3, 1));
ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupStorage(0, kMaxMat4Count + 1));
ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupStorage(4, kMaxMat4Count));
}