[ast] Add helpers for searching a decoration list
This is a commonly used pattern. Change-Id: I698397c93c33db64c53cbe8662186e1976075b80 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47280 Auto-Submit: James Price <jrprice@google.com> Commit-Queue: James Price <jrprice@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
c76ec15b45
commit
a12ccb20a0
|
@ -36,6 +36,30 @@ class Decoration : public Castable<Decoration, Node> {
|
||||||
/// A list of decorations
|
/// A list of decorations
|
||||||
using DecorationList = std::vector<Decoration*>;
|
using DecorationList = std::vector<Decoration*>;
|
||||||
|
|
||||||
|
/// @param decorations the list of decorations to search
|
||||||
|
/// @returns true if `decorations` includes a decoration of type `T`
|
||||||
|
template <typename T>
|
||||||
|
bool HasDecoration(const DecorationList& decorations) {
|
||||||
|
for (auto* deco : decorations) {
|
||||||
|
if (deco->Is<T>()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// @param decorations the list of decorations to search
|
||||||
|
/// @returns a pointer to `T` from `decorations` if found, otherwise nullptr.
|
||||||
|
template <typename T>
|
||||||
|
T* GetDecoration(const DecorationList& decorations) {
|
||||||
|
for (auto* deco : decorations) {
|
||||||
|
if (deco->Is<T>()) {
|
||||||
|
return deco->As<T>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace ast
|
} // namespace ast
|
||||||
} // namespace tint
|
} // namespace tint
|
||||||
|
|
||||||
|
|
|
@ -49,20 +49,16 @@ Function::Function(Function&&) = default;
|
||||||
Function::~Function() = default;
|
Function::~Function() = default;
|
||||||
|
|
||||||
std::tuple<uint32_t, uint32_t, uint32_t> Function::workgroup_size() const {
|
std::tuple<uint32_t, uint32_t, uint32_t> Function::workgroup_size() const {
|
||||||
for (auto* deco : decorations_) {
|
if (auto* workgroup = GetDecoration<WorkgroupDecoration>(decorations_)) {
|
||||||
if (auto* workgroup = deco->As<WorkgroupDecoration>()) {
|
|
||||||
return workgroup->values();
|
return workgroup->values();
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return {1, 1, 1};
|
return {1, 1, 1};
|
||||||
}
|
}
|
||||||
|
|
||||||
PipelineStage Function::pipeline_stage() const {
|
PipelineStage Function::pipeline_stage() const {
|
||||||
for (auto* deco : decorations_) {
|
if (auto* stage = GetDecoration<StageDecoration>(decorations_)) {
|
||||||
if (auto* stage = deco->As<StageDecoration>()) {
|
|
||||||
return stage->value();
|
return stage->value();
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return PipelineStage::kNone;
|
return PipelineStage::kNone;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -63,18 +63,6 @@ class Function : public Castable<Function, Node> {
|
||||||
/// @returns the decorations attached to this function
|
/// @returns the decorations attached to this function
|
||||||
const DecorationList& decorations() const { return decorations_; }
|
const DecorationList& decorations() const { return decorations_; }
|
||||||
|
|
||||||
/// @returns the decoration with the type `T` or nullptr if this function does
|
|
||||||
/// not contain a decoration with the given type
|
|
||||||
template <typename T>
|
|
||||||
const T* find_decoration() const {
|
|
||||||
for (auto* deco : decorations()) {
|
|
||||||
if (auto* d = deco->As<T>()) {
|
|
||||||
return d;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// @returns the workgroup size {x, y, z} for the function. {1, 1, 1} will be
|
/// @returns the workgroup size {x, y, z} for the function. {1, 1, 1} will be
|
||||||
/// return if no workgroup size was set.
|
/// return if no workgroup size was set.
|
||||||
std::tuple<uint32_t, uint32_t, uint32_t> workgroup_size() const;
|
std::tuple<uint32_t, uint32_t, uint32_t> workgroup_size() const;
|
||||||
|
|
|
@ -50,12 +50,7 @@ StructMember* Struct::get_member(const Symbol& symbol) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Struct::IsBlockDecorated() const {
|
bool Struct::IsBlockDecorated() const {
|
||||||
for (auto* deco : decorations_) {
|
return HasDecoration<StructBlockDecoration>(decorations_);
|
||||||
if (deco->Is<StructBlockDecoration>()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Struct* Struct::Clone(CloneContext* ctx) const {
|
Struct* Struct::Clone(CloneContext* ctx) const {
|
||||||
|
|
|
@ -41,20 +41,14 @@ StructMember::StructMember(StructMember&&) = default;
|
||||||
StructMember::~StructMember() = default;
|
StructMember::~StructMember() = default;
|
||||||
|
|
||||||
bool StructMember::has_offset_decoration() const {
|
bool StructMember::has_offset_decoration() const {
|
||||||
for (auto* deco : decorations_) {
|
return HasDecoration<StructMemberOffsetDecoration>(decorations_);
|
||||||
if (deco->Is<StructMemberOffsetDecoration>()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t StructMember::offset() const {
|
uint32_t StructMember::offset() const {
|
||||||
for (auto* deco : decorations_) {
|
if (auto* offset =
|
||||||
if (auto* offset = deco->As<StructMemberOffsetDecoration>()) {
|
GetDecoration<StructMemberOffsetDecoration>(decorations_)) {
|
||||||
return offset->offset();
|
return offset->offset();
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -59,49 +59,11 @@ Variable::BindingPoint Variable::binding_point() const {
|
||||||
return BindingPoint{group, binding};
|
return BindingPoint{group, binding};
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Variable::HasLocationDecoration() const {
|
|
||||||
for (auto* deco : decorations_) {
|
|
||||||
if (deco->Is<LocationDecoration>()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Variable::HasBuiltinDecoration() const {
|
|
||||||
for (auto* deco : decorations_) {
|
|
||||||
if (deco->Is<BuiltinDecoration>()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Variable::HasConstantIdDecoration() const {
|
|
||||||
for (auto* deco : decorations_) {
|
|
||||||
if (deco->Is<ConstantIdDecoration>()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
LocationDecoration* Variable::GetLocationDecoration() const {
|
|
||||||
for (auto* deco : decorations_) {
|
|
||||||
if (deco->Is<LocationDecoration>()) {
|
|
||||||
return deco->As<LocationDecoration>();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t Variable::constant_id() const {
|
uint32_t Variable::constant_id() const {
|
||||||
TINT_ASSERT(HasConstantIdDecoration());
|
if (auto* cid = GetDecoration<ConstantIdDecoration>(decorations_)) {
|
||||||
for (auto* deco : decorations_) {
|
|
||||||
if (auto* cid = deco->As<ConstantIdDecoration>()) {
|
|
||||||
return cid->value();
|
return cid->value();
|
||||||
}
|
}
|
||||||
}
|
TINT_ASSERT(false);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -134,18 +134,8 @@ class Variable : public Castable<Variable, Node> {
|
||||||
/// @returns the binding point information for the variable
|
/// @returns the binding point information for the variable
|
||||||
BindingPoint binding_point() const;
|
BindingPoint binding_point() const;
|
||||||
|
|
||||||
/// @returns true if the decorations include a LocationDecoration
|
/// @returns the constant_id value for the variable. Assumes that this
|
||||||
bool HasLocationDecoration() const;
|
/// variable has a constant ID decoration.
|
||||||
/// @returns true if the decorations include a BuiltinDecoration
|
|
||||||
bool HasBuiltinDecoration() const;
|
|
||||||
/// @returns true if the decorations include a ConstantIdDecoration
|
|
||||||
bool HasConstantIdDecoration() const;
|
|
||||||
|
|
||||||
/// @returns pointer to LocationDecoration in decorations, otherwise NULL.
|
|
||||||
LocationDecoration* GetLocationDecoration() const;
|
|
||||||
|
|
||||||
/// @returns the constant_id value for the variable. Assumes that
|
|
||||||
/// HasConstantIdDecoration() has been called first.
|
|
||||||
uint32_t constant_id() const;
|
uint32_t constant_id() const;
|
||||||
|
|
||||||
/// Clones this node and all transitive child nodes using the `CloneContext`
|
/// Clones this node and all transitive child nodes using the `CloneContext`
|
||||||
|
|
|
@ -98,11 +98,12 @@ TEST_F(VariableTest, WithDecorations) {
|
||||||
create<ConstantIdDecoration>(1200),
|
create<ConstantIdDecoration>(1200),
|
||||||
});
|
});
|
||||||
|
|
||||||
EXPECT_TRUE(var->HasLocationDecoration());
|
auto& decorations = var->decorations();
|
||||||
EXPECT_TRUE(var->HasBuiltinDecoration());
|
EXPECT_TRUE(ast::HasDecoration<ast::LocationDecoration>(decorations));
|
||||||
EXPECT_TRUE(var->HasConstantIdDecoration());
|
EXPECT_TRUE(ast::HasDecoration<ast::BuiltinDecoration>(decorations));
|
||||||
|
EXPECT_TRUE(ast::HasDecoration<ast::ConstantIdDecoration>(decorations));
|
||||||
|
|
||||||
auto* location = var->GetLocationDecoration();
|
auto* location = ast::GetDecoration<ast::LocationDecoration>(decorations);
|
||||||
ASSERT_NE(nullptr, location);
|
ASSERT_NE(nullptr, location);
|
||||||
EXPECT_EQ(1u, location->value());
|
EXPECT_EQ(1u, location->value());
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
|
|
||||||
#include "src/ast/bool_literal.h"
|
#include "src/ast/bool_literal.h"
|
||||||
#include "src/ast/float_literal.h"
|
#include "src/ast/float_literal.h"
|
||||||
|
#include "src/ast/constant_id_decoration.h"
|
||||||
#include "src/ast/module.h"
|
#include "src/ast/module.h"
|
||||||
#include "src/ast/scalar_constructor_expression.h"
|
#include "src/ast/scalar_constructor_expression.h"
|
||||||
#include "src/ast/sint_literal.h"
|
#include "src/ast/sint_literal.h"
|
||||||
|
@ -203,7 +204,7 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
|
||||||
auto* decl = var->Declaration();
|
auto* decl = var->Declaration();
|
||||||
|
|
||||||
auto name = program_->Symbols().NameFor(decl->symbol());
|
auto name = program_->Symbols().NameFor(decl->symbol());
|
||||||
if (decl->HasBuiltinDecoration()) {
|
if (ast::HasDecoration<ast::BuiltinDecoration>(decl->decorations())) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -220,7 +221,8 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
|
||||||
stage_variable.component_type = ComponentType::kSInt;
|
stage_variable.component_type = ComponentType::kSInt;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto* location_decoration = decl->GetLocationDecoration();
|
auto* location_decoration =
|
||||||
|
ast::GetDecoration<ast::LocationDecoration>(decl->decorations());
|
||||||
if (location_decoration) {
|
if (location_decoration) {
|
||||||
stage_variable.has_location_decoration = true;
|
stage_variable.has_location_decoration = true;
|
||||||
stage_variable.location_decoration = location_decoration->value();
|
stage_variable.location_decoration = location_decoration->value();
|
||||||
|
@ -257,7 +259,7 @@ std::string Inspector::GetRemappedNameForEntryPoint(
|
||||||
std::map<uint32_t, Scalar> Inspector::GetConstantIDs() {
|
std::map<uint32_t, Scalar> Inspector::GetConstantIDs() {
|
||||||
std::map<uint32_t, Scalar> result;
|
std::map<uint32_t, Scalar> result;
|
||||||
for (auto* var : program_->AST().GlobalVariables()) {
|
for (auto* var : program_->AST().GlobalVariables()) {
|
||||||
if (!var->HasConstantIdDecoration()) {
|
if (!ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "src/ast/constant_id_decoration.h"
|
||||||
#include "src/reader/wgsl/parser_impl_test_helper.h"
|
#include "src/reader/wgsl/parser_impl_test_helper.h"
|
||||||
|
|
||||||
namespace tint {
|
namespace tint {
|
||||||
|
@ -43,7 +44,8 @@ TEST_F(ParserImplTest, GlobalConstantDecl) {
|
||||||
ASSERT_NE(e->constructor(), nullptr);
|
ASSERT_NE(e->constructor(), nullptr);
|
||||||
EXPECT_TRUE(e->constructor()->Is<ast::ConstructorExpression>());
|
EXPECT_TRUE(e->constructor()->Is<ast::ConstructorExpression>());
|
||||||
|
|
||||||
EXPECT_FALSE(e.value->HasConstantIdDecoration());
|
EXPECT_FALSE(
|
||||||
|
ast::HasDecoration<ast::ConstantIdDecoration>(e.value->decorations()));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParserImplTest, GlobalConstantDecl_MissingEqual) {
|
TEST_F(ParserImplTest, GlobalConstantDecl_MissingEqual) {
|
||||||
|
@ -123,7 +125,8 @@ TEST_F(ParserImplTest, GlobalConstantDec_ConstantId) {
|
||||||
ASSERT_NE(e->constructor(), nullptr);
|
ASSERT_NE(e->constructor(), nullptr);
|
||||||
EXPECT_TRUE(e->constructor()->Is<ast::ConstructorExpression>());
|
EXPECT_TRUE(e->constructor()->Is<ast::ConstructorExpression>());
|
||||||
|
|
||||||
EXPECT_TRUE(e.value->HasConstantIdDecoration());
|
EXPECT_TRUE(
|
||||||
|
ast::HasDecoration<ast::ConstantIdDecoration>(e.value->decorations()));
|
||||||
EXPECT_EQ(e.value->constant_id(), 7u);
|
EXPECT_EQ(e.value->constant_id(), 7u);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -310,7 +310,8 @@ bool Resolver::ValidateFunction(const ast::Function* func) {
|
||||||
func->source());
|
func->source());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} else if (!func->find_decoration<ast::InternalDecoration>()) {
|
} else if (!ast::HasDecoration<ast::InternalDecoration>(
|
||||||
|
func->decorations())) {
|
||||||
TINT_ICE(diagnostics_)
|
TINT_ICE(diagnostics_)
|
||||||
<< "Function " << builder_->Symbols().NameFor(func->symbol())
|
<< "Function " << builder_->Symbols().NameFor(func->symbol())
|
||||||
<< " has no body and does not have the [[internal]] decoration";
|
<< " has no body and does not have the [[internal]] decoration";
|
||||||
|
|
|
@ -1310,8 +1310,9 @@ bool GeneratorImpl::EmitExpression(std::ostream& pre,
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GeneratorImpl::global_is_in_struct(const semantic::Variable* var) const {
|
bool GeneratorImpl::global_is_in_struct(const semantic::Variable* var) const {
|
||||||
if (var->Declaration()->HasLocationDecoration() ||
|
auto& decorations = var->Declaration()->decorations();
|
||||||
var->Declaration()->HasBuiltinDecoration()) {
|
if (ast::HasDecoration<ast::LocationDecoration>(decorations) ||
|
||||||
|
ast::HasDecoration<ast::BuiltinDecoration>(decorations)) {
|
||||||
return var->StorageClass() == ast::StorageClass::kInput ||
|
return var->StorageClass() == ast::StorageClass::kInput ||
|
||||||
var->StorageClass() == ast::StorageClass::kOutput;
|
var->StorageClass() == ast::StorageClass::kOutput;
|
||||||
}
|
}
|
||||||
|
@ -1463,7 +1464,7 @@ bool GeneratorImpl::EmitFunction(std::ostream& out, ast::Function* func) {
|
||||||
|
|
||||||
auto* func_sem = builder_.Sem().Get(func);
|
auto* func_sem = builder_.Sem().Get(func);
|
||||||
|
|
||||||
if (func->find_decoration<ast::InternalDecoration>()) {
|
if (ast::HasDecoration<ast::InternalDecoration>(func->decorations())) {
|
||||||
// An internal function. Do not emit.
|
// An internal function. Do not emit.
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -2825,7 +2826,7 @@ bool GeneratorImpl::EmitProgramConstVariable(std::ostream& out,
|
||||||
|
|
||||||
auto* type = builder_.Sem().Get(var)->Type();
|
auto* type = builder_.Sem().Get(var)->Type();
|
||||||
|
|
||||||
if (var->HasConstantIdDecoration()) {
|
if (ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
|
||||||
auto const_id = var->constant_id();
|
auto const_id = var->constant_id();
|
||||||
|
|
||||||
out << "#ifndef WGSL_SPEC_CONSTANT_" << const_id << std::endl;
|
out << "#ifndef WGSL_SPEC_CONSTANT_" << const_id << std::endl;
|
||||||
|
|
|
@ -1562,12 +1562,15 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GeneratorImpl::global_is_in_struct(const semantic::Variable* var) const {
|
bool GeneratorImpl::global_is_in_struct(const semantic::Variable* var) const {
|
||||||
|
auto& decorations = var->Declaration()->decorations();
|
||||||
bool in_or_out_struct_has_location =
|
bool in_or_out_struct_has_location =
|
||||||
var != nullptr && var->Declaration()->HasLocationDecoration() &&
|
var != nullptr &&
|
||||||
|
ast::HasDecoration<ast::LocationDecoration>(decorations) &&
|
||||||
(var->StorageClass() == ast::StorageClass::kInput ||
|
(var->StorageClass() == ast::StorageClass::kInput ||
|
||||||
var->StorageClass() == ast::StorageClass::kOutput);
|
var->StorageClass() == ast::StorageClass::kOutput);
|
||||||
bool in_struct_has_builtin =
|
bool in_struct_has_builtin =
|
||||||
var != nullptr && var->Declaration()->HasBuiltinDecoration() &&
|
var != nullptr &&
|
||||||
|
ast::HasDecoration<ast::BuiltinDecoration>(decorations) &&
|
||||||
var->StorageClass() == ast::StorageClass::kOutput;
|
var->StorageClass() == ast::StorageClass::kOutput;
|
||||||
return in_or_out_struct_has_location || in_struct_has_builtin;
|
return in_or_out_struct_has_location || in_struct_has_builtin;
|
||||||
}
|
}
|
||||||
|
@ -2249,7 +2252,7 @@ bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
|
||||||
out_ << " " << program_->Symbols().NameFor(var->symbol());
|
out_ << " " << program_->Symbols().NameFor(var->symbol());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (var->HasConstantIdDecoration()) {
|
if (ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
|
||||||
out_ << " [[function_constant(" << var->constant_id() << ")]]";
|
out_ << " [[function_constant(" << var->constant_id() << ")]]";
|
||||||
} else if (var->constructor() != nullptr) {
|
} else if (var->constructor() != nullptr) {
|
||||||
out_ << " = ";
|
out_ << " = ";
|
||||||
|
|
|
@ -750,7 +750,7 @@ bool Builder::GenerateGlobalVariable(ast::Variable* var) {
|
||||||
// one
|
// one
|
||||||
// 2- If we don't have a constructor and we're an Output or Private variable
|
// 2- If we don't have a constructor and we're an Output or Private variable
|
||||||
// then WGSL requires an initializer.
|
// then WGSL requires an initializer.
|
||||||
if (var->HasConstantIdDecoration()) {
|
if (ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
|
||||||
if (type_no_ac->Is<type::F32>()) {
|
if (type_no_ac->Is<type::F32>()) {
|
||||||
ast::FloatLiteral l(Source{}, type_no_ac, 0.0f);
|
ast::FloatLiteral l(Source{}, type_no_ac, 0.0f);
|
||||||
init_id = GenerateLiteralIfNeeded(var, &l);
|
init_id = GenerateLiteralIfNeeded(var, &l);
|
||||||
|
@ -1490,7 +1490,8 @@ uint32_t Builder::GenerateLiteralIfNeeded(ast::Variable* var,
|
||||||
ast::Literal* lit) {
|
ast::Literal* lit) {
|
||||||
ScalarConstant constant;
|
ScalarConstant constant;
|
||||||
|
|
||||||
if (var && var->HasConstantIdDecoration()) {
|
if (var &&
|
||||||
|
ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
|
||||||
constant.is_spec_op = true;
|
constant.is_spec_op = true;
|
||||||
constant.constant_id = var->constant_id();
|
constant.constant_id = var->constant_id();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue