diff --git a/src/ast/variable.cc b/src/ast/variable.cc index cab6350fc8..25a11f9ea6 100644 --- a/src/ast/variable.cc +++ b/src/ast/variable.cc @@ -45,6 +45,19 @@ Variable::Variable(Variable&&) = default; Variable::~Variable() = default; +Variable::BindingPoint Variable::binding_point() const { + GroupDecoration* group = nullptr; + BindingDecoration* binding = nullptr; + for (auto* deco : decorations()) { + if (auto* g = deco->As()) { + group = g; + } else if (auto* b = deco->As()) { + binding = b; + } + } + return BindingPoint{group, binding}; +} + bool Variable::HasLocationDecoration() const { for (auto* deco : decorations_) { if (deco->Is()) { diff --git a/src/ast/variable.h b/src/ast/variable.h index f5f37cb58f..5f809e83bd 100644 --- a/src/ast/variable.h +++ b/src/ast/variable.h @@ -25,6 +25,8 @@ namespace tint { namespace ast { +class BindingDecoration; +class GroupDecoration; class LocationDecoration; /// A Variable statement. @@ -76,6 +78,18 @@ class LocationDecoration; /// The storage class for a formal parameter is always StorageClass::kNone. class Variable : public Castable { public: + /// BindingPoint holds a group and binding decoration. + struct BindingPoint { + /// The `[[group]]` part of the binding point + GroupDecoration* group = nullptr; + /// The `[[binding]]` part of the binding point + BindingDecoration* binding = nullptr; + + /// @returns true if the BindingPoint has a valid group and binding + /// decoration. + inline operator bool() const { return group && binding; } + }; + /// Create a variable /// @param source the variable source /// @param sym the variable symbol @@ -117,6 +131,9 @@ class Variable : public Castable { /// @returns the decorations attached to this variable const DecorationList& decorations() const { return decorations_; } + /// @returns the binding point information for the variable + BindingPoint binding_point() const; + /// @returns true if the decorations include a LocationDecoration bool HasLocationDecoration() const; /// @returns true if the decorations include a BuiltinDecoration diff --git a/src/ast/variable_test.cc b/src/ast/variable_test.cc index 2651f1a1e3..258f988814 100644 --- a/src/ast/variable_test.cc +++ b/src/ast/variable_test.cc @@ -107,6 +107,47 @@ TEST_F(VariableTest, WithDecorations) { EXPECT_EQ(1u, location->value()); } +TEST_F(VariableTest, BindingPoint) { + auto* var = Var("my_var", ty.i32(), StorageClass::kFunction, nullptr, + DecorationList{ + create(2), + create(1), + }); + EXPECT_TRUE(var->binding_point()); + ASSERT_NE(var->binding_point().binding, nullptr); + ASSERT_NE(var->binding_point().group, nullptr); + EXPECT_EQ(var->binding_point().binding->value(), 2u); + EXPECT_EQ(var->binding_point().group->value(), 1u); +} + +TEST_F(VariableTest, BindingPointoDecorations) { + auto* var = Var("my_var", ty.i32(), StorageClass::kFunction, nullptr, + DecorationList{}); + EXPECT_FALSE(var->binding_point()); + EXPECT_EQ(var->binding_point().group, nullptr); + EXPECT_EQ(var->binding_point().binding, nullptr); +} + +TEST_F(VariableTest, BindingPointMissingGroupDecoration) { + auto* var = Var("my_var", ty.i32(), StorageClass::kFunction, nullptr, + DecorationList{ + create(2), + }); + EXPECT_FALSE(var->binding_point()); + ASSERT_NE(var->binding_point().binding, nullptr); + EXPECT_EQ(var->binding_point().binding->value(), 2u); + EXPECT_EQ(var->binding_point().group, nullptr); +} + +TEST_F(VariableTest, BindingPointMissingBindingDecoration) { + auto* var = Var("my_var", ty.i32(), StorageClass::kFunction, nullptr, + DecorationList{create(1)}); + EXPECT_FALSE(var->binding_point()); + ASSERT_NE(var->binding_point().group, nullptr); + EXPECT_EQ(var->binding_point().group->value(), 1u); + EXPECT_EQ(var->binding_point().binding, nullptr); +} + TEST_F(VariableTest, ConstantId) { auto* var = Var("my_var", ty.i32(), StorageClass::kFunction, nullptr, DecorationList{ diff --git a/src/semantic/function.h b/src/semantic/function.h index fbb5934221..efc6fbc48d 100644 --- a/src/semantic/function.h +++ b/src/semantic/function.h @@ -18,6 +18,7 @@ #include #include +#include "src/ast/variable.h" #include "src/semantic/call_target.h" namespace tint { @@ -39,16 +40,9 @@ class Variable; /// Function holds the semantic information for function nodes. class Function : public Castable { public: - /// Information about a binding - struct BindingInfo { - /// The binding decoration - ast::BindingDecoration* binding = nullptr; - /// The group decoration - ast::GroupDecoration* group = nullptr; - }; - - /// A vector of [Variable*, BindingInfo] pairs - using VariableBindings = std::vector>; + /// A vector of [Variable*, ast::Variable::BindingPoint] pairs + using VariableBindings = + std::vector>; /// Constructor /// @param declaration the ast::Function diff --git a/src/semantic/sem_function.cc b/src/semantic/sem_function.cc index 8e4dcbd07a..dab7f1ce36 100644 --- a/src/semantic/sem_function.cc +++ b/src/semantic/sem_function.cc @@ -38,21 +38,6 @@ ParameterList GetParameters(ast::Function* ast) { return parameters; } -std::tuple GetBindingAndGroup( - const Variable* var) { - ast::BindingDecoration* binding = nullptr; - ast::GroupDecoration* group = nullptr; - for (auto* deco : var->Declaration()->decorations()) { - if (auto* b = deco->As()) { - binding = b; - } - if (auto* s = deco->As()) { - group = s; - } - } - return {binding, group}; -} - } // namespace Function::Function(ast::Function* declaration, @@ -92,14 +77,9 @@ Function::VariableBindings Function::ReferencedUniformVariables() const { continue; } - ast::BindingDecoration* binding = nullptr; - ast::GroupDecoration* group = nullptr; - std::tie(binding, group) = GetBindingAndGroup(var); - if (binding == nullptr || group == nullptr) { - continue; + if (auto binding_point = var->Declaration()->binding_point()) { + ret.push_back({var, binding_point}); } - - ret.push_back({var, BindingInfo{binding, group}}); } return ret; } @@ -112,14 +92,9 @@ Function::VariableBindings Function::ReferencedStorageBufferVariables() const { continue; } - ast::BindingDecoration* binding = nullptr; - ast::GroupDecoration* group = nullptr; - std::tie(binding, group) = GetBindingAndGroup(var); - if (binding == nullptr || group == nullptr) { - continue; + if (auto binding_point = var->Declaration()->binding_point()) { + ret.push_back({var, binding_point}); } - - ret.push_back({var, BindingInfo{binding, group}}); } return ret; } @@ -168,14 +143,9 @@ Function::VariableBindings Function::ReferencedStorageTextureVariables() const { continue; } - ast::BindingDecoration* binding = nullptr; - ast::GroupDecoration* group = nullptr; - std::tie(binding, group) = GetBindingAndGroup(var); - if (binding == nullptr || group == nullptr) { - continue; + if (auto binding_point = var->Declaration()->binding_point()) { + ret.push_back({var, binding_point}); } - - ret.push_back({var, BindingInfo{binding, group}}); } return ret; } @@ -191,14 +161,9 @@ Function::VariableBindings Function::ReferencedDepthTextureVariables() const { continue; } - ast::BindingDecoration* binding = nullptr; - ast::GroupDecoration* group = nullptr; - std::tie(binding, group) = GetBindingAndGroup(var); - if (binding == nullptr || group == nullptr) { - continue; + if (auto binding_point = var->Declaration()->binding_point()) { + ret.push_back({var, binding_point}); } - - ret.push_back({var, BindingInfo{binding, group}}); } return ret; } @@ -239,14 +204,9 @@ Function::VariableBindings Function::ReferencedSamplerVariablesImpl( continue; } - ast::BindingDecoration* binding = nullptr; - ast::GroupDecoration* group = nullptr; - std::tie(binding, group) = GetBindingAndGroup(var); - if (binding == nullptr || group == nullptr) { - continue; + if (auto binding_point = var->Declaration()->binding_point()) { + ret.push_back({var, binding_point}); } - - ret.push_back({var, BindingInfo{binding, group}}); } return ret; } @@ -270,14 +230,9 @@ Function::VariableBindings Function::ReferencedSampledTextureVariablesImpl( continue; } - ast::BindingDecoration* binding = nullptr; - ast::GroupDecoration* group = nullptr; - std::tie(binding, group) = GetBindingAndGroup(var); - if (binding == nullptr || group == nullptr) { - continue; + if (auto binding_point = var->Declaration()->binding_point()) { + ret.push_back({var, binding_point}); } - - ret.push_back({var, BindingInfo{binding, group}}); } return ret;