From f33f1b41ff057741e8616379ce2bdf515aca88cd Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Fri, 25 Feb 2022 20:24:42 +0000 Subject: [PATCH] castable: Make Switch() smarter about return types Infer the return type by finding the common type across all cases. Types that derive from CastableBase will automatically infer to the common base class. Change-Id: I2112ca1abae34e55396685e9ebf2da12f8a6e3fc Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/80320 Reviewed-by: Antonio Maiorano Kokoro: Kokoro Commit-Queue: Ben Clayton Auto-Submit: Ben Clayton --- src/tint/castable.h | 90 ++++++++++- src/tint/castable_test.cc | 315 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 399 insertions(+), 6 deletions(-) diff --git a/src/tint/castable.h b/src/tint/castable.h index 51cad0093b..05144f79e3 100644 --- a/src/tint/castable.h +++ b/src/tint/castable.h @@ -607,7 +607,7 @@ inline bool NonDefaultCases(T* object, if (type->Is(&TypeInfo::Of())) { auto* ptr = static_cast(object); if constexpr (kHasReturnType) { - *result = std::get<0>(cases)(ptr); + *result = static_cast(std::get<0>(cases)(ptr)); } else { std::get<0>(cases)(ptr); } @@ -654,7 +654,8 @@ inline void SwitchCases(T* object, traits::Slice<0, kDefaultIndex>(cases))) { // Nothing matched. Evaluate default case. if constexpr (kHasReturnType) { - *result = std::get(cases)({}); + *result = + static_cast(std::get(cases)({})); } else { std::get(cases)({}); } @@ -667,7 +668,7 @@ inline void SwitchCases(T* object, if constexpr (kHasDefaultCase) { // Evaluate default case. if constexpr (kHasReturnType) { - *result = std::get(cases)({}); + *result = static_cast(std::get(cases)({})); } else { std::get(cases)({}); } @@ -675,6 +676,81 @@ inline void SwitchCases(T* object, } } +/// Resolves to T if T is not nullptr_t, otherwise resolves to Ignore. +template +using NullptrToIgnore = + std::conditional_t, Ignore, T>; + +/// Resolves to `const TYPE` if any of `CASE_RETURN_TYPES` are const or +/// pointer-to-const, otherwise resolves to TYPE. +template +using PropagateReturnConst = std::conditional_t< + // Are any of the pointer-stripped types const? + (std::is_const_v> || ...), + const TYPE, // Yes: Apply const to TYPE + TYPE>; // No: Passthrough + +/// SwitchReturnTypeImpl is the implementation of SwitchReturnType +template +struct SwitchReturnTypeImpl; + +/// SwitchReturnTypeImpl specialization for non-castable case types and an +/// explicitly specified return type. +template +struct SwitchReturnTypeImpl { + /// Resolves to `REQUESTED_TYPE` + using type = REQUESTED_TYPE; +}; + +/// SwitchReturnTypeImpl specialization for non-castable case types and an +/// inferred return type. +template +struct SwitchReturnTypeImpl { + /// Resolves to the common type for all the cases return types. + using type = std::common_type_t; +}; + +/// SwitchReturnTypeImpl specialization for castable case types and an +/// explicitly specified return type. +template +struct SwitchReturnTypeImpl { + public: + /// Resolves to `const REQUESTED_TYPE*` or `REQUESTED_TYPE*` + using type = PropagateReturnConst, + CASE_RETURN_TYPES...>*; +}; + +/// SwitchReturnTypeImpl specialization for castable case types and an infered +/// return type. +template +struct SwitchReturnTypeImpl { + private: + using InferredType = CastableCommonBase< + detail::NullptrToIgnore>...>; + + public: + /// `const T*` or `T*`, where T is the common base type for all the castable + /// case types. + using type = PropagateReturnConst*; +}; + +/// Resolves to the return type for a Switch() with the requested return type +/// `REQUESTED_TYPE` and case statement return types. If `REQUESTED_TYPE` is +/// Infer then the return type will be inferred from the case return types. +template +using SwitchReturnType = typename SwitchReturnTypeImpl< + IsCastable>...>, + REQUESTED_TYPE, + CASE_RETURN_TYPES...>::type; + } // namespace detail /// Switch is used to dispatch one of the provided callback case handler @@ -712,10 +788,12 @@ inline void SwitchCases(T* object, /// @param cases the switch cases /// @return the value returned by the called case. If no cases matched, then the /// zero value for the consistent case type. -template +template inline auto Switch(T* object, CASES&&... cases) { - using Cases = std::tuple; - using ReturnType = traits::ReturnType>; + using ReturnType = + detail::SwitchReturnType...>; static constexpr bool kHasReturnType = !std::is_same_v; if constexpr (kHasReturnType) { diff --git a/src/tint/castable_test.cc b/src/tint/castable_test.cc index e5698f9b72..7ed66cbfa9 100644 --- a/src/tint/castable_test.cc +++ b/src/tint/castable_test.cc @@ -380,6 +380,321 @@ TEST(Castable, SwitchMatchFirst) { } } +TEST(Castable, SwitchReturnValueWithDefault) { + std::unique_ptr frog = std::make_unique(); + std::unique_ptr bear = std::make_unique(); + std::unique_ptr gecko = std::make_unique(); + { + const char* result = Switch( + frog.get(), // + [](Mammal*) { return "mammal"; }, // + [](Amphibian*) { return "amphibian"; }, // + [](Default) { return "unknown"; }); + static_assert(std::is_same_v); + EXPECT_EQ(std::string(result), "amphibian"); + } + { + const char* result = Switch( + bear.get(), // + [](Mammal*) { return "mammal"; }, // + [](Amphibian*) { return "amphibian"; }, // + [](Default) { return "unknown"; }); + static_assert(std::is_same_v); + EXPECT_EQ(std::string(result), "mammal"); + } + { + const char* result = Switch( + gecko.get(), // + [](Mammal*) { return "mammal"; }, // + [](Amphibian*) { return "amphibian"; }, // + [](Default) { return "unknown"; }); + static_assert(std::is_same_v); + EXPECT_EQ(std::string(result), "unknown"); + } +} + +TEST(Castable, SwitchReturnValueWithoutDefault) { + std::unique_ptr frog = std::make_unique(); + std::unique_ptr bear = std::make_unique(); + std::unique_ptr gecko = std::make_unique(); + { + const char* result = Switch( + frog.get(), // + [](Mammal*) { return "mammal"; }, // + [](Amphibian*) { return "amphibian"; }); + static_assert(std::is_same_v); + EXPECT_EQ(std::string(result), "amphibian"); + } + { + const char* result = Switch( + bear.get(), // + [](Mammal*) { return "mammal"; }, // + [](Amphibian*) { return "amphibian"; }); + static_assert(std::is_same_v); + EXPECT_EQ(std::string(result), "mammal"); + } + { + auto* result = Switch( + gecko.get(), // + [](Mammal*) { return "mammal"; }, // + [](Amphibian*) { return "amphibian"; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, nullptr); + } +} + +TEST(Castable, SwitchInferPODReturnTypeWithDefault) { + std::unique_ptr frog = std::make_unique(); + std::unique_ptr bear = std::make_unique(); + std::unique_ptr gecko = std::make_unique(); + { + auto result = Switch( + frog.get(), // + [](Mammal*) { return 1; }, // + [](Amphibian*) { return 2.0f; }, // + [](Default) { return 3.0; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, 2.0); + } + { + auto result = Switch( + bear.get(), // + [](Mammal*) { return 1.0; }, // + [](Amphibian*) { return 2.0f; }, // + [](Default) { return 3; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, 1.0); + } + { + auto result = Switch( + gecko.get(), // + [](Mammal*) { return 1.0f; }, // + [](Amphibian*) { return 2; }, // + [](Default) { return 3.0; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, 3.0); + } +} + +TEST(Castable, SwitchInferPODReturnTypeWithoutDefault) { + std::unique_ptr frog = std::make_unique(); + std::unique_ptr bear = std::make_unique(); + std::unique_ptr gecko = std::make_unique(); + { + auto result = Switch( + frog.get(), // + [](Mammal*) { return 1; }, // + [](Amphibian*) { return 2.0f; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, 2.0f); + } + { + auto result = Switch( + bear.get(), // + [](Mammal*) { return 1.0f; }, // + [](Amphibian*) { return 2; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, 1.0f); + } + { + auto result = Switch( + gecko.get(), // + [](Mammal*) { return 1.0; }, // + [](Amphibian*) { return 2.0f; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, 0.0); + } +} + +TEST(Castable, SwitchInferCastableReturnTypeWithDefault) { + std::unique_ptr frog = std::make_unique(); + std::unique_ptr bear = std::make_unique(); + std::unique_ptr gecko = std::make_unique(); + { + auto* result = Switch( + frog.get(), // + [](Mammal* p) { return p; }, // + [](Amphibian*) { return nullptr; }, // + [](Default) { return nullptr; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, nullptr); + } + { + auto* result = Switch( + bear.get(), // + [](Mammal* p) { return p; }, // + [](Amphibian* p) { return const_cast(p); }, + [](Default) { return nullptr; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, bear.get()); + } + { + auto* result = Switch( + gecko.get(), // + [](Mammal* p) { return p; }, // + [](Amphibian* p) { return p; }, // + [](Default) -> CastableBase* { return nullptr; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, nullptr); + } +} + +TEST(Castable, SwitchInferCastableReturnTypeWithoutDefault) { + std::unique_ptr frog = std::make_unique(); + std::unique_ptr bear = std::make_unique(); + std::unique_ptr gecko = std::make_unique(); + { + auto* result = Switch( + frog.get(), // + [](Mammal* p) { return p; }, // + [](Amphibian*) { return nullptr; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, nullptr); + } + { + auto* result = Switch( + bear.get(), // + [](Mammal* p) { return p; }, // + [](Amphibian* p) { return const_cast(p); }); // + static_assert(std::is_same_v); + EXPECT_EQ(result, bear.get()); + } + { + auto* result = Switch( + gecko.get(), // + [](Mammal* p) { return p; }, // + [](Amphibian* p) { return p; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, nullptr); + } +} + +TEST(Castable, SwitchExplicitPODReturnTypeWithDefault) { + std::unique_ptr frog = std::make_unique(); + std::unique_ptr bear = std::make_unique(); + std::unique_ptr gecko = std::make_unique(); + { + auto result = Switch( + frog.get(), // + [](Mammal*) { return 1; }, // + [](Amphibian*) { return 2.0f; }, // + [](Default) { return 3.0; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, 2.0f); + } + { + auto result = Switch( + bear.get(), // + [](Mammal*) { return 1; }, // + [](Amphibian*) { return 2; }, // + [](Default) { return 3; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, 1.0f); + } + { + auto result = Switch( + gecko.get(), // + [](Mammal*) { return 1.0f; }, // + [](Amphibian*) { return 2.0f; }, // + [](Default) { return 3.0f; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, 3.0f); + } +} + +TEST(Castable, SwitchExplicitPODReturnTypeWithoutDefault) { + std::unique_ptr frog = std::make_unique(); + std::unique_ptr bear = std::make_unique(); + std::unique_ptr gecko = std::make_unique(); + { + auto result = Switch( + frog.get(), // + [](Mammal*) { return 1; }, // + [](Amphibian*) { return 2.0f; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, 2.0f); + } + { + auto result = Switch( + bear.get(), // + [](Mammal*) { return 1.0f; }, // + [](Amphibian*) { return 2; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, 1.0f); + } + { + auto result = Switch( + gecko.get(), // + [](Mammal*) { return 1.0; }, // + [](Amphibian*) { return 2.0f; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, 0.0); + } +} + +TEST(Castable, SwitchExplicitCastableReturnTypeWithDefault) { + std::unique_ptr frog = std::make_unique(); + std::unique_ptr bear = std::make_unique(); + std::unique_ptr gecko = std::make_unique(); + { + auto* result = Switch( + frog.get(), // + [](Mammal* p) { return p; }, // + [](Amphibian*) { return nullptr; }, // + [](Default) { return nullptr; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, nullptr); + } + { + auto* result = Switch( + bear.get(), // + [](Mammal* p) { return p; }, // + [](Amphibian* p) { return const_cast(p); }, + [](Default) { return nullptr; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, bear.get()); + } + { + auto* result = Switch( + gecko.get(), // + [](Mammal* p) { return p; }, // + [](Amphibian* p) { return p; }, // + [](Default) { return nullptr; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, nullptr); + } +} + +TEST(Castable, SwitchExplicitCastableReturnTypeWithoutDefault) { + std::unique_ptr frog = std::make_unique(); + std::unique_ptr bear = std::make_unique(); + std::unique_ptr gecko = std::make_unique(); + { + auto* result = Switch( + frog.get(), // + [](Mammal* p) { return p; }, // + [](Amphibian*) { return nullptr; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, nullptr); + } + { + auto* result = Switch( + bear.get(), // + [](Mammal* p) { return p; }, // + [](Amphibian* p) { return const_cast(p); }); // + static_assert(std::is_same_v); + EXPECT_EQ(result, bear.get()); + } + { + auto* result = Switch( + gecko.get(), // + [](Mammal* p) { return p; }, // + [](Amphibian* p) { return p; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, nullptr); + } +} + TEST(Castable, SwitchNull) { Animal* null = nullptr; Switch(