[ast] Add helpers for searching a decoration list

This is a commonly used pattern.

Change-Id: I698397c93c33db64c53cbe8662186e1976075b80
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47280
Auto-Submit: James Price <jrprice@google.com>
Commit-Queue: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
James Price
2021-04-08 21:53:27 +00:00
committed by Commit Bot service account
parent c76ec15b45
commit a12ccb20a0
14 changed files with 69 additions and 108 deletions

View File

@@ -36,6 +36,30 @@ class Decoration : public Castable<Decoration, Node> {
/// A list of decorations
using DecorationList = std::vector<Decoration*>;
/// @param decorations the list of decorations to search
/// @returns true if `decorations` includes a decoration of type `T`
template <typename T>
bool HasDecoration(const DecorationList& decorations) {
for (auto* deco : decorations) {
if (deco->Is<T>()) {
return true;
}
}
return false;
}
/// @param decorations the list of decorations to search
/// @returns a pointer to `T` from `decorations` if found, otherwise nullptr.
template <typename T>
T* GetDecoration(const DecorationList& decorations) {
for (auto* deco : decorations) {
if (deco->Is<T>()) {
return deco->As<T>();
}
}
return nullptr;
}
} // namespace ast
} // namespace tint

View File

@@ -49,19 +49,15 @@ Function::Function(Function&&) = default;
Function::~Function() = default;
std::tuple<uint32_t, uint32_t, uint32_t> Function::workgroup_size() const {
for (auto* deco : decorations_) {
if (auto* workgroup = deco->As<WorkgroupDecoration>()) {
return workgroup->values();
}
if (auto* workgroup = GetDecoration<WorkgroupDecoration>(decorations_)) {
return workgroup->values();
}
return {1, 1, 1};
}
PipelineStage Function::pipeline_stage() const {
for (auto* deco : decorations_) {
if (auto* stage = deco->As<StageDecoration>()) {
return stage->value();
}
if (auto* stage = GetDecoration<StageDecoration>(decorations_)) {
return stage->value();
}
return PipelineStage::kNone;
}

View File

@@ -63,18 +63,6 @@ class Function : public Castable<Function, Node> {
/// @returns the decorations attached to this function
const DecorationList& decorations() const { return decorations_; }
/// @returns the decoration with the type `T` or nullptr if this function does
/// not contain a decoration with the given type
template <typename T>
const T* find_decoration() const {
for (auto* deco : decorations()) {
if (auto* d = deco->As<T>()) {
return d;
}
}
return nullptr;
}
/// @returns the workgroup size {x, y, z} for the function. {1, 1, 1} will be
/// return if no workgroup size was set.
std::tuple<uint32_t, uint32_t, uint32_t> workgroup_size() const;

View File

@@ -50,12 +50,7 @@ StructMember* Struct::get_member(const Symbol& symbol) const {
}
bool Struct::IsBlockDecorated() const {
for (auto* deco : decorations_) {
if (deco->Is<StructBlockDecoration>()) {
return true;
}
}
return false;
return HasDecoration<StructBlockDecoration>(decorations_);
}
Struct* Struct::Clone(CloneContext* ctx) const {

View File

@@ -41,19 +41,13 @@ StructMember::StructMember(StructMember&&) = default;
StructMember::~StructMember() = default;
bool StructMember::has_offset_decoration() const {
for (auto* deco : decorations_) {
if (deco->Is<StructMemberOffsetDecoration>()) {
return true;
}
}
return false;
return HasDecoration<StructMemberOffsetDecoration>(decorations_);
}
uint32_t StructMember::offset() const {
for (auto* deco : decorations_) {
if (auto* offset = deco->As<StructMemberOffsetDecoration>()) {
return offset->offset();
}
if (auto* offset =
GetDecoration<StructMemberOffsetDecoration>(decorations_)) {
return offset->offset();
}
return 0;
}

View File

@@ -59,49 +59,11 @@ Variable::BindingPoint Variable::binding_point() const {
return BindingPoint{group, binding};
}
bool Variable::HasLocationDecoration() const {
for (auto* deco : decorations_) {
if (deco->Is<LocationDecoration>()) {
return true;
}
}
return false;
}
bool Variable::HasBuiltinDecoration() const {
for (auto* deco : decorations_) {
if (deco->Is<BuiltinDecoration>()) {
return true;
}
}
return false;
}
bool Variable::HasConstantIdDecoration() const {
for (auto* deco : decorations_) {
if (deco->Is<ConstantIdDecoration>()) {
return true;
}
}
return false;
}
LocationDecoration* Variable::GetLocationDecoration() const {
for (auto* deco : decorations_) {
if (deco->Is<LocationDecoration>()) {
return deco->As<LocationDecoration>();
}
}
return nullptr;
}
uint32_t Variable::constant_id() const {
TINT_ASSERT(HasConstantIdDecoration());
for (auto* deco : decorations_) {
if (auto* cid = deco->As<ConstantIdDecoration>()) {
return cid->value();
}
if (auto* cid = GetDecoration<ConstantIdDecoration>(decorations_)) {
return cid->value();
}
TINT_ASSERT(false);
return 0;
}

View File

@@ -134,18 +134,8 @@ class Variable : public Castable<Variable, Node> {
/// @returns the binding point information for the variable
BindingPoint binding_point() const;
/// @returns true if the decorations include a LocationDecoration
bool HasLocationDecoration() const;
/// @returns true if the decorations include a BuiltinDecoration
bool HasBuiltinDecoration() const;
/// @returns true if the decorations include a ConstantIdDecoration
bool HasConstantIdDecoration() const;
/// @returns pointer to LocationDecoration in decorations, otherwise NULL.
LocationDecoration* GetLocationDecoration() const;
/// @returns the constant_id value for the variable. Assumes that
/// HasConstantIdDecoration() has been called first.
/// @returns the constant_id value for the variable. Assumes that this
/// variable has a constant ID decoration.
uint32_t constant_id() const;
/// Clones this node and all transitive child nodes using the `CloneContext`

View File

@@ -98,11 +98,12 @@ TEST_F(VariableTest, WithDecorations) {
create<ConstantIdDecoration>(1200),
});
EXPECT_TRUE(var->HasLocationDecoration());
EXPECT_TRUE(var->HasBuiltinDecoration());
EXPECT_TRUE(var->HasConstantIdDecoration());
auto& decorations = var->decorations();
EXPECT_TRUE(ast::HasDecoration<ast::LocationDecoration>(decorations));
EXPECT_TRUE(ast::HasDecoration<ast::BuiltinDecoration>(decorations));
EXPECT_TRUE(ast::HasDecoration<ast::ConstantIdDecoration>(decorations));
auto* location = var->GetLocationDecoration();
auto* location = ast::GetDecoration<ast::LocationDecoration>(decorations);
ASSERT_NE(nullptr, location);
EXPECT_EQ(1u, location->value());
}