tint: Castable - support non-default-constructable return types

If the Switch() has a default case, then allow support for return types that do not have a default constructor.

Bug: tint:1504
Change-Id: I671ea78fe976138a786e2e0472e1e5f99afa0c5d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/89022
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Ben Clayton 2022-05-05 19:18:00 +00:00 committed by Dawn LUCI CQ
parent 07602e8a8a
commit f7357f89a3
2 changed files with 85 additions and 26 deletions

View File

@ -21,7 +21,9 @@
#include <utility> #include <utility>
#include "src/tint/traits.h" #include "src/tint/traits.h"
#include "src/tint/utils/bitcast.h"
#include "src/tint/utils/crc32.h" #include "src/tint/utils/crc32.h"
#include "src/tint/utils/defer.h"
#if defined(__clang__) #if defined(__clang__)
/// Temporarily disable certain warnings when using Castable API /// Temporarily disable certain warnings when using Castable API
@ -588,7 +590,7 @@ inline bool NonDefaultCases(T* object,
if (type->Is(&TypeInfo::Of<CaseType>())) { if (type->Is(&TypeInfo::Of<CaseType>())) {
auto* ptr = static_cast<CaseType*>(object); auto* ptr = static_cast<CaseType*>(object);
if constexpr (kHasReturnType) { if constexpr (kHasReturnType) {
*result = static_cast<RETURN_TYPE>(std::get<0>(cases)(ptr)); new (result) RETURN_TYPE(static_cast<RETURN_TYPE>(std::get<0>(cases)(ptr)));
} else { } else {
std::get<0>(cases)(ptr); std::get<0>(cases)(ptr);
} }
@ -617,36 +619,61 @@ inline bool NonDefaultCases(T* object,
template <typename T, typename RETURN_TYPE, typename... CASES> template <typename T, typename RETURN_TYPE, typename... CASES>
inline void SwitchCases(T* object, RETURN_TYPE* result, std::tuple<CASES...>&& cases) { inline void SwitchCases(T* object, RETURN_TYPE* result, std::tuple<CASES...>&& cases) {
using Cases = std::tuple<CASES...>; using Cases = std::tuple<CASES...>;
static constexpr int kDefaultIndex = detail::IndexOfDefaultCase<Cases>(); static constexpr int kDefaultIndex = detail::IndexOfDefaultCase<Cases>();
static_assert(kDefaultIndex == -1 || kDefaultIndex == std::tuple_size_v<Cases> - 1,
"Default case must be last in Switch()");
static constexpr bool kHasDefaultCase = kDefaultIndex >= 0; static constexpr bool kHasDefaultCase = kDefaultIndex >= 0;
static constexpr bool kHasReturnType = !std::is_same_v<RETURN_TYPE, void>; static constexpr bool kHasReturnType = !std::is_same_v<RETURN_TYPE, void>;
if (object) { // Static assertions
auto* type = &object->TypeInfo(); static constexpr bool kDefaultIsOK =
if constexpr (kHasDefaultCase) { kDefaultIndex == -1 || kDefaultIndex == std::tuple_size_v<Cases> - 1;
// Evaluate non-default cases. static constexpr bool kReturnIsOK =
if (!detail::NonDefaultCases<T>(object, type, result, kHasDefaultCase || !kHasReturnType || std::is_constructible_v<RETURN_TYPE>;
traits::Slice<0, kDefaultIndex>(cases))) { static_assert(kDefaultIsOK, "Default case must be last in Switch()");
// Nothing matched. Evaluate default case. static_assert(kReturnIsOK,
if constexpr (kHasReturnType) { "Switch() requires either a Default case or a return type that is either void or "
*result = static_cast<RETURN_TYPE>(std::get<kDefaultIndex>(cases)({})); "default-constructable");
} else {
std::get<kDefaultIndex>(cases)({}); // If the static asserts have fired, don't bother spewing more errors below
static constexpr bool kAllOK = kDefaultIsOK && kReturnIsOK;
if constexpr (kAllOK) {
if (object) {
auto* type = &object->TypeInfo();
if constexpr (kHasDefaultCase) {
// Evaluate non-default cases.
if (!detail::NonDefaultCases<T>(object, type, result,
traits::Slice<0, kDefaultIndex>(cases))) {
// Nothing matched. Evaluate default case.
if constexpr (kHasReturnType) {
new (result) RETURN_TYPE(
static_cast<RETURN_TYPE>(std::get<kDefaultIndex>(cases)({})));
} else {
std::get<kDefaultIndex>(cases)({});
}
}
} else {
if (!detail::NonDefaultCases<T>(object, type, result, std::move(cases))) {
// Nothing matched. No default case.
if constexpr (kHasReturnType) {
new (result) RETURN_TYPE();
}
} }
} }
} else { } else {
detail::NonDefaultCases<T>(object, type, result, std::move(cases)); // Object is nullptr, so no cases can match
} if constexpr (kHasDefaultCase) {
} else { // Evaluate default case.
// Object is nullptr, so no cases can match if constexpr (kHasReturnType) {
if constexpr (kHasDefaultCase) { new (result)
// Evaluate default case. RETURN_TYPE(static_cast<RETURN_TYPE>(std::get<kDefaultIndex>(cases)({})));
if constexpr (kHasReturnType) { } else {
*result = static_cast<RETURN_TYPE>(std::get<kDefaultIndex>(cases)({})); std::get<kDefaultIndex>(cases)({});
}
} else { } else {
std::get<kDefaultIndex>(cases)({}); // No default case, no case can match.
if constexpr (kHasReturnType) {
new (result) RETURN_TYPE();
}
} }
} }
} }
@ -760,9 +787,15 @@ inline auto Switch(T* object, CASES&&... cases) {
static constexpr bool kHasReturnType = !std::is_same_v<ReturnType, void>; static constexpr bool kHasReturnType = !std::is_same_v<ReturnType, void>;
if constexpr (kHasReturnType) { if constexpr (kHasReturnType) {
ReturnType res = {}; // Replacement for std::aligned_storage as this is broken on earlier versions of MSVC.
detail::SwitchCases(object, &res, std::forward_as_tuple(std::forward<CASES>(cases)...)); struct alignas(alignof(ReturnType)) ReturnStorage {
return res; uint8_t data[sizeof(ReturnType)];
};
ReturnStorage storage;
auto* res = utils::Bitcast<ReturnType*>(&storage);
TINT_DEFER(res->~ReturnType());
detail::SwitchCases(object, res, std::forward_as_tuple(std::forward<CASES>(cases)...));
return *res;
} else { } else {
detail::SwitchCases<T, void>(object, nullptr, detail::SwitchCases<T, void>(object, nullptr,
std::forward_as_tuple(std::forward<CASES>(cases)...)); std::forward_as_tuple(std::forward<CASES>(cases)...));

View File

@ -710,6 +710,32 @@ TEST(Castable, SwitchNullNoDefault) {
EXPECT_TRUE(default_called); EXPECT_TRUE(default_called);
} }
TEST(Castable, SwitchReturnNoDefaultConstructor) {
struct Object {
explicit Object(int v) : value(v) {}
int value;
};
std::unique_ptr<Animal> frog = std::make_unique<Frog>();
{
auto result = Switch(
frog.get(), //
[](Mammal*) { return Object(1); }, //
[](Amphibian*) { return Object(2); }, //
[](Default) { return Object(3); });
static_assert(std::is_same_v<decltype(result), Object>);
EXPECT_EQ(result.value, 2);
}
{
auto result = Switch(
frog.get(), //
[](Mammal*) { return Object(1); }, //
[](Default) { return Object(3); });
static_assert(std::is_same_v<decltype(result), Object>);
EXPECT_EQ(result.value, 3);
}
}
// IsCastable static tests // IsCastable static tests
static_assert(IsCastable<CastableBase>); static_assert(IsCastable<CastableBase>);
static_assert(IsCastable<Animal>); static_assert(IsCastable<Animal>);