diff --git a/src/castable.h b/src/castable.h index 0dada0d953..34d98d50b7 100644 --- a/src/castable.h +++ b/src/castable.h @@ -94,7 +94,7 @@ struct IsAnyOf; /// class `TO` /// @param obj the object to test from template -bool Is(FROM* obj) { +inline bool Is(FROM* obj) { constexpr const bool downcast = std::is_base_of::value; constexpr const bool upcast = std::is_base_of::value; constexpr const bool nocast = std::is_same::value; @@ -111,6 +111,28 @@ bool Is(FROM* obj) { return obj->TypeInfo().Is(TypeInfo::Of>()); } +/// @returns true if `obj` is a valid pointer, and is of, or derives from the +/// class `TO`, and pred(const TO*) returns true +/// @param obj the object to test from +/// @param pred predicate function with signature `bool(const TO*)` called iff +/// object is of, or derives from the class `TO`. +template +inline bool Is(FROM* obj, Pred&& pred) { + constexpr const bool downcast = std::is_base_of::value; + constexpr const bool upcast = std::is_base_of::value; + constexpr const bool nocast = std::is_same::value; + static_assert(upcast || downcast || nocast, "impossible cast"); + + if (obj == nullptr) { + return false; + } + + bool is_type = upcast || nocast || + obj->TypeInfo().Is(TypeInfo::Of>()); + + return is_type && pred(static_cast*>(obj)); +} + /// @returns true if `obj` is of, or derives from any of the `TO` /// classes. /// @param obj the object to cast from @@ -150,6 +172,15 @@ class CastableBase { return tint::Is(this); } + /// @returns true if this object is of, or derives from the class `TO` and + /// pred(const TO*) returns true + /// @param pred predicate function with signature `bool(const TO*)` called iff + /// object is of, or derives from the class `TO`. + template + inline bool Is(Pred&& pred) const { + return tint::Is(this, std::forward(pred)); + } + /// @returns true if this object is of, or derives from any of the `TO` /// classes. template @@ -219,6 +250,16 @@ class Castable : public BASE { return tint::Is(static_cast(this)); } + /// @returns true if this object is of, or derives from the class `TO` and + /// pred(const TO*) returns true + /// @param pred predicate function with signature `bool(const TO*)` called iff + /// object is of, or derives from the class `TO`. + template + inline bool Is(Pred&& pred) const { + return tint::Is(static_cast(this), + std::forward(pred)); + } + /// @returns true if this object is of, or derives from any of the `TO` /// classes. template diff --git a/src/castable_test.cc b/src/castable_test.cc index 05f5e4ab11..1f3f546e70 100644 --- a/src/castable_test.cc +++ b/src/castable_test.cc @@ -73,6 +73,25 @@ TEST(CastableBase, Is) { ASSERT_TRUE(gecko->Is()); } +TEST(CastableBase, IsWithPredicate) { + std::unique_ptr frog = std::make_unique(); + + frog->Is([&frog](const Animal* a) { + EXPECT_EQ(a, frog.get()); + return true; + }); + + ASSERT_TRUE((frog->Is([](const Animal* a) { return true; }))); + ASSERT_FALSE((frog->Is([](const Animal* a) { return false; }))); + + // Predicate not called if cast is invalid + auto expect_not_called = [] { FAIL() << "Should not be called"; }; + ASSERT_FALSE((frog->Is([&](const Animal* a) { + expect_not_called(); + return true; + }))); +} + TEST(CastableBase, IsAnyOf) { std::unique_ptr frog = std::make_unique(); std::unique_ptr bear = std::make_unique(); @@ -138,6 +157,25 @@ TEST(Castable, Is) { ASSERT_TRUE(gecko->Is()); } +TEST(Castable, IsWithPredicate) { + std::unique_ptr frog = std::make_unique(); + + frog->Is([&frog](const Animal* a) { + EXPECT_EQ(a, frog.get()); + return true; + }); + + ASSERT_TRUE((frog->Is([](const Animal* a) { return true; }))); + ASSERT_FALSE((frog->Is([](const Animal* a) { return false; }))); + + // Predicate not called if cast is invalid + auto expect_not_called = [] { FAIL() << "Should not be called"; }; + ASSERT_FALSE((frog->Is([&](const Animal* a) { + expect_not_called(); + return true; + }))); +} + TEST(Castable, As) { std::unique_ptr frog = std::make_unique(); std::unique_ptr bear = std::make_unique(); diff --git a/src/type/type.cc b/src/type/type.cc index ecd79eaeab..2a7d6ba991 100644 --- a/src/type/type.cc +++ b/src/type/type.cc @@ -79,11 +79,13 @@ bool Type::is_float_scalar() const { } bool Type::is_float_matrix() const { - return Is() && As()->type()->is_float_scalar(); + return Is( + [](const Matrix* m) { return m->type()->is_float_scalar(); }); } bool Type::is_float_vector() const { - return Is() && As()->type()->is_float_scalar(); + return Is( + [](const Vector* v) { return v->type()->is_float_scalar(); }); } bool Type::is_float_scalar_or_vector() const { @@ -95,19 +97,19 @@ bool Type::is_integer_scalar() const { } bool Type::is_unsigned_integer_vector() const { - return Is() && As()->type()->Is(); + return Is([](const Vector* v) { return v->type()->Is(); }); } bool Type::is_signed_integer_vector() const { - return Is() && As()->type()->Is(); + return Is([](const Vector* v) { return v->type()->Is(); }); } bool Type::is_unsigned_scalar_or_vector() const { - return Is() || (Is() && As()->type()->Is()); + return Is() || is_unsigned_integer_vector(); } bool Type::is_signed_scalar_or_vector() const { - return Is() || (Is() && As()->type()->Is()); + return Is() || is_signed_integer_vector(); } bool Type::is_integer_scalar_or_vector() const { @@ -115,7 +117,7 @@ bool Type::is_integer_scalar_or_vector() const { } bool Type::is_bool_vector() const { - return Is() && As()->type()->Is(); + return Is([](const Vector* v) { return v->type()->Is(); }); } bool Type::is_bool_scalar_or_vector() const {