From b68e8aa65848dc17cf86dd61fed6d45bfb64d06e Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Wed, 2 Feb 2022 14:38:32 +0000 Subject: [PATCH] Optimize tint::IsAnyOf<>() for many types Split IsAnyOf() into log(n) stages, where each stage performs a hashcode check. Previously there was a single hash test across the bitwise-or of all the types being considered. If this passed, then each type would be tested with Is() individually. With this change, the list of types will be recursively split into two, which each block hash-code checked. This is repeated until we reach fewer than 4 types to check, where the test decays to using Is() for each type. Also renamed `combined_hashcode` to `full_hashcode`, and used the term CombinedHash for new helpers that bitwise-or the hashes from a number of types. Bug: tint:1383 Change-Id: Id056b9f7a9792430bd75ce554cb5fe73221ca4c7 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/78580 Reviewed-by: Antonio Maiorano Commit-Queue: Ben Clayton Kokoro: Ben Clayton --- src/castable.cc | 2 +- src/castable.h | 153 +++++++++++++++++++++----------------- src/traits.h | 52 ++++++++++++- src/traits_test.cc | 181 ++++++++++++++++++++++++++++++++++----------- 4 files changed, 273 insertions(+), 115 deletions(-) diff --git a/src/castable.cc b/src/castable.cc index e63981f6aa..02c3ebcd80 100644 --- a/src/castable.cc +++ b/src/castable.cc @@ -23,7 +23,7 @@ const TypeInfo detail::TypeInfoOf::info{ nullptr, "CastableBase", tint::TypeInfo::HashCodeOf(), - tint::TypeInfo::HashCodeOf(), + tint::TypeInfo::FullHashCodeOf(), }; } // namespace tint diff --git a/src/castable.h b/src/castable.h index 280cc1754a..3104492965 100644 --- a/src/castable.h +++ b/src/castable.h @@ -17,6 +17,7 @@ #include #include +#include #include #include "src/traits.h" @@ -61,7 +62,7 @@ struct TypeInfoOf; &tint::detail::TypeInfoOf::info, \ #CLASS, \ tint::TypeInfo::HashCodeOf(), \ - tint::TypeInfo::CombinedHashCodeOf(), \ + tint::TypeInfo::FullHashCodeOf(), \ }; \ TINT_CASTABLE_POP_DISABLE_WARNINGS() @@ -86,17 +87,17 @@ struct TypeInfo { const char* name; /// The type hash code const HashCode hashcode; - /// The type hash code or'd with the base class' combined hash code - const HashCode combined_hashcode; + /// The type hash code bitwise-or'd with all ancestor's hashcodes. + const HashCode full_hashcode; /// @param type the test type info /// @returns true if the class with this TypeInfo is of, or derives from the /// class with the given TypeInfo. inline bool Is(const tint::TypeInfo* type) const { // Optimization: Check whether the all the bits of the type's hashcode can - // be found in the combined_hashcode. If a single bit is missing, then we + // be found in the full_hashcode. If a single bit is missing, then we // can quickly tell that that this TypeInfo does not derive from `type`. - if ((combined_hashcode & type->hashcode) != type->hashcode) { + if ((full_hashcode & type->hashcode) != type->hashcode) { return false; } @@ -145,6 +146,8 @@ struct TypeInfo { /// multiple hashcodes are bitwise-or'd together. template static constexpr HashCode HashCodeOf() { + static_assert(traits::IsTypeOrDerived::value, + "T is not Castable"); /// Use the compiler's "pretty" function name, which includes the template /// type, to obtain a unique hash value. #ifdef _MSC_VER @@ -161,13 +164,75 @@ struct TypeInfo { /// @returns the hashcode of the given type, bitwise-or'd with the hashcodes /// of all base classes. template - static constexpr HashCode CombinedHashCodeOf() { + static constexpr HashCode FullHashCodeOf() { if constexpr (std::is_same_v) { return HashCodeOf(); } else { - return HashCodeOf() | CombinedHashCodeOf(); + return HashCodeOf() | FullHashCodeOf(); } } + + /// @returns the bitwise-or'd hashcodes of all the types of the tuple `TUPLE`. + /// @see HashCodeOf + template + static constexpr HashCode CombinedHashCodeOfTuple() { + constexpr auto kCount = std::tuple_size_v; + if constexpr (kCount == 0) { + return 0; + } else if constexpr (kCount == 1) { + return HashCodeOf>(); + } else { + constexpr auto kMid = kCount / 2; + return CombinedHashCodeOfTuple>() | + CombinedHashCodeOfTuple< + traits::SliceTuple>(); + } + } + + /// @returns the bitwise-or'd hashcodes of all the template parameter types. + /// @see HashCodeOf + template + static constexpr HashCode CombinedHashCodeOf() { + return CombinedHashCodeOfTuple>(); + } + + /// @returns true if this TypeInfo is of, or derives from any of the types in + /// `TUPLE`. + template + inline bool IsAnyOfTuple() const { + constexpr auto kCount = std::tuple_size_v; + if constexpr (kCount == 0) { + return false; + } else if constexpr (kCount == 1) { + return Is(&Of>()); + } else if constexpr (kCount == 2) { + return Is(&Of>()) || + Is(&Of>()); + } else if constexpr (kCount == 3) { + return Is(&Of>()) || + Is(&Of>()) || + Is(&Of>()); + } else { + // Optimization: Compare the object's hashcode to the bitwise-or of all + // the tested type's hashcodes. If there's no intersection of bits in + // the two masks, then we can guarantee that the type is not in `TO`. + if (full_hashcode & TypeInfo::CombinedHashCodeOfTuple()) { + // Possibly one of the types in `TUPLE`. + // Split the search in two, and scan each block. + static constexpr auto kMid = kCount / 2; + return IsAnyOfTuple>() || + IsAnyOfTuple>(); + } + return false; + } + } + + /// @returns true if this TypeInfo is of, or derives from any of the types in + /// `TYPES`. + template + inline bool IsAnyOf() const { + return IsAnyOfTuple>(); + } }; namespace detail { @@ -181,10 +246,6 @@ struct TypeInfoOf { static const TypeInfo info; }; -// Forward declaration -template -struct IsAnyOf; - /// A placeholder structure used for template parameters that need a default /// type, but can always be automatically inferred. struct Infer; @@ -204,39 +265,29 @@ inline bool Is(FROM* obj) { } /// @returns true if `obj` is a valid pointer, and is of, or derives from the -/// class `TO`, and pred(const TO*) returns true +/// type `TYPE`, and pred(const TYPE*) returns true /// @param obj the object to test from -/// @param pred predicate function with signature `bool(const TO*)` called iff -/// object is of, or derives from the class `TO`. +/// @param pred predicate function with signature `bool(const TYPE*)` called iff +/// object is of, or derives from the class `TYPE`. /// @see CastFlags -template -inline bool Is(FROM* obj, Pred&& pred) { - return Is(obj) && - pred(static_cast*>(obj)); +inline bool Is(OBJ* obj, Pred&& pred) { + return Is(obj) && + pred(static_cast*>(obj)); } -/// @returns true if `obj` is of, or derives from any of the `TO` -/// classes. -/// @param obj the object to cast from -template -inline bool IsAnyOf(FROM* obj) { +/// @returns true if `obj` is a valid pointer, and is of, or derives from any of +/// the types in `TYPES`.OBJ +/// @param obj the object to query. +template +inline bool IsAnyOf(OBJ* obj) { if (!obj) { return false; } - // Optimization: Compare the object's combined_hashcode to the bitwise-or of - // all the tested type's hashcodes. If there's no intersection of bits in the - // two masks, then we can guarantee that the type is not in `TO`. - using Helper = detail::IsAnyOf; - auto* type = &obj->TypeInfo(); - auto hashcode = type->combined_hashcode; - if ((Helper::kHashCodes & hashcode) == 0) { - return false; - } - // Possibly one of the types in `TO`. Continue to testing against each type. - return Helper::template Exec(type); + return obj->TypeInfo().template IsAnyOf(); } /// @returns obj dynamically cast to the type `TO` or `nullptr` if @@ -402,38 +453,6 @@ class Castable : public BASE { } }; -namespace detail { -/// Helper for Castable::IsAnyOf -template -struct IsAnyOf { - /// The bitwise-or of all typeinfo hashcodes - static constexpr auto kHashCodes = - TypeInfo::HashCodeOf() | IsAnyOf::kHashCodes; - - /// @param type castable object type to test - /// @returns true if `obj` is of, or derives from any of `[TO_FIRST, - /// ...TO_REST]` - template - static bool Exec(const TypeInfo* type) { - return TypeInfo::Is(type) || - IsAnyOf::template Exec(type); - } -}; -/// Terminal specialization -template -struct IsAnyOf { - /// The bitwise-or of all typeinfo hashcodes - static constexpr auto kHashCodes = TypeInfo::HashCodeOf(); - - /// @param type castable object type to test - /// @returns true if `obj` is of, or derives from TO - template - static bool Exec(const TypeInfo* type) { - return TypeInfo::Is(type); - } -}; -} // namespace detail - } // namespace tint TINT_CASTABLE_POP_DISABLE_WARNINGS(); diff --git a/src/traits.h b/src/traits.h index 5b933b9790..4d2cbab364 100644 --- a/src/traits.h +++ b/src/traits.h @@ -16,9 +16,9 @@ #define SRC_TRAITS_H_ #include +#include -namespace tint { -namespace traits { +namespace tint::traits { /// Convience type definition for std::decay::type template @@ -109,7 +109,51 @@ using EnableIfIsType = EnableIf::value, T>; template using EnableIfIsNotType = EnableIf::value, T>; -} // namespace traits -} // namespace tint +/// @returns the std::index_sequence with all the indices shifted by OFFSET. +template +constexpr auto Shift(std::index_sequence) { + return std::integer_sequence{}; +} + +/// @returns a std::integer_sequence with the integers `[OFFSET..OFFSET+COUNT)` +template +constexpr auto Range() { + return Shift(std::make_index_sequence{}); +} + +namespace detail { + +/// @returns the tuple `t` swizzled by `INDICES` +template +constexpr auto Swizzle(TUPLE&& t, std::index_sequence) { + return std::make_tuple(std::get(std::forward(t))...); +} + +/// @returns a nullptr of the tuple type `TUPLE` swizzled by `INDICES`. +/// @note: This function is intended to be used in a `decltype()` expression, +/// and returns a pointer-to-tuple as the tuple may hold non-constructable +/// types. +template +constexpr auto* SwizzlePtrTy(std::index_sequence) { + using Swizzled = std::tuple...>; + return static_cast(nullptr); +} + +} // namespace detail + +/// @returns the slice of the tuple `t` with the tuple elements +/// `[OFFSET..OFFSET+COUNT)` +template +constexpr auto Slice(TUPLE&& t) { + return detail::Swizzle(std::forward(t), Range()); +} + +/// Resolves to the slice of the tuple `t` with the tuple elements +/// `[OFFSET..OFFSET+COUNT)` +template +using SliceTuple = std::remove_pointer_t(Range()))>; + +} // namespace tint::traits #endif // SRC_TRAITS_H_ diff --git a/src/traits_test.cc b/src/traits_test.cc index 43cd209477..74538666d1 100644 --- a/src/traits_test.cc +++ b/src/traits_test.cc @@ -28,13 +28,12 @@ void F3(int, S, float) {} TEST(ParamType, Function) { F1({}); // Avoid unused method warning F3(0, {}, 0); // Avoid unused method warning - static_assert(std::is_same, S>::value, ""); - static_assert(std::is_same, int>::value, ""); - static_assert(std::is_same, S>::value, ""); - static_assert(std::is_same, float>::value, - ""); - static_assert(std::is_same, void>::value, ""); - static_assert(std::is_same, void>::value, ""); + static_assert(std::is_same_v, S>, ""); + static_assert(std::is_same_v, int>, ""); + static_assert(std::is_same_v, S>, ""); + static_assert(std::is_same_v, float>, ""); + static_assert(std::is_same_v, void>, ""); + static_assert(std::is_same_v, void>, ""); static_assert(SignatureOfT::parameter_count == 1, ""); static_assert(SignatureOfT::parameter_count == 3, ""); } @@ -47,14 +46,12 @@ TEST(ParamType, Method) { }; C().F1({}); // Avoid unused method warning C().F3(0, {}, 0); // Avoid unused method warning - static_assert(std::is_same, S>::value, ""); - static_assert(std::is_same, int>::value, - ""); - static_assert(std::is_same, S>::value, ""); - static_assert(std::is_same, float>::value, - ""); - static_assert(std::is_same, void>::value, ""); - static_assert(std::is_same, void>::value, ""); + static_assert(std::is_same_v, S>, ""); + static_assert(std::is_same_v, int>, ""); + static_assert(std::is_same_v, S>, ""); + static_assert(std::is_same_v, float>, ""); + static_assert(std::is_same_v, void>, ""); + static_assert(std::is_same_v, void>, ""); static_assert(SignatureOfT::parameter_count == 1, ""); static_assert(SignatureOfT::parameter_count == 3, ""); } @@ -67,14 +64,12 @@ TEST(ParamType, ConstMethod) { }; C().F1({}); // Avoid unused method warning C().F3(0, {}, 0); // Avoid unused method warning - static_assert(std::is_same, S>::value, ""); - static_assert(std::is_same, int>::value, - ""); - static_assert(std::is_same, S>::value, ""); - static_assert(std::is_same, float>::value, - ""); - static_assert(std::is_same, void>::value, ""); - static_assert(std::is_same, void>::value, ""); + static_assert(std::is_same_v, S>, ""); + static_assert(std::is_same_v, int>, ""); + static_assert(std::is_same_v, S>, ""); + static_assert(std::is_same_v, float>, ""); + static_assert(std::is_same_v, void>, ""); + static_assert(std::is_same_v, void>, ""); static_assert(SignatureOfT::parameter_count == 1, ""); static_assert(SignatureOfT::parameter_count == 3, ""); } @@ -87,14 +82,12 @@ TEST(ParamType, StaticMethod) { }; C::F1({}); // Avoid unused method warning C::F3(0, {}, 0); // Avoid unused method warning - static_assert(std::is_same, S>::value, ""); - static_assert(std::is_same, int>::value, - ""); - static_assert(std::is_same, S>::value, ""); - static_assert(std::is_same, float>::value, - ""); - static_assert(std::is_same, void>::value, ""); - static_assert(std::is_same, void>::value, ""); + static_assert(std::is_same_v, S>, ""); + static_assert(std::is_same_v, int>, ""); + static_assert(std::is_same_v, S>, ""); + static_assert(std::is_same_v, float>, ""); + static_assert(std::is_same_v, void>, ""); + static_assert(std::is_same_v, void>, ""); static_assert(SignatureOfT::parameter_count == 1, ""); static_assert(SignatureOfT::parameter_count == 3, ""); } @@ -102,12 +95,12 @@ TEST(ParamType, StaticMethod) { TEST(ParamType, FunctionLike) { using F1 = std::function; using F3 = std::function; - static_assert(std::is_same, S>::value, ""); - static_assert(std::is_same, int>::value, ""); - static_assert(std::is_same, S>::value, ""); - static_assert(std::is_same, float>::value, ""); - static_assert(std::is_same, void>::value, ""); - static_assert(std::is_same, void>::value, ""); + static_assert(std::is_same_v, S>, ""); + static_assert(std::is_same_v, int>, ""); + static_assert(std::is_same_v, S>, ""); + static_assert(std::is_same_v, float>, ""); + static_assert(std::is_same_v, void>, ""); + static_assert(std::is_same_v, void>, ""); static_assert(SignatureOfT::parameter_count == 1, ""); static_assert(SignatureOfT::parameter_count == 3, ""); } @@ -115,15 +108,117 @@ TEST(ParamType, FunctionLike) { TEST(ParamType, Lambda) { auto l1 = [](S) {}; auto l3 = [](int, S, float) {}; - static_assert(std::is_same, S>::value, ""); - static_assert(std::is_same, int>::value, ""); - static_assert(std::is_same, S>::value, ""); - static_assert(std::is_same, float>::value, ""); - static_assert(std::is_same, void>::value, ""); - static_assert(std::is_same, void>::value, ""); + static_assert(std::is_same_v, S>, ""); + static_assert(std::is_same_v, int>, ""); + static_assert(std::is_same_v, S>, ""); + static_assert(std::is_same_v, float>, ""); + static_assert(std::is_same_v, void>, ""); + static_assert(std::is_same_v, void>, ""); static_assert(SignatureOfT::parameter_count == 1, ""); static_assert(SignatureOfT::parameter_count == 3, ""); } +TEST(Slice, Empty) { + auto sliced = Slice<0, 0>(std::make_tuple<>()); + static_assert(std::tuple_size_v == 0, ""); +} + +TEST(Slice, SingleElementSliceEmpty) { + auto sliced = Slice<0, 0>(std::make_tuple(1)); + static_assert(std::tuple_size_v == 0, ""); +} + +TEST(Slice, SingleElementSliceFull) { + auto sliced = Slice<0, 1>(std::make_tuple(1)); + static_assert(std::tuple_size_v == 1, ""); + static_assert(std::is_same_v, int>, + ""); + EXPECT_EQ(std::get<0>(sliced), 1); +} + +TEST(Slice, MixedTupleSliceEmpty) { + auto sliced = Slice<1, 0>(std::make_tuple(1, true, 2.0f)); + static_assert(std::tuple_size_v == 0, ""); +} + +TEST(Slice, MixedTupleSliceFull) { + auto sliced = Slice<0, 3>(std::make_tuple(1, true, 2.0f)); + static_assert(std::tuple_size_v == 3, ""); + static_assert(std::is_same_v, int>, + ""); + static_assert(std::is_same_v, bool>, + ""); + static_assert( + std::is_same_v, float>, ""); + EXPECT_EQ(std::get<0>(sliced), 1); + EXPECT_EQ(std::get<1>(sliced), true); + EXPECT_EQ(std::get<2>(sliced), 2.0f); +} + +TEST(Slice, MixedTupleSliceLowPart) { + auto sliced = Slice<0, 2>(std::make_tuple(1, true, 2.0f)); + static_assert(std::tuple_size_v == 2, ""); + static_assert(std::is_same_v, int>, + ""); + static_assert(std::is_same_v, bool>, + ""); + EXPECT_EQ(std::get<0>(sliced), 1); + EXPECT_EQ(std::get<1>(sliced), true); +} + +TEST(Slice, MixedTupleSliceHighPart) { + auto sliced = Slice<1, 2>(std::make_tuple(1, true, 2.0f)); + static_assert(std::tuple_size_v == 2, ""); + static_assert(std::is_same_v, bool>, + ""); + static_assert( + std::is_same_v, float>, ""); + EXPECT_EQ(std::get<0>(sliced), true); + EXPECT_EQ(std::get<1>(sliced), 2.0f); +} + +TEST(SliceTuple, Empty) { + using sliced = SliceTuple<0, 0, std::tuple<>>; + static_assert(std::tuple_size_v == 0, ""); +} + +TEST(SliceTuple, SingleElementSliceEmpty) { + using sliced = SliceTuple<0, 0, std::tuple>; + static_assert(std::tuple_size_v == 0, ""); +} + +TEST(SliceTuple, SingleElementSliceFull) { + using sliced = SliceTuple<0, 1, std::tuple>; + static_assert(std::tuple_size_v == 1, ""); + static_assert(std::is_same_v, int>, ""); +} + +TEST(SliceTuple, MixedTupleSliceEmpty) { + using sliced = SliceTuple<1, 0, std::tuple>; + static_assert(std::tuple_size_v == 0, ""); +} + +TEST(SliceTuple, MixedTupleSliceFull) { + using sliced = SliceTuple<0, 3, std::tuple>; + static_assert(std::tuple_size_v == 3, ""); + static_assert(std::is_same_v, int>, ""); + static_assert(std::is_same_v, bool>, ""); + static_assert(std::is_same_v, float>, ""); +} + +TEST(SliceTuple, MixedTupleSliceLowPart) { + using sliced = SliceTuple<0, 2, std::tuple>; + static_assert(std::tuple_size_v == 2, ""); + static_assert(std::is_same_v, int>, ""); + static_assert(std::is_same_v, bool>, ""); +} + +TEST(SliceTuple, MixedTupleSliceHighPart) { + using sliced = SliceTuple<1, 2, std::tuple>; + static_assert(std::tuple_size_v == 2, ""); + static_assert(std::is_same_v, bool>, ""); + static_assert(std::is_same_v, float>, ""); +} + } // namespace traits } // namespace tint