inspector: reflect workgroup storage size

This reflects the total size of all workgroup storage-class variables
referenced transitively by an entry point.

Bug: tint:919
Change-Id: If3a217fea5a875ac18db6de1579f004e368fbb7b
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/57740
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ken Rockot <rockot@google.com>
This commit is contained in:
Ken Rockot 2021-07-19 20:30:09 +00:00 committed by Tint LUCI CQ
parent b291cfced9
commit ac9db206eb
9 changed files with 237 additions and 86 deletions

View File

@ -42,6 +42,7 @@
#include "src/sem/variable.h" #include "src/sem/variable.h"
#include "src/sem/vector_type.h" #include "src/sem/vector_type.h"
#include "src/sem/void_type.h" #include "src/sem/void_type.h"
#include "src/utils/math.h"
namespace tint { namespace tint {
namespace inspector { namespace inspector {
@ -534,6 +535,31 @@ std::vector<SamplerTexturePair> Inspector::GetSamplerTextureUses(
return it->second; return it->second;
} }
uint32_t Inspector::GetWorkgroupStorageSize(const std::string& entry_point) {
auto* func = FindEntryPointByName(entry_point);
if (!func) {
return 0;
}
uint32_t total_size = 0;
auto* func_sem = program_->Sem().Get(func);
for (const sem::Variable* var : func_sem->ReferencedModuleVariables()) {
if (var->StorageClass() == ast::StorageClass::kWorkgroup) {
uint32_t align = 0;
uint32_t size = 0;
var->Type()->UnwrapRef()->GetDefaultAlignAndSize(align, size);
// This essentially matches std430 layout rules from GLSL, which are in
// turn specified as an upper bound for Vulkan layout sizing. Since D3D
// and Metal are even less specific, we assume Vulkan behavior as a
// good-enough approximation everywhere.
total_size += utils::RoundUp(align, size);
}
}
return total_size;
}
ast::Function* Inspector::FindEntryPointByName(const std::string& name) { ast::Function* Inspector::FindEntryPointByName(const std::string& name) {
auto* func = program_->AST().Functions().Find(program_->Symbols().Get(name)); auto* func = program_->AST().Functions().Find(program_->Symbols().Get(name));
if (!func) { if (!func) {

View File

@ -132,6 +132,11 @@ class Inspector {
std::vector<SamplerTexturePair> GetSamplerTextureUses( std::vector<SamplerTexturePair> GetSamplerTextureUses(
const std::string& entry_point); const std::string& entry_point);
/// @param entry_point name of the entry point to get information about.
/// @returns the total size in bytes of all Workgroup storage-class storage
/// referenced transitively by the entry point.
uint32_t GetWorkgroupStorageSize(const std::string& entry_point);
private: private:
const Program* program_; const Program* program_;
std::string error_; std::string error_;

View File

@ -134,6 +134,9 @@ class InspectorGetExternalTextureResourceBindingsTest : public InspectorBuilder,
class InspectorGetSamplerTextureUsesTest : public InspectorBuilder, class InspectorGetSamplerTextureUsesTest : public InspectorBuilder,
public testing::Test {}; public testing::Test {};
class InspectorGetWorkgroupStorageSizeTest : public InspectorBuilder,
public testing::Test {};
TEST_F(InspectorGetEntryPointTest, NoFunctions) { TEST_F(InspectorGetEntryPointTest, NoFunctions) {
Inspector& inspector = Build(); Inspector& inspector = Build();
@ -549,7 +552,7 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantUnreferenced) {
TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByEntryPoint) { TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByEntryPoint) {
AddOverridableConstantWithoutID<float>("foo", ty.f32(), nullptr); AddOverridableConstantWithoutID<float>("foo", ty.f32(), nullptr);
MakeConstReferenceBodyFunction( MakePlainGlobalReferenceBodyFunction(
"ep_func", "foo", ty.f32(), "ep_func", "foo", ty.f32(),
{Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)}); {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
@ -564,7 +567,7 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByEntryPoint) {
TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByCallee) { TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByCallee) {
AddOverridableConstantWithoutID<float>("foo", ty.f32(), nullptr); AddOverridableConstantWithoutID<float>("foo", ty.f32(), nullptr);
MakeConstReferenceBodyFunction("callee_func", "foo", ty.f32(), {}); MakePlainGlobalReferenceBodyFunction("callee_func", "foo", ty.f32(), {});
MakeCallerBodyFunction( MakeCallerBodyFunction(
"ep_func", {"callee_func"}, "ep_func", {"callee_func"},
{Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)}); {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
@ -581,7 +584,7 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByCallee) {
TEST_F(InspectorGetEntryPointTest, OverridableConstantSomeReferenced) { TEST_F(InspectorGetEntryPointTest, OverridableConstantSomeReferenced) {
AddOverridableConstantWithID<float>("foo", 1, ty.f32(), nullptr); AddOverridableConstantWithID<float>("foo", 1, ty.f32(), nullptr);
AddOverridableConstantWithID<float>("bar", 2, ty.f32(), nullptr); AddOverridableConstantWithID<float>("bar", 2, ty.f32(), nullptr);
MakeConstReferenceBodyFunction("callee_func", "foo", ty.f32(), {}); MakePlainGlobalReferenceBodyFunction("callee_func", "foo", ty.f32(), {});
MakeCallerBodyFunction( MakeCallerBodyFunction(
"ep_func", {"callee_func"}, "ep_func", {"callee_func"},
{Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)}); {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
@ -2397,6 +2400,93 @@ TEST_F(InspectorGetSamplerTextureUsesTest, InFunction) {
EXPECT_EQ(0u, result[0].texture_binding_point.binding); EXPECT_EQ(0u, result[0].texture_binding_point.binding);
} }
TEST_F(InspectorGetWorkgroupStorageSizeTest, Empty) {
MakeEmptyBodyFunction("ep_func",
ast::DecorationList{Stage(ast::PipelineStage::kCompute),
WorkgroupSize(1)});
Inspector& inspector = Build();
EXPECT_EQ(0u, inspector.GetWorkgroupStorageSize("ep_func"));
}
TEST_F(InspectorGetWorkgroupStorageSizeTest, Simple) {
AddWorkgroupStorage("wg_f32", ty.f32());
MakePlainGlobalReferenceBodyFunction("f32_func", "wg_f32", ty.f32(), {});
MakeCallerBodyFunction("ep_func", {"f32_func"},
ast::DecorationList{
Stage(ast::PipelineStage::kCompute),
WorkgroupSize(1),
});
Inspector& inspector = Build();
EXPECT_EQ(4u, inspector.GetWorkgroupStorageSize("ep_func"));
}
TEST_F(InspectorGetWorkgroupStorageSizeTest, CompoundTypes) {
// This struct should occupy 68 bytes. 4 from the i32 field, and another 64
// from the 4-element array with 16-byte stride.
ast::Struct* wg_struct_type = MakeStructType(
"WgStruct", {ty.i32(), ty.array(ty.i32(), 4, /*stride=*/16)},
/*is_block=*/false);
AddWorkgroupStorage("wg_struct_var", ty.Of(wg_struct_type));
MakeStructVariableReferenceBodyFunction("wg_struct_func", "wg_struct_var",
{{0, ty.i32()}});
// Plus another 4 bytes from this other workgroup-class f32.
AddWorkgroupStorage("wg_f32", ty.f32());
MakePlainGlobalReferenceBodyFunction("f32_func", "wg_f32", ty.f32(), {});
MakeCallerBodyFunction("ep_func", {"wg_struct_func", "f32_func"},
ast::DecorationList{
Stage(ast::PipelineStage::kCompute),
WorkgroupSize(1),
});
Inspector& inspector = Build();
EXPECT_EQ(72u, inspector.GetWorkgroupStorageSize("ep_func"));
}
TEST_F(InspectorGetWorkgroupStorageSizeTest, AlignmentPadding) {
// vec3<f32> has an alignment of 16 but a size of 12. We leverage this to test
// that our padded size calculation for workgroup storage is accurate.
AddWorkgroupStorage("wg_vec3", ty.vec3<f32>());
MakePlainGlobalReferenceBodyFunction("wg_func", "wg_vec3", ty.vec3<f32>(),
{});
MakeCallerBodyFunction("ep_func", {"wg_func"},
ast::DecorationList{
Stage(ast::PipelineStage::kCompute),
WorkgroupSize(1),
});
Inspector& inspector = Build();
EXPECT_EQ(16u, inspector.GetWorkgroupStorageSize("ep_func"));
}
TEST_F(InspectorGetWorkgroupStorageSizeTest, StructAlignment) {
// Per WGSL spec, a struct's size is the offset its last member plus the size
// of its last member, rounded up to the alignment of its largest member. So
// here the struct is expected to occupy 1024 bytes of workgroup storage.
ast::Struct* wg_struct_type = MakeStructTypeFromMembers(
"WgStruct",
{MakeStructMember(0, ty.f32(),
{create<ast::StructMemberAlignDecoration>(1024)})},
/*is_block=*/false);
AddWorkgroupStorage("wg_struct_var", ty.Of(wg_struct_type));
MakeStructVariableReferenceBodyFunction("wg_struct_func", "wg_struct_var",
{{0, ty.f32()}});
MakeCallerBodyFunction("ep_func", {"wg_struct_func"},
ast::DecorationList{
Stage(ast::PipelineStage::kCompute),
WorkgroupSize(1),
});
Inspector& inspector = Build();
EXPECT_EQ(1024u, inspector.GetWorkgroupStorageSize("ep_func"));
}
} // namespace } // namespace
} // namespace inspector } // namespace inspector
} // namespace tint } // namespace tint

View File

@ -60,7 +60,7 @@ ast::Struct* InspectorBuilder::MakeInOutStruct(
return Structure(name, members); return Structure(name, members);
} }
ast::Function* InspectorBuilder::MakeConstReferenceBodyFunction( ast::Function* InspectorBuilder::MakePlainGlobalReferenceBodyFunction(
std::string func, std::string func,
std::string var, std::string var,
ast::Type* type, ast::Type* type,
@ -93,15 +93,27 @@ ast::Struct* InspectorBuilder::MakeStructType(
bool is_block) { bool is_block) {
ast::StructMemberList members; ast::StructMemberList members;
for (auto* type : member_types) { for (auto* type : member_types) {
members.push_back(Member(StructMemberName(members.size(), type), type)); members.push_back(MakeStructMember(members.size(), type, {}));
} }
return MakeStructTypeFromMembers(name, std::move(members), is_block);
}
ast::Struct* InspectorBuilder::MakeStructTypeFromMembers(
const std::string& name,
ast::StructMemberList members,
bool is_block) {
ast::DecorationList decos; ast::DecorationList decos;
if (is_block) { if (is_block) {
decos.push_back(create<ast::StructBlockDecoration>()); decos.push_back(create<ast::StructBlockDecoration>());
} }
return Structure(name, std::move(members), decos);
}
return Structure(name, members, decos); ast::StructMember* InspectorBuilder::MakeStructMember(
size_t index,
ast::Type* type,
ast::DecorationList decorations) {
return Member(StructMemberName(index, type), type, std::move(decorations));
} }
ast::Struct* InspectorBuilder::MakeUniformBufferType( ast::Struct* InspectorBuilder::MakeUniformBufferType(
@ -128,6 +140,11 @@ void InspectorBuilder::AddUniformBuffer(const std::string& name,
}); });
} }
void InspectorBuilder::AddWorkgroupStorage(const std::string& name,
ast::Type* type) {
Global(name, type, ast::StorageClass::kWorkgroup);
}
void InspectorBuilder::AddStorageBuffer(const std::string& name, void InspectorBuilder::AddStorageBuffer(const std::string& name,
ast::Type* type, ast::Type* type,
ast::Access access, ast::Access access,

View File

@ -139,13 +139,14 @@ class InspectorBuilder : public ProgramBuilder {
}); });
} }
/// Generates a function that references module constant /// Generates a function that references module-scoped, plain-typed constant
/// or variable.
/// @param func name of the function created /// @param func name of the function created
/// @param var name of the constant to be reference /// @param var name of the constant to be reference
/// @param type type of the const being referenced /// @param type type of the const being referenced
/// @param decorations the function decorations /// @param decorations the function decorations
/// @returns a function object /// @returns a function object
ast::Function* MakeConstReferenceBodyFunction( ast::Function* MakePlainGlobalReferenceBodyFunction(
std::string func, std::string func,
std::string var, std::string var,
ast::Type* type, ast::Type* type,
@ -172,6 +173,24 @@ class InspectorBuilder : public ProgramBuilder {
std::vector<ast::Type*> member_types, std::vector<ast::Type*> member_types,
bool is_block); bool is_block);
/// Generates a struct type from a list of member nodes.
/// @param name name for the struct type
/// @param members a vector of members
/// @param is_block whether or not to decorate as a Block
/// @returns a struct type
ast::Struct* MakeStructTypeFromMembers(const std::string& name,
ast::StructMemberList members,
bool is_block);
/// Generates a struct member with a specified index and type.
/// @param index index of the field within the struct
/// @param type the type of the member field
/// @param decorations a list of decorations to apply to the member field
/// @returns a struct member
ast::StructMember* MakeStructMember(size_t index,
ast::Type* type,
ast::DecorationList decorations);
/// Generates types appropriate for using in an uniform buffer /// Generates types appropriate for using in an uniform buffer
/// @param name name for the type /// @param name name for the type
/// @param member_types a vector of member types /// @param member_types a vector of member types
@ -197,6 +216,11 @@ class InspectorBuilder : public ProgramBuilder {
uint32_t group, uint32_t group,
uint32_t binding); uint32_t binding);
/// Adds a workgroup storage variable to the program
/// @param name the name of the variable
/// @param type the type of the variable
void AddWorkgroupStorage(const std::string& name, ast::Type* type);
/// Adds a storage buffer variable to the program /// Adds a storage buffer variable to the program
/// @param name the name of the variable /// @param name the name of the variable
/// @param type the type to use /// @param type the type to use

View File

@ -723,7 +723,7 @@ bool Resolver::ValidateStorageClassLayout(const sem::Struct* str,
auto required_alignment_of = [&](const sem::Type* ty) { auto required_alignment_of = [&](const sem::Type* ty) {
uint32_t actual_align = 0; uint32_t actual_align = 0;
uint32_t actual_size = 0; uint32_t actual_size = 0;
DefaultAlignAndSize(ty, actual_align, actual_size); ty->GetDefaultAlignAndSize(actual_align, actual_size);
uint32_t required_align = actual_align; uint32_t required_align = actual_align;
if (is_uniform_struct_or_array(ty)) { if (is_uniform_struct_or_array(ty)) {
required_align = utils::RoundUp(16u, actual_align); required_align = utils::RoundUp(16u, actual_align);
@ -3750,69 +3750,6 @@ void Resolver::CreateSemanticNodes() const {
} }
} }
bool Resolver::DefaultAlignAndSize(const sem::Type* ty,
uint32_t& align,
uint32_t& size) {
static constexpr uint32_t vector_size[] = {
/* padding */ 0,
/* padding */ 0,
/*vec2*/ 8,
/*vec3*/ 12,
/*vec4*/ 16,
};
static constexpr uint32_t vector_align[] = {
/* padding */ 0,
/* padding */ 0,
/*vec2*/ 8,
/*vec3*/ 16,
/*vec4*/ 16,
};
if (ty->is_scalar()) {
// Note: Also captures booleans, but these are not host-shareable.
align = 4;
size = 4;
return true;
}
if (auto* vec = ty->As<sem::Vector>()) {
if (vec->size() < 2 || vec->size() > 4) {
TINT_UNREACHABLE(Resolver, diagnostics_)
<< "Invalid vector size: vec" << vec->size();
return false;
}
align = vector_align[vec->size()];
size = vector_size[vec->size()];
return true;
}
if (auto* mat = ty->As<sem::Matrix>()) {
if (mat->columns() < 2 || mat->columns() > 4 || mat->rows() < 2 ||
mat->rows() > 4) {
TINT_UNREACHABLE(Resolver, diagnostics_)
<< "Invalid matrix size: mat" << mat->columns() << "x" << mat->rows();
return false;
}
align = vector_align[mat->rows()];
size = vector_align[mat->rows()] * mat->columns();
return true;
}
if (auto* s = ty->As<sem::Struct>()) {
align = s->Align();
size = s->Size();
return true;
}
if (auto* a = ty->As<sem::Array>()) {
align = a->Align();
size = a->SizeInBytes();
return true;
}
if (auto* a = ty->As<sem::Atomic>()) {
return DefaultAlignAndSize(a->Type(), align, size);
}
TINT_UNREACHABLE(Resolver, diagnostics_)
<< "invalid type " << ty->TypeInfo().name;
return false;
}
sem::Array* Resolver::Array(const ast::Array* arr) { sem::Array* Resolver::Array(const ast::Array* arr) {
auto source = arr->source(); auto source = arr->source();
@ -3821,7 +3758,7 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
return nullptr; return nullptr;
} }
if (!IsPlain(el_ty)) { // Check must come before DefaultAlignAndSize() if (!IsPlain(el_ty)) { // Check must come before GetDefaultAlignAndSize()
AddError(el_ty->FriendlyName(builder_->Symbols()) + AddError(el_ty->FriendlyName(builder_->Symbols()) +
" cannot be used as an element type of an array", " cannot be used as an element type of an array",
source); source);
@ -3830,9 +3767,7 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
uint32_t el_align = 0; uint32_t el_align = 0;
uint32_t el_size = 0; uint32_t el_size = 0;
if (!DefaultAlignAndSize(el_ty, el_align, el_size)) { el_ty->GetDefaultAlignAndSize(el_align, el_size);
return nullptr;
}
if (!ValidateNoDuplicateDecorations(arr->decorations())) { if (!ValidateNoDuplicateDecorations(arr->decorations())) {
return nullptr; return nullptr;
@ -4040,9 +3975,7 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) {
uint32_t offset = struct_size; uint32_t offset = struct_size;
uint32_t align = 0; uint32_t align = 0;
uint32_t size = 0; uint32_t size = 0;
if (!DefaultAlignAndSize(type, align, size)) { type->GetDefaultAlignAndSize(align, size);
return nullptr;
}
if (!ValidateNoDuplicateDecorations(member->decorations())) { if (!ValidateNoDuplicateDecorations(member->decorations())) {
return nullptr; return nullptr;

View File

@ -375,13 +375,6 @@ class Resolver {
sem::Type* ty, sem::Type* ty,
const Source& usage); const Source& usage);
/// @param align the output default alignment in bytes for the type `ty`
/// @param size the output default size in bytes for the type `ty`
/// @returns true on success, false on error
bool DefaultAlignAndSize(const sem::Type* ty,
uint32_t& align,
uint32_t& size);
/// @param storage_class the storage class /// @param storage_class the storage class
/// @returns the default access control for the given storage class /// @returns the default access control for the given storage class
ast::Access DefaultAccessForStorageClass(ast::StorageClass storage_class); ast::Access DefaultAccessForStorageClass(ast::StorageClass storage_class);

View File

@ -14,6 +14,9 @@
#include "src/sem/type.h" #include "src/sem/type.h"
#include "src/debug.h"
#include "src/sem/array.h"
#include "src/sem/atomic_type.h"
#include "src/sem/bool_type.h" #include "src/sem/bool_type.h"
#include "src/sem/f32_type.h" #include "src/sem/f32_type.h"
#include "src/sem/i32_type.h" #include "src/sem/i32_type.h"
@ -21,6 +24,7 @@
#include "src/sem/pointer_type.h" #include "src/sem/pointer_type.h"
#include "src/sem/reference_type.h" #include "src/sem/reference_type.h"
#include "src/sem/sampler_type.h" #include "src/sem/sampler_type.h"
#include "src/sem/struct.h"
#include "src/sem/texture_type.h" #include "src/sem/texture_type.h"
#include "src/sem/u32_type.h" #include "src/sem/u32_type.h"
#include "src/sem/vector_type.h" #include "src/sem/vector_type.h"
@ -52,6 +56,61 @@ const Type* Type::UnwrapRef() const {
return type; return type;
} }
void Type::GetDefaultAlignAndSize(uint32_t& align, uint32_t& size) const {
TINT_ASSERT(Semantic, !As<Reference>());
TINT_ASSERT(Semantic, !As<Pointer>());
static constexpr uint32_t vector_size[] = {
/* padding */ 0,
/* padding */ 0,
/*vec2*/ 8,
/*vec3*/ 12,
/*vec4*/ 16,
};
static constexpr uint32_t vector_align[] = {
/* padding */ 0,
/* padding */ 0,
/*vec2*/ 8,
/*vec3*/ 16,
/*vec4*/ 16,
};
if (is_scalar()) {
// Note: Also captures booleans, but these are not host-shareable.
align = 4;
size = 4;
return;
}
if (auto* vec = As<Vector>()) {
TINT_ASSERT(Semantic, vec->size() >= 2 && vec->size() <= 4);
align = vector_align[vec->size()];
size = vector_size[vec->size()];
return;
}
if (auto* mat = As<Matrix>()) {
TINT_ASSERT(Semantic, mat->columns() >= 2 && mat->columns() <= 4);
TINT_ASSERT(Semantic, mat->rows() >= 2 && mat->rows() <= 4);
align = vector_align[mat->rows()];
size = vector_align[mat->rows()] * mat->columns();
return;
}
if (auto* s = As<Struct>()) {
align = s->Align();
size = s->Size();
return;
}
if (auto* a = As<Array>()) {
align = a->Align();
size = a->SizeInBytes();
return;
}
if (auto* a = As<Atomic>()) {
return a->Type()->GetDefaultAlignAndSize(align, size);
}
TINT_ASSERT(Semantic, false);
}
bool Type::is_scalar() const { bool Type::is_scalar() const {
return IsAnyOf<F32, U32, I32, Bool>(); return IsAnyOf<F32, U32, I32, Bool>();
} }

View File

@ -52,6 +52,10 @@ class Type : public Castable<Type, Node> {
/// @returns the inner type if this is a reference, `this` otherwise /// @returns the inner type if this is a reference, `this` otherwise
const Type* UnwrapRef() const; const Type* UnwrapRef() const;
/// @param align the output default alignment in bytes for this type.
/// @param size the output default size in bytes for this type.
void GetDefaultAlignAndSize(uint32_t& align, uint32_t& size) const;
/// @returns true if this type is a scalar /// @returns true if this type is a scalar
bool is_scalar() const; bool is_scalar() const;
/// @returns true if this type is a numeric scalar /// @returns true if this type is a numeric scalar