diff --git a/src/tint/castable.h b/src/tint/castable.h index 82fcf6bde0..d6597fc942 100644 --- a/src/tint/castable.h +++ b/src/tint/castable.h @@ -86,11 +86,36 @@ enum CastFlags { kDontErrorOnImpossibleCast = 1, }; +/// The type of a hash code +using HashCode = uint64_t; + +/// Maybe checks to see if an object with the full hashcode @p object_full_hashcode could +/// potentially be of, or derive from the type with the hashcode @p query_hashcode. +/// @param type_hashcode the hashcode of the type +/// @param object_full_hashcode the full hashcode of the object being queried +/// @returns true if the object with the given full hashcode could be one of the template types. +inline bool Maybe(HashCode type_hashcode, HashCode object_full_hashcode) { + return (object_full_hashcode & type_hashcode) == type_hashcode; +} + +/// MaybeAnyOf checks to see if an object with the full hashcode @p object_full_hashcode could +/// potentially be of, or derive from the types with the combined hashcode @p combined_hashcode. +/// @param combined_hashcode the bitwise OR'd hashcodes of the types +/// @param object_full_hashcode the full hashcode of the object being queried +/// @returns true if the object with the given full hashcode could be one of the template types. +inline bool MaybeAnyOf(HashCode combined_hashcode, HashCode object_full_hashcode) { + // Compare the object's hashcode to the bitwise-or of all the tested type's hashcodes. If + // there's no intersection of bits in the two masks, then we can guarantee that the type is not + // in `TO`. + HashCode mask = object_full_hashcode & combined_hashcode; + // HashCodeOf() ensures that two bits are always set for every hash, so we can quickly + // eliminate the bitmask where only one bit is set. + HashCode two_bits = mask & (mask - 1); + return two_bits != 0; +} + /// TypeInfo holds type information for a Castable type. struct TypeInfo { - /// The type of a hash code - using HashCode = uint64_t; - /// The base class of this type const TypeInfo* base; /// The type name @@ -133,10 +158,7 @@ struct TypeInfo { /// @returns true if the class with this TypeInfo is of, or derives from the /// class with the given TypeInfo. inline bool Is(const tint::TypeInfo* type) const { - // Optimization: Check whether the all the bits of the type's hashcode can be found in the - // full_hashcode. If a single bit is missing, then we can quickly tell that that this - // TypeInfo does not derive from `type`. - if ((full_hashcode & type->hashcode) != type->hashcode) { + if (!Maybe(type->hashcode, full_hashcode)) { return false; } @@ -221,14 +243,7 @@ struct TypeInfo { } else if constexpr (kCount == 1) { return Is(&Of>()); } else { - // Optimization: Compare the object's hashcode to the bitwise-or of all the tested - // type's hashcodes. If there's no intersection of bits in the two masks, then we can - // guarantee that the type is not in `TO`. - HashCode mask = full_hashcode & TypeInfo::CombinedHashCodeOfTuple(); - // HashCodeOf() ensures that two bits are always set for every hash, so we can quickly - // eliminate the bitmask where only one bit is set. - HashCode two_bits = mask & (mask - 1); - if (two_bits) { + if (MaybeAnyOf(TypeInfo::CombinedHashCodeOfTuple(), full_hashcode)) { // Possibly one of the types in `TUPLE`. // Split the search in two, and scan each block. static constexpr auto kMid = kCount / 2; @@ -597,13 +612,8 @@ inline bool NonDefaultCases([[maybe_unused]] T* object, // Multiple cases. // Check the hashcode bits to see if there's any possibility of a case matching in these // cases. If there isn't, we can skip all these cases. - TypeInfo::HashCode mask = - type->full_hashcode & TypeInfo::CombinedHashCodeOf...>(); - // HashCodeOf() ensures that two bits are always set for every hash, so we can quickly - // eliminate the bitmask where only one bit is set. - TypeInfo::HashCode two_bits = mask & (mask - 1); - if (two_bits) { - // There's a possibility. We need to scan further. + if (MaybeAnyOf(TypeInfo::CombinedHashCodeOf...>(), + type->full_hashcode)) { // Split the cases into two, and recurse. constexpr size_t kMid = kNumCases / 2; return NonDefaultCases(object, type, result, traits::Slice<0, kMid>(cases)) ||