[ast] Add helpers for searching a decoration list
This is a commonly used pattern. Change-Id: I698397c93c33db64c53cbe8662186e1976075b80 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47280 Auto-Submit: James Price <jrprice@google.com> Commit-Queue: James Price <jrprice@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
c76ec15b45
commit
a12ccb20a0
|
@ -36,6 +36,30 @@ class Decoration : public Castable<Decoration, Node> {
|
|||
/// A list of decorations
|
||||
using DecorationList = std::vector<Decoration*>;
|
||||
|
||||
/// @param decorations the list of decorations to search
|
||||
/// @returns true if `decorations` includes a decoration of type `T`
|
||||
template <typename T>
|
||||
bool HasDecoration(const DecorationList& decorations) {
|
||||
for (auto* deco : decorations) {
|
||||
if (deco->Is<T>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// @param decorations the list of decorations to search
|
||||
/// @returns a pointer to `T` from `decorations` if found, otherwise nullptr.
|
||||
template <typename T>
|
||||
T* GetDecoration(const DecorationList& decorations) {
|
||||
for (auto* deco : decorations) {
|
||||
if (deco->Is<T>()) {
|
||||
return deco->As<T>();
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace ast
|
||||
} // namespace tint
|
||||
|
||||
|
|
|
@ -49,20 +49,16 @@ Function::Function(Function&&) = default;
|
|||
Function::~Function() = default;
|
||||
|
||||
std::tuple<uint32_t, uint32_t, uint32_t> Function::workgroup_size() const {
|
||||
for (auto* deco : decorations_) {
|
||||
if (auto* workgroup = deco->As<WorkgroupDecoration>()) {
|
||||
if (auto* workgroup = GetDecoration<WorkgroupDecoration>(decorations_)) {
|
||||
return workgroup->values();
|
||||
}
|
||||
}
|
||||
return {1, 1, 1};
|
||||
}
|
||||
|
||||
PipelineStage Function::pipeline_stage() const {
|
||||
for (auto* deco : decorations_) {
|
||||
if (auto* stage = deco->As<StageDecoration>()) {
|
||||
if (auto* stage = GetDecoration<StageDecoration>(decorations_)) {
|
||||
return stage->value();
|
||||
}
|
||||
}
|
||||
return PipelineStage::kNone;
|
||||
}
|
||||
|
||||
|
|
|
@ -63,18 +63,6 @@ class Function : public Castable<Function, Node> {
|
|||
/// @returns the decorations attached to this function
|
||||
const DecorationList& decorations() const { return decorations_; }
|
||||
|
||||
/// @returns the decoration with the type `T` or nullptr if this function does
|
||||
/// not contain a decoration with the given type
|
||||
template <typename T>
|
||||
const T* find_decoration() const {
|
||||
for (auto* deco : decorations()) {
|
||||
if (auto* d = deco->As<T>()) {
|
||||
return d;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// @returns the workgroup size {x, y, z} for the function. {1, 1, 1} will be
|
||||
/// return if no workgroup size was set.
|
||||
std::tuple<uint32_t, uint32_t, uint32_t> workgroup_size() const;
|
||||
|
|
|
@ -50,12 +50,7 @@ StructMember* Struct::get_member(const Symbol& symbol) const {
|
|||
}
|
||||
|
||||
bool Struct::IsBlockDecorated() const {
|
||||
for (auto* deco : decorations_) {
|
||||
if (deco->Is<StructBlockDecoration>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return HasDecoration<StructBlockDecoration>(decorations_);
|
||||
}
|
||||
|
||||
Struct* Struct::Clone(CloneContext* ctx) const {
|
||||
|
|
|
@ -41,20 +41,14 @@ StructMember::StructMember(StructMember&&) = default;
|
|||
StructMember::~StructMember() = default;
|
||||
|
||||
bool StructMember::has_offset_decoration() const {
|
||||
for (auto* deco : decorations_) {
|
||||
if (deco->Is<StructMemberOffsetDecoration>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return HasDecoration<StructMemberOffsetDecoration>(decorations_);
|
||||
}
|
||||
|
||||
uint32_t StructMember::offset() const {
|
||||
for (auto* deco : decorations_) {
|
||||
if (auto* offset = deco->As<StructMemberOffsetDecoration>()) {
|
||||
if (auto* offset =
|
||||
GetDecoration<StructMemberOffsetDecoration>(decorations_)) {
|
||||
return offset->offset();
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -59,49 +59,11 @@ Variable::BindingPoint Variable::binding_point() const {
|
|||
return BindingPoint{group, binding};
|
||||
}
|
||||
|
||||
bool Variable::HasLocationDecoration() const {
|
||||
for (auto* deco : decorations_) {
|
||||
if (deco->Is<LocationDecoration>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Variable::HasBuiltinDecoration() const {
|
||||
for (auto* deco : decorations_) {
|
||||
if (deco->Is<BuiltinDecoration>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Variable::HasConstantIdDecoration() const {
|
||||
for (auto* deco : decorations_) {
|
||||
if (deco->Is<ConstantIdDecoration>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
LocationDecoration* Variable::GetLocationDecoration() const {
|
||||
for (auto* deco : decorations_) {
|
||||
if (deco->Is<LocationDecoration>()) {
|
||||
return deco->As<LocationDecoration>();
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
uint32_t Variable::constant_id() const {
|
||||
TINT_ASSERT(HasConstantIdDecoration());
|
||||
for (auto* deco : decorations_) {
|
||||
if (auto* cid = deco->As<ConstantIdDecoration>()) {
|
||||
if (auto* cid = GetDecoration<ConstantIdDecoration>(decorations_)) {
|
||||
return cid->value();
|
||||
}
|
||||
}
|
||||
TINT_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -134,18 +134,8 @@ class Variable : public Castable<Variable, Node> {
|
|||
/// @returns the binding point information for the variable
|
||||
BindingPoint binding_point() const;
|
||||
|
||||
/// @returns true if the decorations include a LocationDecoration
|
||||
bool HasLocationDecoration() const;
|
||||
/// @returns true if the decorations include a BuiltinDecoration
|
||||
bool HasBuiltinDecoration() const;
|
||||
/// @returns true if the decorations include a ConstantIdDecoration
|
||||
bool HasConstantIdDecoration() const;
|
||||
|
||||
/// @returns pointer to LocationDecoration in decorations, otherwise NULL.
|
||||
LocationDecoration* GetLocationDecoration() const;
|
||||
|
||||
/// @returns the constant_id value for the variable. Assumes that
|
||||
/// HasConstantIdDecoration() has been called first.
|
||||
/// @returns the constant_id value for the variable. Assumes that this
|
||||
/// variable has a constant ID decoration.
|
||||
uint32_t constant_id() const;
|
||||
|
||||
/// Clones this node and all transitive child nodes using the `CloneContext`
|
||||
|
|
|
@ -98,11 +98,12 @@ TEST_F(VariableTest, WithDecorations) {
|
|||
create<ConstantIdDecoration>(1200),
|
||||
});
|
||||
|
||||
EXPECT_TRUE(var->HasLocationDecoration());
|
||||
EXPECT_TRUE(var->HasBuiltinDecoration());
|
||||
EXPECT_TRUE(var->HasConstantIdDecoration());
|
||||
auto& decorations = var->decorations();
|
||||
EXPECT_TRUE(ast::HasDecoration<ast::LocationDecoration>(decorations));
|
||||
EXPECT_TRUE(ast::HasDecoration<ast::BuiltinDecoration>(decorations));
|
||||
EXPECT_TRUE(ast::HasDecoration<ast::ConstantIdDecoration>(decorations));
|
||||
|
||||
auto* location = var->GetLocationDecoration();
|
||||
auto* location = ast::GetDecoration<ast::LocationDecoration>(decorations);
|
||||
ASSERT_NE(nullptr, location);
|
||||
EXPECT_EQ(1u, location->value());
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include "src/ast/bool_literal.h"
|
||||
#include "src/ast/float_literal.h"
|
||||
#include "src/ast/constant_id_decoration.h"
|
||||
#include "src/ast/module.h"
|
||||
#include "src/ast/scalar_constructor_expression.h"
|
||||
#include "src/ast/sint_literal.h"
|
||||
|
@ -203,7 +204,7 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
|
|||
auto* decl = var->Declaration();
|
||||
|
||||
auto name = program_->Symbols().NameFor(decl->symbol());
|
||||
if (decl->HasBuiltinDecoration()) {
|
||||
if (ast::HasDecoration<ast::BuiltinDecoration>(decl->decorations())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -220,7 +221,8 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
|
|||
stage_variable.component_type = ComponentType::kSInt;
|
||||
}
|
||||
|
||||
auto* location_decoration = decl->GetLocationDecoration();
|
||||
auto* location_decoration =
|
||||
ast::GetDecoration<ast::LocationDecoration>(decl->decorations());
|
||||
if (location_decoration) {
|
||||
stage_variable.has_location_decoration = true;
|
||||
stage_variable.location_decoration = location_decoration->value();
|
||||
|
@ -257,7 +259,7 @@ std::string Inspector::GetRemappedNameForEntryPoint(
|
|||
std::map<uint32_t, Scalar> Inspector::GetConstantIDs() {
|
||||
std::map<uint32_t, Scalar> result;
|
||||
for (auto* var : program_->AST().GlobalVariables()) {
|
||||
if (!var->HasConstantIdDecoration()) {
|
||||
if (!ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/ast/constant_id_decoration.h"
|
||||
#include "src/reader/wgsl/parser_impl_test_helper.h"
|
||||
|
||||
namespace tint {
|
||||
|
@ -43,7 +44,8 @@ TEST_F(ParserImplTest, GlobalConstantDecl) {
|
|||
ASSERT_NE(e->constructor(), nullptr);
|
||||
EXPECT_TRUE(e->constructor()->Is<ast::ConstructorExpression>());
|
||||
|
||||
EXPECT_FALSE(e.value->HasConstantIdDecoration());
|
||||
EXPECT_FALSE(
|
||||
ast::HasDecoration<ast::ConstantIdDecoration>(e.value->decorations()));
|
||||
}
|
||||
|
||||
TEST_F(ParserImplTest, GlobalConstantDecl_MissingEqual) {
|
||||
|
@ -123,7 +125,8 @@ TEST_F(ParserImplTest, GlobalConstantDec_ConstantId) {
|
|||
ASSERT_NE(e->constructor(), nullptr);
|
||||
EXPECT_TRUE(e->constructor()->Is<ast::ConstructorExpression>());
|
||||
|
||||
EXPECT_TRUE(e.value->HasConstantIdDecoration());
|
||||
EXPECT_TRUE(
|
||||
ast::HasDecoration<ast::ConstantIdDecoration>(e.value->decorations()));
|
||||
EXPECT_EQ(e.value->constant_id(), 7u);
|
||||
}
|
||||
|
||||
|
|
|
@ -310,7 +310,8 @@ bool Resolver::ValidateFunction(const ast::Function* func) {
|
|||
func->source());
|
||||
return false;
|
||||
}
|
||||
} else if (!func->find_decoration<ast::InternalDecoration>()) {
|
||||
} else if (!ast::HasDecoration<ast::InternalDecoration>(
|
||||
func->decorations())) {
|
||||
TINT_ICE(diagnostics_)
|
||||
<< "Function " << builder_->Symbols().NameFor(func->symbol())
|
||||
<< " has no body and does not have the [[internal]] decoration";
|
||||
|
|
|
@ -1310,8 +1310,9 @@ bool GeneratorImpl::EmitExpression(std::ostream& pre,
|
|||
}
|
||||
|
||||
bool GeneratorImpl::global_is_in_struct(const semantic::Variable* var) const {
|
||||
if (var->Declaration()->HasLocationDecoration() ||
|
||||
var->Declaration()->HasBuiltinDecoration()) {
|
||||
auto& decorations = var->Declaration()->decorations();
|
||||
if (ast::HasDecoration<ast::LocationDecoration>(decorations) ||
|
||||
ast::HasDecoration<ast::BuiltinDecoration>(decorations)) {
|
||||
return var->StorageClass() == ast::StorageClass::kInput ||
|
||||
var->StorageClass() == ast::StorageClass::kOutput;
|
||||
}
|
||||
|
@ -1463,7 +1464,7 @@ bool GeneratorImpl::EmitFunction(std::ostream& out, ast::Function* func) {
|
|||
|
||||
auto* func_sem = builder_.Sem().Get(func);
|
||||
|
||||
if (func->find_decoration<ast::InternalDecoration>()) {
|
||||
if (ast::HasDecoration<ast::InternalDecoration>(func->decorations())) {
|
||||
// An internal function. Do not emit.
|
||||
return true;
|
||||
}
|
||||
|
@ -2825,7 +2826,7 @@ bool GeneratorImpl::EmitProgramConstVariable(std::ostream& out,
|
|||
|
||||
auto* type = builder_.Sem().Get(var)->Type();
|
||||
|
||||
if (var->HasConstantIdDecoration()) {
|
||||
if (ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
|
||||
auto const_id = var->constant_id();
|
||||
|
||||
out << "#ifndef WGSL_SPEC_CONSTANT_" << const_id << std::endl;
|
||||
|
|
|
@ -1562,12 +1562,15 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) {
|
|||
}
|
||||
|
||||
bool GeneratorImpl::global_is_in_struct(const semantic::Variable* var) const {
|
||||
auto& decorations = var->Declaration()->decorations();
|
||||
bool in_or_out_struct_has_location =
|
||||
var != nullptr && var->Declaration()->HasLocationDecoration() &&
|
||||
var != nullptr &&
|
||||
ast::HasDecoration<ast::LocationDecoration>(decorations) &&
|
||||
(var->StorageClass() == ast::StorageClass::kInput ||
|
||||
var->StorageClass() == ast::StorageClass::kOutput);
|
||||
bool in_struct_has_builtin =
|
||||
var != nullptr && var->Declaration()->HasBuiltinDecoration() &&
|
||||
var != nullptr &&
|
||||
ast::HasDecoration<ast::BuiltinDecoration>(decorations) &&
|
||||
var->StorageClass() == ast::StorageClass::kOutput;
|
||||
return in_or_out_struct_has_location || in_struct_has_builtin;
|
||||
}
|
||||
|
@ -2249,7 +2252,7 @@ bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
|
|||
out_ << " " << program_->Symbols().NameFor(var->symbol());
|
||||
}
|
||||
|
||||
if (var->HasConstantIdDecoration()) {
|
||||
if (ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
|
||||
out_ << " [[function_constant(" << var->constant_id() << ")]]";
|
||||
} else if (var->constructor() != nullptr) {
|
||||
out_ << " = ";
|
||||
|
|
|
@ -750,7 +750,7 @@ bool Builder::GenerateGlobalVariable(ast::Variable* var) {
|
|||
// one
|
||||
// 2- If we don't have a constructor and we're an Output or Private variable
|
||||
// then WGSL requires an initializer.
|
||||
if (var->HasConstantIdDecoration()) {
|
||||
if (ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
|
||||
if (type_no_ac->Is<type::F32>()) {
|
||||
ast::FloatLiteral l(Source{}, type_no_ac, 0.0f);
|
||||
init_id = GenerateLiteralIfNeeded(var, &l);
|
||||
|
@ -1490,7 +1490,8 @@ uint32_t Builder::GenerateLiteralIfNeeded(ast::Variable* var,
|
|||
ast::Literal* lit) {
|
||||
ScalarConstant constant;
|
||||
|
||||
if (var && var->HasConstantIdDecoration()) {
|
||||
if (var &&
|
||||
ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
|
||||
constant.is_spec_op = true;
|
||||
constant.constant_id = var->constant_id();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue