tint/inspector: Use a std::optional for workgroup_size

No-value represents a workgroup size that is derived from an override-expression.

Bug: dawn:1504
Bug: chromium:1346929
Change-Id: Idf6caa9d052aa56e8ef1913d16d1f68d2c5844ed
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/97362
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-07-27 17:05:56 +00:00 committed by Dawn LUCI CQ
parent a123b892a5
commit a1571ac403
4 changed files with 56 additions and 47 deletions

View File

@ -660,19 +660,24 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
DAWN_TRY_ASSIGN(metadata->stage, TintPipelineStageToShaderStage(entryPoint.stage));
if (metadata->stage == SingleShaderStage::Compute) {
DelayedInvalidIf(entryPoint.workgroup_size_x > limits.v1.maxComputeWorkgroupSizeX ||
entryPoint.workgroup_size_y > limits.v1.maxComputeWorkgroupSizeY ||
entryPoint.workgroup_size_z > limits.v1.maxComputeWorkgroupSizeZ,
auto workgroup_size = entryPoint.workgroup_size;
DAWN_INVALID_IF(
!workgroup_size.has_value(),
"TODO(crbug.com/dawn/1504): Dawn does not currently support @workgroup_size "
"attributes using override-expressions");
DelayedInvalidIf(workgroup_size->x > limits.v1.maxComputeWorkgroupSizeX ||
workgroup_size->y > limits.v1.maxComputeWorkgroupSizeY ||
workgroup_size->z > limits.v1.maxComputeWorkgroupSizeZ,
"Entry-point uses workgroup_size(%u, %u, %u) that exceeds the "
"maximum allowed (%u, %u, %u).",
entryPoint.workgroup_size_x, entryPoint.workgroup_size_y,
entryPoint.workgroup_size_z, limits.v1.maxComputeWorkgroupSizeX,
limits.v1.maxComputeWorkgroupSizeY, limits.v1.maxComputeWorkgroupSizeZ);
workgroup_size->x, workgroup_size->y, workgroup_size->z,
limits.v1.maxComputeWorkgroupSizeX, limits.v1.maxComputeWorkgroupSizeY,
limits.v1.maxComputeWorkgroupSizeZ);
// Dimensions have already been validated against their individual limits above.
// Cast to uint64_t to avoid overflow in this multiplication.
uint64_t numInvocations = static_cast<uint64_t>(entryPoint.workgroup_size_x) *
entryPoint.workgroup_size_y * entryPoint.workgroup_size_z;
uint64_t numInvocations =
static_cast<uint64_t>(workgroup_size->x) * workgroup_size->y * workgroup_size->z;
DelayedInvalidIf(numInvocations > limits.v1.maxComputeInvocationsPerWorkgroup,
"The total number of workgroup invocations (%u) exceeds the "
"maximum allowed (%u).",
@ -684,9 +689,9 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
"the maximum allowed (%u bytes).",
workgroupStorageSize, limits.v1.maxComputeWorkgroupStorageSize);
metadata->localWorkgroupSize.x = entryPoint.workgroup_size_x;
metadata->localWorkgroupSize.y = entryPoint.workgroup_size_y;
metadata->localWorkgroupSize.z = entryPoint.workgroup_size_z;
metadata->localWorkgroupSize.x = workgroup_size->x;
metadata->localWorkgroupSize.y = workgroup_size->y;
metadata->localWorkgroupSize.z = workgroup_size->z;
metadata->usesNumWorkgroups = entryPoint.num_workgroups_used;
}

View File

@ -15,6 +15,7 @@
#ifndef SRC_TINT_INSPECTOR_ENTRY_POINT_H_
#define SRC_TINT_INSPECTOR_ENTRY_POINT_H_
#include <optional>
#include <string>
#include <tuple>
#include <vector>
@ -123,6 +124,16 @@ struct OverridableConstant {
/// The pipeline stage
enum class PipelineStage { kVertex, kFragment, kCompute };
/// WorkgroupSize describes the dimensions of the workgroup grid for a compute shader.
struct WorkgroupSize {
/// The 'x' dimension of the workgroup grid
uint32_t x = 1;
/// The 'y' dimension of the workgroup grid
uint32_t y = 1;
/// The 'z' dimension of the workgroup grid
uint32_t z = 1;
};
/// Reflection data for an entry point in the shader.
struct EntryPoint {
/// Constructors
@ -139,12 +150,10 @@ struct EntryPoint {
std::string remapped_name;
/// The entry point stage
PipelineStage stage;
/// The workgroup x size
uint32_t workgroup_size_x = 0;
/// The workgroup y size
uint32_t workgroup_size_y = 0;
/// The workgroup z size
uint32_t workgroup_size_z = 0;
/// The workgroup size. If PipelineStage is kCompute and this holds no value, then the workgroup
/// size is derived from an override-expression. In this situation you first need to run the
/// tint::transform::SubstituteOverride transform before using the inspector.
std::optional<WorkgroupSize> workgroup_size;
/// List of the input variable accessed via this entry point.
std::vector<StageVariable> input_variables;
/// List of the output variable accessed via this entry point.
@ -166,12 +175,6 @@ struct EntryPoint {
bool sample_index_used = false;
/// Does the entry point use the num_workgroups builtin
bool num_workgroups_used = false;
/// @returns the size of the workgroup in {x,y,z} format
std::tuple<uint32_t, uint32_t, uint32_t> workgroup_size() {
return std::tuple<uint32_t, uint32_t, uint32_t>(workgroup_size_x, workgroup_size_y,
workgroup_size_z);
}
};
} // namespace tint::inspector

View File

@ -148,29 +148,30 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
entry_point.remapped_name = program_->Symbols().NameFor(func->symbol);
switch (func->PipelineStage()) {
case ast::PipelineStage::kCompute:
case ast::PipelineStage::kCompute: {
entry_point.stage = PipelineStage::kCompute;
auto wgsize = sem->WorkgroupSize();
if (!wgsize[0].overridable_const && !wgsize[1].overridable_const &&
!wgsize[2].overridable_const) {
entry_point.workgroup_size = {wgsize[0].value, wgsize[1].value,
wgsize[2].value};
}
break;
case ast::PipelineStage::kFragment:
}
case ast::PipelineStage::kFragment: {
entry_point.stage = PipelineStage::kFragment;
break;
case ast::PipelineStage::kVertex:
}
case ast::PipelineStage::kVertex: {
entry_point.stage = PipelineStage::kVertex;
break;
default:
}
default: {
TINT_UNREACHABLE(Inspector, diagnostics_)
<< "invalid pipeline stage for entry point '" << entry_point.name << "'";
break;
}
auto wgsize = sem->WorkgroupSize();
entry_point.workgroup_size_x = wgsize[0].value;
entry_point.workgroup_size_y = wgsize[1].value;
entry_point.workgroup_size_z = wgsize[2].value;
if (wgsize[0].overridable_const || wgsize[1].overridable_const ||
wgsize[2].overridable_const) {
// TODO(crbug.com/tint/713): Handle overridable constants.
TINT_ASSERT(Inspector, false);
}
}
for (auto* param : sem->Parameters()) {

View File

@ -239,11 +239,11 @@ TEST_F(InspectorGetEntryPointTest, DefaultWorkgroupSize) {
ASSERT_FALSE(inspector.has_error()) << inspector.error();
ASSERT_EQ(1u, result.size());
uint32_t x, y, z;
std::tie(x, y, z) = result[0].workgroup_size();
EXPECT_EQ(8u, x);
EXPECT_EQ(2u, y);
EXPECT_EQ(1u, z);
auto workgroup_size = result[0].workgroup_size;
ASSERT_TRUE(workgroup_size.has_value());
EXPECT_EQ(8u, workgroup_size->x);
EXPECT_EQ(2u, workgroup_size->y);
EXPECT_EQ(1u, workgroup_size->z);
}
TEST_F(InspectorGetEntryPointTest, NonDefaultWorkgroupSize) {
@ -258,11 +258,11 @@ TEST_F(InspectorGetEntryPointTest, NonDefaultWorkgroupSize) {
ASSERT_FALSE(inspector.has_error()) << inspector.error();
ASSERT_EQ(1u, result.size());
uint32_t x, y, z;
std::tie(x, y, z) = result[0].workgroup_size();
EXPECT_EQ(8u, x);
EXPECT_EQ(2u, y);
EXPECT_EQ(1u, z);
auto workgroup_size = result[0].workgroup_size;
ASSERT_TRUE(workgroup_size.has_value());
EXPECT_EQ(8u, workgroup_size->x);
EXPECT_EQ(2u, workgroup_size->y);
EXPECT_EQ(1u, workgroup_size->z);
}
TEST_F(InspectorGetEntryPointTest, NoInOutVariables) {