diff --git a/src/tint/castable.h b/src/tint/castable.h index 51cad0093b..05144f79e3 100644 --- a/src/tint/castable.h +++ b/src/tint/castable.h @@ -607,7 +607,7 @@ inline bool NonDefaultCases(T* object, if (type->Is(&TypeInfo::Of())) { auto* ptr = static_cast(object); if constexpr (kHasReturnType) { - *result = std::get<0>(cases)(ptr); + *result = static_cast(std::get<0>(cases)(ptr)); } else { std::get<0>(cases)(ptr); } @@ -654,7 +654,8 @@ inline void SwitchCases(T* object, traits::Slice<0, kDefaultIndex>(cases))) { // Nothing matched. Evaluate default case. if constexpr (kHasReturnType) { - *result = std::get(cases)({}); + *result = + static_cast(std::get(cases)({})); } else { std::get(cases)({}); } @@ -667,7 +668,7 @@ inline void SwitchCases(T* object, if constexpr (kHasDefaultCase) { // Evaluate default case. if constexpr (kHasReturnType) { - *result = std::get(cases)({}); + *result = static_cast(std::get(cases)({})); } else { std::get(cases)({}); } @@ -675,6 +676,81 @@ inline void SwitchCases(T* object, } } +/// Resolves to T if T is not nullptr_t, otherwise resolves to Ignore. +template +using NullptrToIgnore = + std::conditional_t, Ignore, T>; + +/// Resolves to `const TYPE` if any of `CASE_RETURN_TYPES` are const or +/// pointer-to-const, otherwise resolves to TYPE. +template +using PropagateReturnConst = std::conditional_t< + // Are any of the pointer-stripped types const? + (std::is_const_v> || ...), + const TYPE, // Yes: Apply const to TYPE + TYPE>; // No: Passthrough + +/// SwitchReturnTypeImpl is the implementation of SwitchReturnType +template +struct SwitchReturnTypeImpl; + +/// SwitchReturnTypeImpl specialization for non-castable case types and an +/// explicitly specified return type. +template +struct SwitchReturnTypeImpl { + /// Resolves to `REQUESTED_TYPE` + using type = REQUESTED_TYPE; +}; + +/// SwitchReturnTypeImpl specialization for non-castable case types and an +/// inferred return type. +template +struct SwitchReturnTypeImpl { + /// Resolves to the common type for all the cases return types. + using type = std::common_type_t; +}; + +/// SwitchReturnTypeImpl specialization for castable case types and an +/// explicitly specified return type. +template +struct SwitchReturnTypeImpl { + public: + /// Resolves to `const REQUESTED_TYPE*` or `REQUESTED_TYPE*` + using type = PropagateReturnConst, + CASE_RETURN_TYPES...>*; +}; + +/// SwitchReturnTypeImpl specialization for castable case types and an infered +/// return type. +template +struct SwitchReturnTypeImpl { + private: + using InferredType = CastableCommonBase< + detail::NullptrToIgnore>...>; + + public: + /// `const T*` or `T*`, where T is the common base type for all the castable + /// case types. + using type = PropagateReturnConst*; +}; + +/// Resolves to the return type for a Switch() with the requested return type +/// `REQUESTED_TYPE` and case statement return types. If `REQUESTED_TYPE` is +/// Infer then the return type will be inferred from the case return types. +template +using SwitchReturnType = typename SwitchReturnTypeImpl< + IsCastable>...>, + REQUESTED_TYPE, + CASE_RETURN_TYPES...>::type; + } // namespace detail /// Switch is used to dispatch one of the provided callback case handler @@ -712,10 +788,12 @@ inline void SwitchCases(T* object, /// @param cases the switch cases /// @return the value returned by the called case. If no cases matched, then the /// zero value for the consistent case type. -template +template inline auto Switch(T* object, CASES&&... cases) { - using Cases = std::tuple; - using ReturnType = traits::ReturnType>; + using ReturnType = + detail::SwitchReturnType...>; static constexpr bool kHasReturnType = !std::is_same_v; if constexpr (kHasReturnType) { diff --git a/src/tint/castable_test.cc b/src/tint/castable_test.cc index e5698f9b72..7ed66cbfa9 100644 --- a/src/tint/castable_test.cc +++ b/src/tint/castable_test.cc @@ -380,6 +380,321 @@ TEST(Castable, SwitchMatchFirst) { } } +TEST(Castable, SwitchReturnValueWithDefault) { + std::unique_ptr frog = std::make_unique(); + std::unique_ptr bear = std::make_unique(); + std::unique_ptr gecko = std::make_unique(); + { + const char* result = Switch( + frog.get(), // + [](Mammal*) { return "mammal"; }, // + [](Amphibian*) { return "amphibian"; }, // + [](Default) { return "unknown"; }); + static_assert(std::is_same_v); + EXPECT_EQ(std::string(result), "amphibian"); + } + { + const char* result = Switch( + bear.get(), // + [](Mammal*) { return "mammal"; }, // + [](Amphibian*) { return "amphibian"; }, // + [](Default) { return "unknown"; }); + static_assert(std::is_same_v); + EXPECT_EQ(std::string(result), "mammal"); + } + { + const char* result = Switch( + gecko.get(), // + [](Mammal*) { return "mammal"; }, // + [](Amphibian*) { return "amphibian"; }, // + [](Default) { return "unknown"; }); + static_assert(std::is_same_v); + EXPECT_EQ(std::string(result), "unknown"); + } +} + +TEST(Castable, SwitchReturnValueWithoutDefault) { + std::unique_ptr frog = std::make_unique(); + std::unique_ptr bear = std::make_unique(); + std::unique_ptr gecko = std::make_unique(); + { + const char* result = Switch( + frog.get(), // + [](Mammal*) { return "mammal"; }, // + [](Amphibian*) { return "amphibian"; }); + static_assert(std::is_same_v); + EXPECT_EQ(std::string(result), "amphibian"); + } + { + const char* result = Switch( + bear.get(), // + [](Mammal*) { return "mammal"; }, // + [](Amphibian*) { return "amphibian"; }); + static_assert(std::is_same_v); + EXPECT_EQ(std::string(result), "mammal"); + } + { + auto* result = Switch( + gecko.get(), // + [](Mammal*) { return "mammal"; }, // + [](Amphibian*) { return "amphibian"; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, nullptr); + } +} + +TEST(Castable, SwitchInferPODReturnTypeWithDefault) { + std::unique_ptr frog = std::make_unique(); + std::unique_ptr bear = std::make_unique(); + std::unique_ptr gecko = std::make_unique(); + { + auto result = Switch( + frog.get(), // + [](Mammal*) { return 1; }, // + [](Amphibian*) { return 2.0f; }, // + [](Default) { return 3.0; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, 2.0); + } + { + auto result = Switch( + bear.get(), // + [](Mammal*) { return 1.0; }, // + [](Amphibian*) { return 2.0f; }, // + [](Default) { return 3; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, 1.0); + } + { + auto result = Switch( + gecko.get(), // + [](Mammal*) { return 1.0f; }, // + [](Amphibian*) { return 2; }, // + [](Default) { return 3.0; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, 3.0); + } +} + +TEST(Castable, SwitchInferPODReturnTypeWithoutDefault) { + std::unique_ptr frog = std::make_unique(); + std::unique_ptr bear = std::make_unique(); + std::unique_ptr gecko = std::make_unique(); + { + auto result = Switch( + frog.get(), // + [](Mammal*) { return 1; }, // + [](Amphibian*) { return 2.0f; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, 2.0f); + } + { + auto result = Switch( + bear.get(), // + [](Mammal*) { return 1.0f; }, // + [](Amphibian*) { return 2; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, 1.0f); + } + { + auto result = Switch( + gecko.get(), // + [](Mammal*) { return 1.0; }, // + [](Amphibian*) { return 2.0f; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, 0.0); + } +} + +TEST(Castable, SwitchInferCastableReturnTypeWithDefault) { + std::unique_ptr frog = std::make_unique(); + std::unique_ptr bear = std::make_unique(); + std::unique_ptr gecko = std::make_unique(); + { + auto* result = Switch( + frog.get(), // + [](Mammal* p) { return p; }, // + [](Amphibian*) { return nullptr; }, // + [](Default) { return nullptr; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, nullptr); + } + { + auto* result = Switch( + bear.get(), // + [](Mammal* p) { return p; }, // + [](Amphibian* p) { return const_cast(p); }, + [](Default) { return nullptr; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, bear.get()); + } + { + auto* result = Switch( + gecko.get(), // + [](Mammal* p) { return p; }, // + [](Amphibian* p) { return p; }, // + [](Default) -> CastableBase* { return nullptr; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, nullptr); + } +} + +TEST(Castable, SwitchInferCastableReturnTypeWithoutDefault) { + std::unique_ptr frog = std::make_unique(); + std::unique_ptr bear = std::make_unique(); + std::unique_ptr gecko = std::make_unique(); + { + auto* result = Switch( + frog.get(), // + [](Mammal* p) { return p; }, // + [](Amphibian*) { return nullptr; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, nullptr); + } + { + auto* result = Switch( + bear.get(), // + [](Mammal* p) { return p; }, // + [](Amphibian* p) { return const_cast(p); }); // + static_assert(std::is_same_v); + EXPECT_EQ(result, bear.get()); + } + { + auto* result = Switch( + gecko.get(), // + [](Mammal* p) { return p; }, // + [](Amphibian* p) { return p; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, nullptr); + } +} + +TEST(Castable, SwitchExplicitPODReturnTypeWithDefault) { + std::unique_ptr frog = std::make_unique(); + std::unique_ptr bear = std::make_unique(); + std::unique_ptr gecko = std::make_unique(); + { + auto result = Switch( + frog.get(), // + [](Mammal*) { return 1; }, // + [](Amphibian*) { return 2.0f; }, // + [](Default) { return 3.0; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, 2.0f); + } + { + auto result = Switch( + bear.get(), // + [](Mammal*) { return 1; }, // + [](Amphibian*) { return 2; }, // + [](Default) { return 3; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, 1.0f); + } + { + auto result = Switch( + gecko.get(), // + [](Mammal*) { return 1.0f; }, // + [](Amphibian*) { return 2.0f; }, // + [](Default) { return 3.0f; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, 3.0f); + } +} + +TEST(Castable, SwitchExplicitPODReturnTypeWithoutDefault) { + std::unique_ptr frog = std::make_unique(); + std::unique_ptr bear = std::make_unique(); + std::unique_ptr gecko = std::make_unique(); + { + auto result = Switch( + frog.get(), // + [](Mammal*) { return 1; }, // + [](Amphibian*) { return 2.0f; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, 2.0f); + } + { + auto result = Switch( + bear.get(), // + [](Mammal*) { return 1.0f; }, // + [](Amphibian*) { return 2; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, 1.0f); + } + { + auto result = Switch( + gecko.get(), // + [](Mammal*) { return 1.0; }, // + [](Amphibian*) { return 2.0f; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, 0.0); + } +} + +TEST(Castable, SwitchExplicitCastableReturnTypeWithDefault) { + std::unique_ptr frog = std::make_unique(); + std::unique_ptr bear = std::make_unique(); + std::unique_ptr gecko = std::make_unique(); + { + auto* result = Switch( + frog.get(), // + [](Mammal* p) { return p; }, // + [](Amphibian*) { return nullptr; }, // + [](Default) { return nullptr; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, nullptr); + } + { + auto* result = Switch( + bear.get(), // + [](Mammal* p) { return p; }, // + [](Amphibian* p) { return const_cast(p); }, + [](Default) { return nullptr; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, bear.get()); + } + { + auto* result = Switch( + gecko.get(), // + [](Mammal* p) { return p; }, // + [](Amphibian* p) { return p; }, // + [](Default) { return nullptr; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, nullptr); + } +} + +TEST(Castable, SwitchExplicitCastableReturnTypeWithoutDefault) { + std::unique_ptr frog = std::make_unique(); + std::unique_ptr bear = std::make_unique(); + std::unique_ptr gecko = std::make_unique(); + { + auto* result = Switch( + frog.get(), // + [](Mammal* p) { return p; }, // + [](Amphibian*) { return nullptr; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, nullptr); + } + { + auto* result = Switch( + bear.get(), // + [](Mammal* p) { return p; }, // + [](Amphibian* p) { return const_cast(p); }); // + static_assert(std::is_same_v); + EXPECT_EQ(result, bear.get()); + } + { + auto* result = Switch( + gecko.get(), // + [](Mammal* p) { return p; }, // + [](Amphibian* p) { return p; }); + static_assert(std::is_same_v); + EXPECT_EQ(result, nullptr); + } +} + TEST(Castable, SwitchNull) { Animal* null = nullptr; Switch(