tint/inspector: Use a std::optional for workgroup_size
No-value represents a workgroup size that is derived from an override-expression. Bug: dawn:1504 Bug: chromium:1346929 Change-Id: Idf6caa9d052aa56e8ef1913d16d1f68d2c5844ed Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/97362 Reviewed-by: Dan Sinclair <dsinclair@chromium.org> Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
a123b892a5
commit
a1571ac403
|
@ -660,19 +660,24 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
|
||||||
DAWN_TRY_ASSIGN(metadata->stage, TintPipelineStageToShaderStage(entryPoint.stage));
|
DAWN_TRY_ASSIGN(metadata->stage, TintPipelineStageToShaderStage(entryPoint.stage));
|
||||||
|
|
||||||
if (metadata->stage == SingleShaderStage::Compute) {
|
if (metadata->stage == SingleShaderStage::Compute) {
|
||||||
DelayedInvalidIf(entryPoint.workgroup_size_x > limits.v1.maxComputeWorkgroupSizeX ||
|
auto workgroup_size = entryPoint.workgroup_size;
|
||||||
entryPoint.workgroup_size_y > limits.v1.maxComputeWorkgroupSizeY ||
|
DAWN_INVALID_IF(
|
||||||
entryPoint.workgroup_size_z > limits.v1.maxComputeWorkgroupSizeZ,
|
!workgroup_size.has_value(),
|
||||||
|
"TODO(crbug.com/dawn/1504): Dawn does not currently support @workgroup_size "
|
||||||
|
"attributes using override-expressions");
|
||||||
|
DelayedInvalidIf(workgroup_size->x > limits.v1.maxComputeWorkgroupSizeX ||
|
||||||
|
workgroup_size->y > limits.v1.maxComputeWorkgroupSizeY ||
|
||||||
|
workgroup_size->z > limits.v1.maxComputeWorkgroupSizeZ,
|
||||||
"Entry-point uses workgroup_size(%u, %u, %u) that exceeds the "
|
"Entry-point uses workgroup_size(%u, %u, %u) that exceeds the "
|
||||||
"maximum allowed (%u, %u, %u).",
|
"maximum allowed (%u, %u, %u).",
|
||||||
entryPoint.workgroup_size_x, entryPoint.workgroup_size_y,
|
workgroup_size->x, workgroup_size->y, workgroup_size->z,
|
||||||
entryPoint.workgroup_size_z, limits.v1.maxComputeWorkgroupSizeX,
|
limits.v1.maxComputeWorkgroupSizeX, limits.v1.maxComputeWorkgroupSizeY,
|
||||||
limits.v1.maxComputeWorkgroupSizeY, limits.v1.maxComputeWorkgroupSizeZ);
|
limits.v1.maxComputeWorkgroupSizeZ);
|
||||||
|
|
||||||
// Dimensions have already been validated against their individual limits above.
|
// Dimensions have already been validated against their individual limits above.
|
||||||
// Cast to uint64_t to avoid overflow in this multiplication.
|
// Cast to uint64_t to avoid overflow in this multiplication.
|
||||||
uint64_t numInvocations = static_cast<uint64_t>(entryPoint.workgroup_size_x) *
|
uint64_t numInvocations =
|
||||||
entryPoint.workgroup_size_y * entryPoint.workgroup_size_z;
|
static_cast<uint64_t>(workgroup_size->x) * workgroup_size->y * workgroup_size->z;
|
||||||
DelayedInvalidIf(numInvocations > limits.v1.maxComputeInvocationsPerWorkgroup,
|
DelayedInvalidIf(numInvocations > limits.v1.maxComputeInvocationsPerWorkgroup,
|
||||||
"The total number of workgroup invocations (%u) exceeds the "
|
"The total number of workgroup invocations (%u) exceeds the "
|
||||||
"maximum allowed (%u).",
|
"maximum allowed (%u).",
|
||||||
|
@ -684,9 +689,9 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
|
||||||
"the maximum allowed (%u bytes).",
|
"the maximum allowed (%u bytes).",
|
||||||
workgroupStorageSize, limits.v1.maxComputeWorkgroupStorageSize);
|
workgroupStorageSize, limits.v1.maxComputeWorkgroupStorageSize);
|
||||||
|
|
||||||
metadata->localWorkgroupSize.x = entryPoint.workgroup_size_x;
|
metadata->localWorkgroupSize.x = workgroup_size->x;
|
||||||
metadata->localWorkgroupSize.y = entryPoint.workgroup_size_y;
|
metadata->localWorkgroupSize.y = workgroup_size->y;
|
||||||
metadata->localWorkgroupSize.z = entryPoint.workgroup_size_z;
|
metadata->localWorkgroupSize.z = workgroup_size->z;
|
||||||
|
|
||||||
metadata->usesNumWorkgroups = entryPoint.num_workgroups_used;
|
metadata->usesNumWorkgroups = entryPoint.num_workgroups_used;
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
#ifndef SRC_TINT_INSPECTOR_ENTRY_POINT_H_
|
#ifndef SRC_TINT_INSPECTOR_ENTRY_POINT_H_
|
||||||
#define SRC_TINT_INSPECTOR_ENTRY_POINT_H_
|
#define SRC_TINT_INSPECTOR_ENTRY_POINT_H_
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -123,6 +124,16 @@ struct OverridableConstant {
|
||||||
/// The pipeline stage
|
/// The pipeline stage
|
||||||
enum class PipelineStage { kVertex, kFragment, kCompute };
|
enum class PipelineStage { kVertex, kFragment, kCompute };
|
||||||
|
|
||||||
|
/// WorkgroupSize describes the dimensions of the workgroup grid for a compute shader.
|
||||||
|
struct WorkgroupSize {
|
||||||
|
/// The 'x' dimension of the workgroup grid
|
||||||
|
uint32_t x = 1;
|
||||||
|
/// The 'y' dimension of the workgroup grid
|
||||||
|
uint32_t y = 1;
|
||||||
|
/// The 'z' dimension of the workgroup grid
|
||||||
|
uint32_t z = 1;
|
||||||
|
};
|
||||||
|
|
||||||
/// Reflection data for an entry point in the shader.
|
/// Reflection data for an entry point in the shader.
|
||||||
struct EntryPoint {
|
struct EntryPoint {
|
||||||
/// Constructors
|
/// Constructors
|
||||||
|
@ -139,12 +150,10 @@ struct EntryPoint {
|
||||||
std::string remapped_name;
|
std::string remapped_name;
|
||||||
/// The entry point stage
|
/// The entry point stage
|
||||||
PipelineStage stage;
|
PipelineStage stage;
|
||||||
/// The workgroup x size
|
/// The workgroup size. If PipelineStage is kCompute and this holds no value, then the workgroup
|
||||||
uint32_t workgroup_size_x = 0;
|
/// size is derived from an override-expression. In this situation you first need to run the
|
||||||
/// The workgroup y size
|
/// tint::transform::SubstituteOverride transform before using the inspector.
|
||||||
uint32_t workgroup_size_y = 0;
|
std::optional<WorkgroupSize> workgroup_size;
|
||||||
/// The workgroup z size
|
|
||||||
uint32_t workgroup_size_z = 0;
|
|
||||||
/// List of the input variable accessed via this entry point.
|
/// List of the input variable accessed via this entry point.
|
||||||
std::vector<StageVariable> input_variables;
|
std::vector<StageVariable> input_variables;
|
||||||
/// List of the output variable accessed via this entry point.
|
/// List of the output variable accessed via this entry point.
|
||||||
|
@ -166,12 +175,6 @@ struct EntryPoint {
|
||||||
bool sample_index_used = false;
|
bool sample_index_used = false;
|
||||||
/// Does the entry point use the num_workgroups builtin
|
/// Does the entry point use the num_workgroups builtin
|
||||||
bool num_workgroups_used = false;
|
bool num_workgroups_used = false;
|
||||||
|
|
||||||
/// @returns the size of the workgroup in {x,y,z} format
|
|
||||||
std::tuple<uint32_t, uint32_t, uint32_t> workgroup_size() {
|
|
||||||
return std::tuple<uint32_t, uint32_t, uint32_t>(workgroup_size_x, workgroup_size_y,
|
|
||||||
workgroup_size_z);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tint::inspector
|
} // namespace tint::inspector
|
||||||
|
|
|
@ -148,29 +148,30 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
|
||||||
entry_point.remapped_name = program_->Symbols().NameFor(func->symbol);
|
entry_point.remapped_name = program_->Symbols().NameFor(func->symbol);
|
||||||
|
|
||||||
switch (func->PipelineStage()) {
|
switch (func->PipelineStage()) {
|
||||||
case ast::PipelineStage::kCompute:
|
case ast::PipelineStage::kCompute: {
|
||||||
entry_point.stage = PipelineStage::kCompute;
|
entry_point.stage = PipelineStage::kCompute;
|
||||||
|
|
||||||
|
auto wgsize = sem->WorkgroupSize();
|
||||||
|
if (!wgsize[0].overridable_const && !wgsize[1].overridable_const &&
|
||||||
|
!wgsize[2].overridable_const) {
|
||||||
|
entry_point.workgroup_size = {wgsize[0].value, wgsize[1].value,
|
||||||
|
wgsize[2].value};
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case ast::PipelineStage::kFragment:
|
}
|
||||||
|
case ast::PipelineStage::kFragment: {
|
||||||
entry_point.stage = PipelineStage::kFragment;
|
entry_point.stage = PipelineStage::kFragment;
|
||||||
break;
|
break;
|
||||||
case ast::PipelineStage::kVertex:
|
}
|
||||||
|
case ast::PipelineStage::kVertex: {
|
||||||
entry_point.stage = PipelineStage::kVertex;
|
entry_point.stage = PipelineStage::kVertex;
|
||||||
break;
|
break;
|
||||||
default:
|
}
|
||||||
|
default: {
|
||||||
TINT_UNREACHABLE(Inspector, diagnostics_)
|
TINT_UNREACHABLE(Inspector, diagnostics_)
|
||||||
<< "invalid pipeline stage for entry point '" << entry_point.name << "'";
|
<< "invalid pipeline stage for entry point '" << entry_point.name << "'";
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto wgsize = sem->WorkgroupSize();
|
|
||||||
entry_point.workgroup_size_x = wgsize[0].value;
|
|
||||||
entry_point.workgroup_size_y = wgsize[1].value;
|
|
||||||
entry_point.workgroup_size_z = wgsize[2].value;
|
|
||||||
if (wgsize[0].overridable_const || wgsize[1].overridable_const ||
|
|
||||||
wgsize[2].overridable_const) {
|
|
||||||
// TODO(crbug.com/tint/713): Handle overridable constants.
|
|
||||||
TINT_ASSERT(Inspector, false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto* param : sem->Parameters()) {
|
for (auto* param : sem->Parameters()) {
|
||||||
|
|
|
@ -239,11 +239,11 @@ TEST_F(InspectorGetEntryPointTest, DefaultWorkgroupSize) {
|
||||||
ASSERT_FALSE(inspector.has_error()) << inspector.error();
|
ASSERT_FALSE(inspector.has_error()) << inspector.error();
|
||||||
|
|
||||||
ASSERT_EQ(1u, result.size());
|
ASSERT_EQ(1u, result.size());
|
||||||
uint32_t x, y, z;
|
auto workgroup_size = result[0].workgroup_size;
|
||||||
std::tie(x, y, z) = result[0].workgroup_size();
|
ASSERT_TRUE(workgroup_size.has_value());
|
||||||
EXPECT_EQ(8u, x);
|
EXPECT_EQ(8u, workgroup_size->x);
|
||||||
EXPECT_EQ(2u, y);
|
EXPECT_EQ(2u, workgroup_size->y);
|
||||||
EXPECT_EQ(1u, z);
|
EXPECT_EQ(1u, workgroup_size->z);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(InspectorGetEntryPointTest, NonDefaultWorkgroupSize) {
|
TEST_F(InspectorGetEntryPointTest, NonDefaultWorkgroupSize) {
|
||||||
|
@ -258,11 +258,11 @@ TEST_F(InspectorGetEntryPointTest, NonDefaultWorkgroupSize) {
|
||||||
ASSERT_FALSE(inspector.has_error()) << inspector.error();
|
ASSERT_FALSE(inspector.has_error()) << inspector.error();
|
||||||
|
|
||||||
ASSERT_EQ(1u, result.size());
|
ASSERT_EQ(1u, result.size());
|
||||||
uint32_t x, y, z;
|
auto workgroup_size = result[0].workgroup_size;
|
||||||
std::tie(x, y, z) = result[0].workgroup_size();
|
ASSERT_TRUE(workgroup_size.has_value());
|
||||||
EXPECT_EQ(8u, x);
|
EXPECT_EQ(8u, workgroup_size->x);
|
||||||
EXPECT_EQ(2u, y);
|
EXPECT_EQ(2u, workgroup_size->y);
|
||||||
EXPECT_EQ(1u, z);
|
EXPECT_EQ(1u, workgroup_size->z);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(InspectorGetEntryPointTest, NoInOutVariables) {
|
TEST_F(InspectorGetEntryPointTest, NoInOutVariables) {
|
||||||
|
|
Loading…
Reference in New Issue