tint/inspector: Use a std::optional for workgroup_size

No-value represents a workgroup size that is derived from an override-expression.

Bug: dawn:1504
Bug: chromium:1346929
Change-Id: Idf6caa9d052aa56e8ef1913d16d1f68d2c5844ed
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/97362
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-07-27 17:05:56 +00:00 committed by Dawn LUCI CQ
parent a123b892a5
commit a1571ac403
4 changed files with 56 additions and 47 deletions

View File

@ -660,19 +660,24 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
DAWN_TRY_ASSIGN(metadata->stage, TintPipelineStageToShaderStage(entryPoint.stage)); DAWN_TRY_ASSIGN(metadata->stage, TintPipelineStageToShaderStage(entryPoint.stage));
if (metadata->stage == SingleShaderStage::Compute) { if (metadata->stage == SingleShaderStage::Compute) {
DelayedInvalidIf(entryPoint.workgroup_size_x > limits.v1.maxComputeWorkgroupSizeX || auto workgroup_size = entryPoint.workgroup_size;
entryPoint.workgroup_size_y > limits.v1.maxComputeWorkgroupSizeY || DAWN_INVALID_IF(
entryPoint.workgroup_size_z > limits.v1.maxComputeWorkgroupSizeZ, !workgroup_size.has_value(),
"TODO(crbug.com/dawn/1504): Dawn does not currently support @workgroup_size "
"attributes using override-expressions");
DelayedInvalidIf(workgroup_size->x > limits.v1.maxComputeWorkgroupSizeX ||
workgroup_size->y > limits.v1.maxComputeWorkgroupSizeY ||
workgroup_size->z > limits.v1.maxComputeWorkgroupSizeZ,
"Entry-point uses workgroup_size(%u, %u, %u) that exceeds the " "Entry-point uses workgroup_size(%u, %u, %u) that exceeds the "
"maximum allowed (%u, %u, %u).", "maximum allowed (%u, %u, %u).",
entryPoint.workgroup_size_x, entryPoint.workgroup_size_y, workgroup_size->x, workgroup_size->y, workgroup_size->z,
entryPoint.workgroup_size_z, limits.v1.maxComputeWorkgroupSizeX, limits.v1.maxComputeWorkgroupSizeX, limits.v1.maxComputeWorkgroupSizeY,
limits.v1.maxComputeWorkgroupSizeY, limits.v1.maxComputeWorkgroupSizeZ); limits.v1.maxComputeWorkgroupSizeZ);
// Dimensions have already been validated against their individual limits above. // Dimensions have already been validated against their individual limits above.
// Cast to uint64_t to avoid overflow in this multiplication. // Cast to uint64_t to avoid overflow in this multiplication.
uint64_t numInvocations = static_cast<uint64_t>(entryPoint.workgroup_size_x) * uint64_t numInvocations =
entryPoint.workgroup_size_y * entryPoint.workgroup_size_z; static_cast<uint64_t>(workgroup_size->x) * workgroup_size->y * workgroup_size->z;
DelayedInvalidIf(numInvocations > limits.v1.maxComputeInvocationsPerWorkgroup, DelayedInvalidIf(numInvocations > limits.v1.maxComputeInvocationsPerWorkgroup,
"The total number of workgroup invocations (%u) exceeds the " "The total number of workgroup invocations (%u) exceeds the "
"maximum allowed (%u).", "maximum allowed (%u).",
@ -684,9 +689,9 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
"the maximum allowed (%u bytes).", "the maximum allowed (%u bytes).",
workgroupStorageSize, limits.v1.maxComputeWorkgroupStorageSize); workgroupStorageSize, limits.v1.maxComputeWorkgroupStorageSize);
metadata->localWorkgroupSize.x = entryPoint.workgroup_size_x; metadata->localWorkgroupSize.x = workgroup_size->x;
metadata->localWorkgroupSize.y = entryPoint.workgroup_size_y; metadata->localWorkgroupSize.y = workgroup_size->y;
metadata->localWorkgroupSize.z = entryPoint.workgroup_size_z; metadata->localWorkgroupSize.z = workgroup_size->z;
metadata->usesNumWorkgroups = entryPoint.num_workgroups_used; metadata->usesNumWorkgroups = entryPoint.num_workgroups_used;
} }

View File

@ -15,6 +15,7 @@
#ifndef SRC_TINT_INSPECTOR_ENTRY_POINT_H_ #ifndef SRC_TINT_INSPECTOR_ENTRY_POINT_H_
#define SRC_TINT_INSPECTOR_ENTRY_POINT_H_ #define SRC_TINT_INSPECTOR_ENTRY_POINT_H_
#include <optional>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
@ -123,6 +124,16 @@ struct OverridableConstant {
/// The pipeline stage /// The pipeline stage
enum class PipelineStage { kVertex, kFragment, kCompute }; enum class PipelineStage { kVertex, kFragment, kCompute };
/// WorkgroupSize describes the dimensions of the workgroup grid for a compute shader.
struct WorkgroupSize {
/// The 'x' dimension of the workgroup grid
uint32_t x = 1;
/// The 'y' dimension of the workgroup grid
uint32_t y = 1;
/// The 'z' dimension of the workgroup grid
uint32_t z = 1;
};
/// Reflection data for an entry point in the shader. /// Reflection data for an entry point in the shader.
struct EntryPoint { struct EntryPoint {
/// Constructors /// Constructors
@ -139,12 +150,10 @@ struct EntryPoint {
std::string remapped_name; std::string remapped_name;
/// The entry point stage /// The entry point stage
PipelineStage stage; PipelineStage stage;
/// The workgroup x size /// The workgroup size. If PipelineStage is kCompute and this holds no value, then the workgroup
uint32_t workgroup_size_x = 0; /// size is derived from an override-expression. In this situation you first need to run the
/// The workgroup y size /// tint::transform::SubstituteOverride transform before using the inspector.
uint32_t workgroup_size_y = 0; std::optional<WorkgroupSize> workgroup_size;
/// The workgroup z size
uint32_t workgroup_size_z = 0;
/// List of the input variable accessed via this entry point. /// List of the input variable accessed via this entry point.
std::vector<StageVariable> input_variables; std::vector<StageVariable> input_variables;
/// List of the output variable accessed via this entry point. /// List of the output variable accessed via this entry point.
@ -166,12 +175,6 @@ struct EntryPoint {
bool sample_index_used = false; bool sample_index_used = false;
/// Does the entry point use the num_workgroups builtin /// Does the entry point use the num_workgroups builtin
bool num_workgroups_used = false; bool num_workgroups_used = false;
/// @returns the size of the workgroup in {x,y,z} format
std::tuple<uint32_t, uint32_t, uint32_t> workgroup_size() {
return std::tuple<uint32_t, uint32_t, uint32_t>(workgroup_size_x, workgroup_size_y,
workgroup_size_z);
}
}; };
} // namespace tint::inspector } // namespace tint::inspector

View File

@ -148,29 +148,30 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
entry_point.remapped_name = program_->Symbols().NameFor(func->symbol); entry_point.remapped_name = program_->Symbols().NameFor(func->symbol);
switch (func->PipelineStage()) { switch (func->PipelineStage()) {
case ast::PipelineStage::kCompute: case ast::PipelineStage::kCompute: {
entry_point.stage = PipelineStage::kCompute; entry_point.stage = PipelineStage::kCompute;
auto wgsize = sem->WorkgroupSize();
if (!wgsize[0].overridable_const && !wgsize[1].overridable_const &&
!wgsize[2].overridable_const) {
entry_point.workgroup_size = {wgsize[0].value, wgsize[1].value,
wgsize[2].value};
}
break; break;
case ast::PipelineStage::kFragment: }
case ast::PipelineStage::kFragment: {
entry_point.stage = PipelineStage::kFragment; entry_point.stage = PipelineStage::kFragment;
break; break;
case ast::PipelineStage::kVertex: }
case ast::PipelineStage::kVertex: {
entry_point.stage = PipelineStage::kVertex; entry_point.stage = PipelineStage::kVertex;
break; break;
default: }
default: {
TINT_UNREACHABLE(Inspector, diagnostics_) TINT_UNREACHABLE(Inspector, diagnostics_)
<< "invalid pipeline stage for entry point '" << entry_point.name << "'"; << "invalid pipeline stage for entry point '" << entry_point.name << "'";
break; break;
} }
auto wgsize = sem->WorkgroupSize();
entry_point.workgroup_size_x = wgsize[0].value;
entry_point.workgroup_size_y = wgsize[1].value;
entry_point.workgroup_size_z = wgsize[2].value;
if (wgsize[0].overridable_const || wgsize[1].overridable_const ||
wgsize[2].overridable_const) {
// TODO(crbug.com/tint/713): Handle overridable constants.
TINT_ASSERT(Inspector, false);
} }
for (auto* param : sem->Parameters()) { for (auto* param : sem->Parameters()) {

View File

@ -239,11 +239,11 @@ TEST_F(InspectorGetEntryPointTest, DefaultWorkgroupSize) {
ASSERT_FALSE(inspector.has_error()) << inspector.error(); ASSERT_FALSE(inspector.has_error()) << inspector.error();
ASSERT_EQ(1u, result.size()); ASSERT_EQ(1u, result.size());
uint32_t x, y, z; auto workgroup_size = result[0].workgroup_size;
std::tie(x, y, z) = result[0].workgroup_size(); ASSERT_TRUE(workgroup_size.has_value());
EXPECT_EQ(8u, x); EXPECT_EQ(8u, workgroup_size->x);
EXPECT_EQ(2u, y); EXPECT_EQ(2u, workgroup_size->y);
EXPECT_EQ(1u, z); EXPECT_EQ(1u, workgroup_size->z);
} }
TEST_F(InspectorGetEntryPointTest, NonDefaultWorkgroupSize) { TEST_F(InspectorGetEntryPointTest, NonDefaultWorkgroupSize) {
@ -258,11 +258,11 @@ TEST_F(InspectorGetEntryPointTest, NonDefaultWorkgroupSize) {
ASSERT_FALSE(inspector.has_error()) << inspector.error(); ASSERT_FALSE(inspector.has_error()) << inspector.error();
ASSERT_EQ(1u, result.size()); ASSERT_EQ(1u, result.size());
uint32_t x, y, z; auto workgroup_size = result[0].workgroup_size;
std::tie(x, y, z) = result[0].workgroup_size(); ASSERT_TRUE(workgroup_size.has_value());
EXPECT_EQ(8u, x); EXPECT_EQ(8u, workgroup_size->x);
EXPECT_EQ(2u, y); EXPECT_EQ(2u, workgroup_size->y);
EXPECT_EQ(1u, z); EXPECT_EQ(1u, workgroup_size->z);
} }
TEST_F(InspectorGetEntryPointTest, NoInOutVariables) { TEST_F(InspectorGetEntryPointTest, NoInOutVariables) {