Validate workgroup size and storage requirements
We define hard limits on these attributes for compute stages. This enforces them. BUG: dawn:322 Change-Id: I9b279774e877b5d40d912cb9f812f23d61c20a42 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/56806 Commit-Queue: Ken Rockot <rockot@google.com> Reviewed-by: Corentin Wallez <cwallez@chromium.org>
This commit is contained in:
parent
e50f8c65b7
commit
59668e95c7
|
@ -14,6 +14,7 @@
|
||||||
|
|
||||||
#include "dawn_native/ShaderModule.h"
|
#include "dawn_native/ShaderModule.h"
|
||||||
|
|
||||||
|
#include "common/Constants.h"
|
||||||
#include "common/HashUtils.h"
|
#include "common/HashUtils.h"
|
||||||
#include "common/VertexFormatUtils.h"
|
#include "common/VertexFormatUtils.h"
|
||||||
#include "dawn_native/BindGroupLayout.h"
|
#include "dawn_native/BindGroupLayout.h"
|
||||||
|
@ -903,6 +904,50 @@ namespace dawn_native {
|
||||||
DAWN_TRY_ASSIGN(metadata->stage, TintPipelineStageToShaderStage(entryPoint.stage));
|
DAWN_TRY_ASSIGN(metadata->stage, TintPipelineStageToShaderStage(entryPoint.stage));
|
||||||
|
|
||||||
if (metadata->stage == SingleShaderStage::Compute) {
|
if (metadata->stage == SingleShaderStage::Compute) {
|
||||||
|
if (entryPoint.workgroup_size_x > kMaxComputeWorkgroupSizeX) {
|
||||||
|
errorStream << "Workgroup X dimension exceeds maximum allowed:"
|
||||||
|
<< entryPoint.workgroup_size_x << " > "
|
||||||
|
<< kMaxComputeWorkgroupSizeX;
|
||||||
|
return DAWN_VALIDATION_ERROR(errorStream.str());
|
||||||
|
}
|
||||||
|
if (entryPoint.workgroup_size_y > kMaxComputeWorkgroupSizeY) {
|
||||||
|
errorStream << "Workgroup Y dimension exceeds maximum allowed: "
|
||||||
|
<< entryPoint.workgroup_size_y << " > "
|
||||||
|
<< kMaxComputeWorkgroupSizeY;
|
||||||
|
return DAWN_VALIDATION_ERROR(errorStream.str());
|
||||||
|
}
|
||||||
|
if (entryPoint.workgroup_size_z > kMaxComputeWorkgroupSizeZ) {
|
||||||
|
errorStream << "Workgroup Z dimension exceeds maximum allowed: "
|
||||||
|
<< entryPoint.workgroup_size_z << " > "
|
||||||
|
<< kMaxComputeWorkgroupSizeZ;
|
||||||
|
return DAWN_VALIDATION_ERROR(errorStream.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dimensions have already been validated against their individual limits above.
|
||||||
|
// This assertion ensures that the product of such limited dimensions cannot
|
||||||
|
// possibly overflow a uint32_t.
|
||||||
|
static_assert(static_cast<uint64_t>(kMaxComputeWorkgroupSizeX) *
|
||||||
|
kMaxComputeWorkgroupSizeY * kMaxComputeWorkgroupSizeZ <=
|
||||||
|
std::numeric_limits<uint32_t>::max(),
|
||||||
|
"Per-dimension workgroup size limits are too high");
|
||||||
|
uint32_t num_invocations = entryPoint.workgroup_size_x *
|
||||||
|
entryPoint.workgroup_size_y *
|
||||||
|
entryPoint.workgroup_size_z;
|
||||||
|
if (num_invocations > kMaxComputeWorkgroupInvocations) {
|
||||||
|
errorStream << "Number of workgroup invocations exceeds maximum allowed: "
|
||||||
|
<< num_invocations << " > " << kMaxComputeWorkgroupInvocations;
|
||||||
|
return DAWN_VALIDATION_ERROR(errorStream.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t workgroup_storage_size =
|
||||||
|
inspector.GetWorkgroupStorageSize(entryPoint.name);
|
||||||
|
if (workgroup_storage_size > kMaxComputeWorkgroupStorageSize) {
|
||||||
|
errorStream << "Workgroup shared storage size for " << entryPoint.name
|
||||||
|
<< " exceeds the maximum allowed: " << workgroup_storage_size
|
||||||
|
<< " > " << kMaxComputeWorkgroupStorageSize;
|
||||||
|
return DAWN_VALIDATION_ERROR(errorStream.str());
|
||||||
|
}
|
||||||
|
|
||||||
metadata->localWorkgroupSize.x = entryPoint.workgroup_size_x;
|
metadata->localWorkgroupSize.x = entryPoint.workgroup_size_x;
|
||||||
metadata->localWorkgroupSize.y = entryPoint.workgroup_size_y;
|
metadata->localWorkgroupSize.y = entryPoint.workgroup_size_y;
|
||||||
metadata->localWorkgroupSize.z = entryPoint.workgroup_size_z;
|
metadata->localWorkgroupSize.z = entryPoint.workgroup_size_z;
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
#include "utils/WGPUHelpers.h"
|
#include "utils/WGPUHelpers.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
constexpr uint32_t kTileSize = 64u;
|
constexpr uint32_t kTileSize = 32u;
|
||||||
|
|
||||||
const std::string& kMatMulFloatHeader = R"(
|
const std::string& kMatMulFloatHeader = R"(
|
||||||
[[block]] struct Uniforms {
|
[[block]] struct Uniforms {
|
||||||
|
@ -62,18 +62,18 @@ namespace {
|
||||||
|
|
||||||
let RowPerThread : u32 = 4u;
|
let RowPerThread : u32 = 4u;
|
||||||
let ColPerThread : u32 = 4u;
|
let ColPerThread : u32 = 4u;
|
||||||
let TileAOuter : u32 = 64u;
|
let TileAOuter : u32 = 32u;
|
||||||
let TileBOuter : u32 = 64u;
|
let TileBOuter : u32 = 32u;
|
||||||
let TileInner : u32 = 64u;)";
|
let TileInner : u32 = 32u;)";
|
||||||
|
|
||||||
const std::string& kMatMulFloatSharedArray1D = R"(
|
const std::string& kMatMulFloatSharedArray1D = R"(
|
||||||
var<workgroup> mm_Asub : array<f32, 4096>;
|
var<workgroup> mm_Asub : array<f32, 1024>;
|
||||||
var<workgroup> mm_Bsub : array<f32, 4096>;)";
|
var<workgroup> mm_Bsub : array<f32, 1024>;)";
|
||||||
const std::string& kMatMulFloatSharedArray2D = R"(
|
const std::string& kMatMulFloatSharedArray2D = R"(
|
||||||
var<workgroup> mm_Asub : array<array<f32, 64>, 64>;
|
var<workgroup> mm_Asub : array<array<f32, 32>, 32>;
|
||||||
var<workgroup> mm_Bsub : array<array<f32, 64>, 64>;)";
|
var<workgroup> mm_Bsub : array<array<f32, 32>, 32>;)";
|
||||||
const std::string& kMatMulFloatBodyPart1 = R"(
|
const std::string& kMatMulFloatBodyPart1 = R"(
|
||||||
[[stage(compute), workgroup_size(16, 16, 1)]]
|
[[stage(compute), workgroup_size(8, 8, 1)]]
|
||||||
fn main([[builtin(local_invocation_id)]] local_id : vec3<u32>,
|
fn main([[builtin(local_invocation_id)]] local_id : vec3<u32>,
|
||||||
[[builtin(global_invocation_id)]] global_id : vec3<u32>) {
|
[[builtin(global_invocation_id)]] global_id : vec3<u32>) {
|
||||||
let tileRow : u32 = local_id.y * RowPerThread;
|
let tileRow : u32 = local_id.y * RowPerThread;
|
||||||
|
@ -95,9 +95,9 @@ namespace {
|
||||||
acc[index] = 0.;
|
acc[index] = 0.;
|
||||||
}
|
}
|
||||||
|
|
||||||
let ColPerThreadA : u32 = TileInner / 16u;
|
let ColPerThreadA : u32 = TileInner / 8u;
|
||||||
let tileColA : u32 = local_id.x * ColPerThreadA;
|
let tileColA : u32 = local_id.x * ColPerThreadA;
|
||||||
let RowPerThreadB : u32 = TileInner / 16u;
|
let RowPerThreadB : u32 = TileInner / 8u;
|
||||||
let tileRowB : u32 = local_id.y * RowPerThreadB;
|
let tileRowB : u32 = local_id.y * RowPerThreadB;
|
||||||
|
|
||||||
// Loop over shared dimension.
|
// Loop over shared dimension.
|
||||||
|
@ -229,17 +229,16 @@ namespace {
|
||||||
|
|
||||||
let RowPerThread : u32 = 4u;
|
let RowPerThread : u32 = 4u;
|
||||||
let ColPerThread : u32 = 4u;
|
let ColPerThread : u32 = 4u;
|
||||||
let TileAOuter : u32 = 64u;
|
let TileOuter : u32 = 32u;
|
||||||
let TileBOuter : u32 = 64u;
|
let TileInner : u32 = 32u;)";
|
||||||
let TileInner : u32 = 64u;)";
|
|
||||||
const std::string& kMatMulVec4SharedArray1D = R"(
|
const std::string& kMatMulVec4SharedArray1D = R"(
|
||||||
var<workgroup> mm_Asub : array<vec4<f32>, 1024>;
|
var<workgroup> mm_Asub : array<vec4<f32>, 256>;
|
||||||
var<workgroup> mm_Bsub : array<vec4<f32>, 1024>;)";
|
var<workgroup> mm_Bsub : array<vec4<f32>, 256>;)";
|
||||||
const std::string& kMatMulVec4SharedArray2D = R"(
|
const std::string& kMatMulVec4SharedArray2D = R"(
|
||||||
var<workgroup> mm_Asub : array<array<vec4<f32>, 16>, 64>;
|
var<workgroup> mm_Asub : array<array<vec4<f32>, 8>, 32>;
|
||||||
var<workgroup> mm_Bsub : array<array<vec4<f32>, 16>, 64>;)";
|
var<workgroup> mm_Bsub : array<array<vec4<f32>, 8>, 32>;)";
|
||||||
const std::string& kMatMulVec4BodyPart1 = R"(
|
const std::string& kMatMulVec4BodyPart1 = R"(
|
||||||
[[stage(compute), workgroup_size(16, 16, 1)]]
|
[[stage(compute), workgroup_size(8, 8, 1)]]
|
||||||
fn main([[builtin(local_invocation_id)]] local_id : vec3<u32>,
|
fn main([[builtin(local_invocation_id)]] local_id : vec3<u32>,
|
||||||
[[builtin(global_invocation_id)]] global_id : vec3<u32>) {
|
[[builtin(global_invocation_id)]] global_id : vec3<u32>) {
|
||||||
let tileRow : u32 = local_id.y * RowPerThread;
|
let tileRow : u32 = local_id.y * RowPerThread;
|
||||||
|
@ -262,7 +261,7 @@ namespace {
|
||||||
}
|
}
|
||||||
|
|
||||||
var globalColA : u32 = tileCol;
|
var globalColA : u32 = tileCol;
|
||||||
let RowPerThreadB : u32 = TileInner / 16u;
|
let RowPerThreadB : u32 = TileInner / 8u;
|
||||||
let tileRowB : u32 = local_id.y * RowPerThreadB;
|
let tileRowB : u32 = local_id.y * RowPerThreadB;
|
||||||
|
|
||||||
// Loop over shared dimension.
|
// Loop over shared dimension.
|
||||||
|
@ -281,7 +280,7 @@ namespace {
|
||||||
for (var innerRow : u32 = 0u; innerRow < RowPerThreadB; innerRow = innerRow + 1u) {
|
for (var innerRow : u32 = 0u; innerRow < RowPerThreadB; innerRow = innerRow + 1u) {
|
||||||
let inputRow : u32 = tileRowB + innerRow;
|
let inputRow : u32 = tileRowB + innerRow;
|
||||||
let inputCol : u32 = tileCol;
|
let inputCol : u32 = tileCol;
|
||||||
let index : u32 = inputRow * TileBOuter / ColPerThread + inputCol;
|
let index : u32 = inputRow * TileOuter / ColPerThread + inputCol;
|
||||||
mm_Bsub[index] = mm_readB(t * TileInner + inputRow, globalCol);;
|
mm_Bsub[index] = mm_readB(t * TileInner + inputRow, globalCol);;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -289,10 +288,10 @@ namespace {
|
||||||
|
|
||||||
// Compute acc values for a single thread.
|
// Compute acc values for a single thread.
|
||||||
for (var k : u32 = 0u; k < TileInner / ColPerThread; k = k + 1u) {
|
for (var k : u32 = 0u; k < TileInner / ColPerThread; k = k + 1u) {
|
||||||
BCached[0] = mm_Bsub[(k * ColPerThread) * (TileBOuter / ColPerThread) + tileCol];
|
BCached[0] = mm_Bsub[(k * ColPerThread) * (TileOuter / ColPerThread) + tileCol];
|
||||||
BCached[1] = mm_Bsub[(k * ColPerThread + 1u) * (TileBOuter / ColPerThread) + tileCol];
|
BCached[1] = mm_Bsub[(k * ColPerThread + 1u) * (TileOuter / ColPerThread) + tileCol];
|
||||||
BCached[2] = mm_Bsub[(k * ColPerThread + 2u) * (TileBOuter / ColPerThread) + tileCol];
|
BCached[2] = mm_Bsub[(k * ColPerThread + 2u) * (TileOuter / ColPerThread) + tileCol];
|
||||||
BCached[3] = mm_Bsub[(k * ColPerThread + 3u) * (TileBOuter / ColPerThread) + tileCol];
|
BCached[3] = mm_Bsub[(k * ColPerThread + 3u) * (TileOuter / ColPerThread) + tileCol];
|
||||||
|
|
||||||
for (var i : u32 = 0u; i < RowPerThread; i = i + 1u) {
|
for (var i : u32 = 0u; i < RowPerThread; i = i + 1u) {
|
||||||
ACached = mm_Asub[(tileRow + i) * (TileInner / ColPerThread) + k];)";
|
ACached = mm_Asub[(tileRow + i) * (TileInner / ColPerThread) + k];)";
|
||||||
|
|
|
@ -282,3 +282,64 @@ TEST_F(ShaderModuleValidationTest, MaximumShaderIOLocations) {
|
||||||
ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, fragmentShader.c_str()));
|
ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, fragmentShader.c_str()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tests that we validate workgroup size limits.
|
||||||
|
TEST_F(ShaderModuleValidationTest, ComputeWorkgroupSizeLimits) {
|
||||||
|
DAWN_SKIP_TEST_IF(!HasToggleEnabled("use_tint_generator"));
|
||||||
|
|
||||||
|
auto MakeShaderWithWorkgroupSize = [this](uint32_t x, uint32_t y, uint32_t z) {
|
||||||
|
std::ostringstream ss;
|
||||||
|
ss << "[[stage(compute), workgroup_size(" << x << "," << y << "," << z
|
||||||
|
<< ")]] fn main() {}";
|
||||||
|
utils::CreateShaderModule(device, ss.str().c_str());
|
||||||
|
};
|
||||||
|
|
||||||
|
MakeShaderWithWorkgroupSize(1, 1, 1);
|
||||||
|
MakeShaderWithWorkgroupSize(kMaxComputeWorkgroupSizeX, 1, 1);
|
||||||
|
MakeShaderWithWorkgroupSize(1, kMaxComputeWorkgroupSizeY, 1);
|
||||||
|
MakeShaderWithWorkgroupSize(1, 1, kMaxComputeWorkgroupSizeZ);
|
||||||
|
|
||||||
|
ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupSize(kMaxComputeWorkgroupSizeX + 1, 1, 1));
|
||||||
|
ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupSize(1, kMaxComputeWorkgroupSizeY + 1, 1));
|
||||||
|
ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupSize(1, 1, kMaxComputeWorkgroupSizeZ + 1));
|
||||||
|
|
||||||
|
// No individual dimension exceeds its limit, but the combined size should definitely exceed the
|
||||||
|
// total invocation limit.
|
||||||
|
ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupSize(
|
||||||
|
kMaxComputeWorkgroupSizeX, kMaxComputeWorkgroupSizeY, kMaxComputeWorkgroupSizeZ));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests that we validate workgroup storage size limits.
|
||||||
|
TEST_F(ShaderModuleValidationTest, ComputeWorkgroupStorageSizeLimits) {
|
||||||
|
DAWN_SKIP_TEST_IF(!HasToggleEnabled("use_tint_generator"));
|
||||||
|
|
||||||
|
constexpr uint32_t kVec4Size = 16;
|
||||||
|
constexpr uint32_t kMaxVec4Count = kMaxComputeWorkgroupStorageSize / kVec4Size;
|
||||||
|
constexpr uint32_t kMat4Size = 64;
|
||||||
|
constexpr uint32_t kMaxMat4Count = kMaxComputeWorkgroupStorageSize / kMat4Size;
|
||||||
|
|
||||||
|
auto MakeShaderWithWorkgroupStorage = [this](uint32_t vec4_count, uint32_t mat4_count) {
|
||||||
|
std::ostringstream ss;
|
||||||
|
std::ostringstream body;
|
||||||
|
if (vec4_count > 0) {
|
||||||
|
ss << "var<workgroup> vec4_data: array<vec4<f32>, " << vec4_count << ">;";
|
||||||
|
body << "ignore(vec4_data);";
|
||||||
|
}
|
||||||
|
if (mat4_count > 0) {
|
||||||
|
ss << "var<workgroup> mat4_data: array<mat4x4<f32>, " << mat4_count << ">;";
|
||||||
|
body << "ignore(mat4_data);";
|
||||||
|
}
|
||||||
|
ss << "[[stage(compute), workgroup_size(1)]] fn main() { " << body.str() << " }";
|
||||||
|
utils::CreateShaderModule(device, ss.str().c_str());
|
||||||
|
};
|
||||||
|
|
||||||
|
MakeShaderWithWorkgroupStorage(1, 1);
|
||||||
|
MakeShaderWithWorkgroupStorage(kMaxVec4Count, 0);
|
||||||
|
MakeShaderWithWorkgroupStorage(0, kMaxMat4Count);
|
||||||
|
MakeShaderWithWorkgroupStorage(kMaxVec4Count - 4, 1);
|
||||||
|
MakeShaderWithWorkgroupStorage(4, kMaxMat4Count - 1);
|
||||||
|
ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupStorage(kMaxVec4Count + 1, 0));
|
||||||
|
ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupStorage(kMaxVec4Count - 3, 1));
|
||||||
|
ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupStorage(0, kMaxMat4Count + 1));
|
||||||
|
ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupStorage(4, kMaxMat4Count));
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue