[ast] Add helpers for searching a decoration list

This is a commonly used pattern.

Change-Id: I698397c93c33db64c53cbe8662186e1976075b80
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47280
Auto-Submit: James Price <jrprice@google.com>
Commit-Queue: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
James Price 2021-04-08 21:53:27 +00:00 committed by Commit Bot service account
parent c76ec15b45
commit a12ccb20a0
14 changed files with 69 additions and 108 deletions

View File

@ -36,6 +36,30 @@ class Decoration : public Castable<Decoration, Node> {
/// A list of decorations
using DecorationList = std::vector<Decoration*>;
/// @param decorations the list of decorations to search
/// @returns true if `decorations` includes a decoration of type `T`
template <typename T>
bool HasDecoration(const DecorationList& decorations) {
for (auto* deco : decorations) {
if (deco->Is<T>()) {
return true;
}
}
return false;
}
/// @param decorations the list of decorations to search
/// @returns a pointer to `T` from `decorations` if found, otherwise nullptr.
template <typename T>
T* GetDecoration(const DecorationList& decorations) {
for (auto* deco : decorations) {
if (deco->Is<T>()) {
return deco->As<T>();
}
}
return nullptr;
}
} // namespace ast
} // namespace tint

View File

@ -49,19 +49,15 @@ Function::Function(Function&&) = default;
Function::~Function() = default;
std::tuple<uint32_t, uint32_t, uint32_t> Function::workgroup_size() const {
for (auto* deco : decorations_) {
if (auto* workgroup = deco->As<WorkgroupDecoration>()) {
return workgroup->values();
}
if (auto* workgroup = GetDecoration<WorkgroupDecoration>(decorations_)) {
return workgroup->values();
}
return {1, 1, 1};
}
PipelineStage Function::pipeline_stage() const {
for (auto* deco : decorations_) {
if (auto* stage = deco->As<StageDecoration>()) {
return stage->value();
}
if (auto* stage = GetDecoration<StageDecoration>(decorations_)) {
return stage->value();
}
return PipelineStage::kNone;
}

View File

@ -63,18 +63,6 @@ class Function : public Castable<Function, Node> {
/// @returns the decorations attached to this function
const DecorationList& decorations() const { return decorations_; }
/// @returns the decoration with the type `T` or nullptr if this function does
/// not contain a decoration with the given type
template <typename T>
const T* find_decoration() const {
for (auto* deco : decorations()) {
if (auto* d = deco->As<T>()) {
return d;
}
}
return nullptr;
}
/// @returns the workgroup size {x, y, z} for the function. {1, 1, 1} will be
/// return if no workgroup size was set.
std::tuple<uint32_t, uint32_t, uint32_t> workgroup_size() const;

View File

@ -50,12 +50,7 @@ StructMember* Struct::get_member(const Symbol& symbol) const {
}
bool Struct::IsBlockDecorated() const {
for (auto* deco : decorations_) {
if (deco->Is<StructBlockDecoration>()) {
return true;
}
}
return false;
return HasDecoration<StructBlockDecoration>(decorations_);
}
Struct* Struct::Clone(CloneContext* ctx) const {

View File

@ -41,19 +41,13 @@ StructMember::StructMember(StructMember&&) = default;
StructMember::~StructMember() = default;
bool StructMember::has_offset_decoration() const {
for (auto* deco : decorations_) {
if (deco->Is<StructMemberOffsetDecoration>()) {
return true;
}
}
return false;
return HasDecoration<StructMemberOffsetDecoration>(decorations_);
}
uint32_t StructMember::offset() const {
for (auto* deco : decorations_) {
if (auto* offset = deco->As<StructMemberOffsetDecoration>()) {
return offset->offset();
}
if (auto* offset =
GetDecoration<StructMemberOffsetDecoration>(decorations_)) {
return offset->offset();
}
return 0;
}

View File

@ -59,49 +59,11 @@ Variable::BindingPoint Variable::binding_point() const {
return BindingPoint{group, binding};
}
bool Variable::HasLocationDecoration() const {
for (auto* deco : decorations_) {
if (deco->Is<LocationDecoration>()) {
return true;
}
}
return false;
}
bool Variable::HasBuiltinDecoration() const {
for (auto* deco : decorations_) {
if (deco->Is<BuiltinDecoration>()) {
return true;
}
}
return false;
}
bool Variable::HasConstantIdDecoration() const {
for (auto* deco : decorations_) {
if (deco->Is<ConstantIdDecoration>()) {
return true;
}
}
return false;
}
LocationDecoration* Variable::GetLocationDecoration() const {
for (auto* deco : decorations_) {
if (deco->Is<LocationDecoration>()) {
return deco->As<LocationDecoration>();
}
}
return nullptr;
}
uint32_t Variable::constant_id() const {
TINT_ASSERT(HasConstantIdDecoration());
for (auto* deco : decorations_) {
if (auto* cid = deco->As<ConstantIdDecoration>()) {
return cid->value();
}
if (auto* cid = GetDecoration<ConstantIdDecoration>(decorations_)) {
return cid->value();
}
TINT_ASSERT(false);
return 0;
}

View File

@ -134,18 +134,8 @@ class Variable : public Castable<Variable, Node> {
/// @returns the binding point information for the variable
BindingPoint binding_point() const;
/// @returns true if the decorations include a LocationDecoration
bool HasLocationDecoration() const;
/// @returns true if the decorations include a BuiltinDecoration
bool HasBuiltinDecoration() const;
/// @returns true if the decorations include a ConstantIdDecoration
bool HasConstantIdDecoration() const;
/// @returns pointer to LocationDecoration in decorations, otherwise NULL.
LocationDecoration* GetLocationDecoration() const;
/// @returns the constant_id value for the variable. Assumes that
/// HasConstantIdDecoration() has been called first.
/// @returns the constant_id value for the variable. Assumes that this
/// variable has a constant ID decoration.
uint32_t constant_id() const;
/// Clones this node and all transitive child nodes using the `CloneContext`

View File

@ -98,11 +98,12 @@ TEST_F(VariableTest, WithDecorations) {
create<ConstantIdDecoration>(1200),
});
EXPECT_TRUE(var->HasLocationDecoration());
EXPECT_TRUE(var->HasBuiltinDecoration());
EXPECT_TRUE(var->HasConstantIdDecoration());
auto& decorations = var->decorations();
EXPECT_TRUE(ast::HasDecoration<ast::LocationDecoration>(decorations));
EXPECT_TRUE(ast::HasDecoration<ast::BuiltinDecoration>(decorations));
EXPECT_TRUE(ast::HasDecoration<ast::ConstantIdDecoration>(decorations));
auto* location = var->GetLocationDecoration();
auto* location = ast::GetDecoration<ast::LocationDecoration>(decorations);
ASSERT_NE(nullptr, location);
EXPECT_EQ(1u, location->value());
}

View File

@ -18,6 +18,7 @@
#include "src/ast/bool_literal.h"
#include "src/ast/float_literal.h"
#include "src/ast/constant_id_decoration.h"
#include "src/ast/module.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/sint_literal.h"
@ -203,7 +204,7 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
auto* decl = var->Declaration();
auto name = program_->Symbols().NameFor(decl->symbol());
if (decl->HasBuiltinDecoration()) {
if (ast::HasDecoration<ast::BuiltinDecoration>(decl->decorations())) {
continue;
}
@ -220,7 +221,8 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
stage_variable.component_type = ComponentType::kSInt;
}
auto* location_decoration = decl->GetLocationDecoration();
auto* location_decoration =
ast::GetDecoration<ast::LocationDecoration>(decl->decorations());
if (location_decoration) {
stage_variable.has_location_decoration = true;
stage_variable.location_decoration = location_decoration->value();
@ -257,7 +259,7 @@ std::string Inspector::GetRemappedNameForEntryPoint(
std::map<uint32_t, Scalar> Inspector::GetConstantIDs() {
std::map<uint32_t, Scalar> result;
for (auto* var : program_->AST().GlobalVariables()) {
if (!var->HasConstantIdDecoration()) {
if (!ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
continue;
}

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/ast/constant_id_decoration.h"
#include "src/reader/wgsl/parser_impl_test_helper.h"
namespace tint {
@ -43,7 +44,8 @@ TEST_F(ParserImplTest, GlobalConstantDecl) {
ASSERT_NE(e->constructor(), nullptr);
EXPECT_TRUE(e->constructor()->Is<ast::ConstructorExpression>());
EXPECT_FALSE(e.value->HasConstantIdDecoration());
EXPECT_FALSE(
ast::HasDecoration<ast::ConstantIdDecoration>(e.value->decorations()));
}
TEST_F(ParserImplTest, GlobalConstantDecl_MissingEqual) {
@ -123,7 +125,8 @@ TEST_F(ParserImplTest, GlobalConstantDec_ConstantId) {
ASSERT_NE(e->constructor(), nullptr);
EXPECT_TRUE(e->constructor()->Is<ast::ConstructorExpression>());
EXPECT_TRUE(e.value->HasConstantIdDecoration());
EXPECT_TRUE(
ast::HasDecoration<ast::ConstantIdDecoration>(e.value->decorations()));
EXPECT_EQ(e.value->constant_id(), 7u);
}

View File

@ -310,7 +310,8 @@ bool Resolver::ValidateFunction(const ast::Function* func) {
func->source());
return false;
}
} else if (!func->find_decoration<ast::InternalDecoration>()) {
} else if (!ast::HasDecoration<ast::InternalDecoration>(
func->decorations())) {
TINT_ICE(diagnostics_)
<< "Function " << builder_->Symbols().NameFor(func->symbol())
<< " has no body and does not have the [[internal]] decoration";

View File

@ -1310,8 +1310,9 @@ bool GeneratorImpl::EmitExpression(std::ostream& pre,
}
bool GeneratorImpl::global_is_in_struct(const semantic::Variable* var) const {
if (var->Declaration()->HasLocationDecoration() ||
var->Declaration()->HasBuiltinDecoration()) {
auto& decorations = var->Declaration()->decorations();
if (ast::HasDecoration<ast::LocationDecoration>(decorations) ||
ast::HasDecoration<ast::BuiltinDecoration>(decorations)) {
return var->StorageClass() == ast::StorageClass::kInput ||
var->StorageClass() == ast::StorageClass::kOutput;
}
@ -1463,7 +1464,7 @@ bool GeneratorImpl::EmitFunction(std::ostream& out, ast::Function* func) {
auto* func_sem = builder_.Sem().Get(func);
if (func->find_decoration<ast::InternalDecoration>()) {
if (ast::HasDecoration<ast::InternalDecoration>(func->decorations())) {
// An internal function. Do not emit.
return true;
}
@ -2825,7 +2826,7 @@ bool GeneratorImpl::EmitProgramConstVariable(std::ostream& out,
auto* type = builder_.Sem().Get(var)->Type();
if (var->HasConstantIdDecoration()) {
if (ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
auto const_id = var->constant_id();
out << "#ifndef WGSL_SPEC_CONSTANT_" << const_id << std::endl;

View File

@ -1562,12 +1562,15 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) {
}
bool GeneratorImpl::global_is_in_struct(const semantic::Variable* var) const {
auto& decorations = var->Declaration()->decorations();
bool in_or_out_struct_has_location =
var != nullptr && var->Declaration()->HasLocationDecoration() &&
var != nullptr &&
ast::HasDecoration<ast::LocationDecoration>(decorations) &&
(var->StorageClass() == ast::StorageClass::kInput ||
var->StorageClass() == ast::StorageClass::kOutput);
bool in_struct_has_builtin =
var != nullptr && var->Declaration()->HasBuiltinDecoration() &&
var != nullptr &&
ast::HasDecoration<ast::BuiltinDecoration>(decorations) &&
var->StorageClass() == ast::StorageClass::kOutput;
return in_or_out_struct_has_location || in_struct_has_builtin;
}
@ -2249,7 +2252,7 @@ bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
out_ << " " << program_->Symbols().NameFor(var->symbol());
}
if (var->HasConstantIdDecoration()) {
if (ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
out_ << " [[function_constant(" << var->constant_id() << ")]]";
} else if (var->constructor() != nullptr) {
out_ << " = ";

View File

@ -750,7 +750,7 @@ bool Builder::GenerateGlobalVariable(ast::Variable* var) {
// one
// 2- If we don't have a constructor and we're an Output or Private variable
// then WGSL requires an initializer.
if (var->HasConstantIdDecoration()) {
if (ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
if (type_no_ac->Is<type::F32>()) {
ast::FloatLiteral l(Source{}, type_no_ac, 0.0f);
init_id = GenerateLiteralIfNeeded(var, &l);
@ -1490,7 +1490,8 @@ uint32_t Builder::GenerateLiteralIfNeeded(ast::Variable* var,
ast::Literal* lit) {
ScalarConstant constant;
if (var && var->HasConstantIdDecoration()) {
if (var &&
ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
constant.is_spec_op = true;
constant.constant_id = var->constant_id();
}