diff --git a/src/tint/castable.h b/src/tint/castable.h index acb9a1837a..71e3ab86fc 100644 --- a/src/tint/castable.h +++ b/src/tint/castable.h @@ -21,9 +21,7 @@ #include #include "src/tint/traits.h" -#include "src/tint/utils/bitcast.h" #include "src/tint/utils/crc32.h" -#include "src/tint/utils/defer.h" #if defined(__clang__) /// Temporarily disable certain warnings when using Castable API diff --git a/src/tint/reader/wgsl/parser_impl.cc b/src/tint/reader/wgsl/parser_impl.cc index e4eab291af..693e1b3c7c 100644 --- a/src/tint/reader/wgsl/parser_impl.cc +++ b/src/tint/reader/wgsl/parser_impl.cc @@ -42,6 +42,7 @@ #include "src/tint/type/multisampled_texture.h" #include "src/tint/type/sampled_texture.h" #include "src/tint/type/texture_dimension.h" +#include "src/tint/utils/defer.h" #include "src/tint/utils/reverse.h" #include "src/tint/utils/string.h" #include "src/tint/utils/string_stream.h" diff --git a/src/tint/resolver/uniformity.cc b/src/tint/resolver/uniformity.cc index cf53c4c0a6..200221a602 100644 --- a/src/tint/resolver/uniformity.cc +++ b/src/tint/resolver/uniformity.cc @@ -39,6 +39,7 @@ #include "src/tint/sem/while_statement.h" #include "src/tint/switch.h" #include "src/tint/utils/block_allocator.h" +#include "src/tint/utils/defer.h" #include "src/tint/utils/map.h" #include "src/tint/utils/string_stream.h" #include "src/tint/utils/unique_vector.h" diff --git a/src/tint/switch.h b/src/tint/switch.h index 9ae8d3b03a..51161958af 100644 --- a/src/tint/switch.h +++ b/src/tint/switch.h @@ -19,6 +19,8 @@ #include #include "src/tint/castable.h" +#include "src/tint/utils/bitcast.h" +#include "src/tint/utils/defer.h" namespace tint { @@ -62,126 +64,6 @@ constexpr int IndexOfDefaultCase() { } } -/// The implementation of Switch() for non-Default cases. -/// Switch splits the cases into two a low and high block of cases, and quickly rules out blocks -/// that cannot match by comparing the HashCode of the object and the cases in the block. If a block -/// of cases may match the given object's type, then that block is split into two, and the process -/// recurses. When NonDefaultCases() is called with a single case, then As<> will be used to -/// dynamically cast to the case type and if the cast succeeds, then the case handler is called. -/// @returns true if a case handler was found, otherwise false. -template -inline bool NonDefaultCases([[maybe_unused]] T* object, - const TypeInfo* type, - [[maybe_unused]] RETURN_TYPE* result, - std::tuple&& cases) { - using Cases = std::tuple; - - static constexpr bool kHasReturnType = !std::is_same_v; - static constexpr size_t kNumCases = sizeof...(CASES); - - if constexpr (kNumCases == 0) { - // No cases. Nothing to do. - return false; - } else if constexpr (kNumCases == 1) { // NOLINT: cpplint doesn't understand - // `else if constexpr` - // Single case. - using CaseFunc = std::tuple_element_t<0, Cases>; - static_assert(!IsDefaultCase, "NonDefaultCases called with a Default case"); - // Attempt to dynamically cast the object to the handler type. If that succeeds, call the - // case handler with the cast object. - using CaseType = SwitchCaseType; - if (type->Is()) { - auto* ptr = static_cast(object); - if constexpr (kHasReturnType) { - new (result) RETURN_TYPE(static_cast(std::get<0>(cases)(ptr))); - } else { - std::get<0>(cases)(ptr); - } - return true; - } - return false; - } else { - // Multiple cases. - // Check the hashcode bits to see if there's any possibility of a case matching in these - // cases. If there isn't, we can skip all these cases. - if (MaybeAnyOf(TypeInfo::CombinedHashCodeOf...>(), - type->full_hashcode)) { - // Split the cases into two, and recurse. - constexpr size_t kMid = kNumCases / 2; - return NonDefaultCases(object, type, result, traits::Slice<0, kMid>(cases)) || - NonDefaultCases(object, type, result, - traits::Slice(cases)); - } else { - return false; - } - } -} - -/// The implementation of Switch() for all cases. -/// @see NonDefaultCases -template -inline void SwitchCases(T* object, RETURN_TYPE* result, std::tuple&& cases) { - using Cases = std::tuple; - - static constexpr int kDefaultIndex = detail::IndexOfDefaultCase(); - static constexpr bool kHasDefaultCase = kDefaultIndex >= 0; - static constexpr bool kHasReturnType = !std::is_same_v; - - // Static assertions - static constexpr bool kDefaultIsOK = - kDefaultIndex == -1 || kDefaultIndex == static_cast(std::tuple_size_v - 1); - static constexpr bool kReturnIsOK = - kHasDefaultCase || !kHasReturnType || std::is_constructible_v; - static_assert(kDefaultIsOK, "Default case must be last in Switch()"); - static_assert(kReturnIsOK, - "Switch() requires either a Default case or a return type that is either void or " - "default-constructable"); - - // If the static asserts have fired, don't bother spewing more errors below - static constexpr bool kAllOK = kDefaultIsOK && kReturnIsOK; - if constexpr (kAllOK) { - if (object) { - auto* type = &object->TypeInfo(); - if constexpr (kHasDefaultCase) { - // Evaluate non-default cases. - if (!detail::NonDefaultCases(object, type, result, - traits::Slice<0, kDefaultIndex>(cases))) { - // Nothing matched. Evaluate default case. - if constexpr (kHasReturnType) { - new (result) RETURN_TYPE( - static_cast(std::get(cases)({}))); - } else { - std::get(cases)({}); - } - } - } else { - if (!detail::NonDefaultCases(object, type, result, std::move(cases))) { - // Nothing matched. No default case. - if constexpr (kHasReturnType) { - new (result) RETURN_TYPE(); - } - } - } - } else { - // Object is nullptr, so no cases can match - if constexpr (kHasDefaultCase) { - // Evaluate default case. - if constexpr (kHasReturnType) { - new (result) - RETURN_TYPE(static_cast(std::get(cases)({}))); - } else { - std::get(cases)({}); - } - } else { - // No default case, no case can match. - if constexpr (kHasReturnType) { - new (result) RETURN_TYPE(); - } - } - } - } -} - /// Resolves to T if T is not nullptr_t, otherwise resolves to Ignore. template using NullptrToIgnore = std::conditional_t, Ignore, T>; @@ -282,21 +164,95 @@ namespace tint { template inline auto Switch(T* object, CASES&&... cases) { using ReturnType = detail::SwitchReturnType...>; + static constexpr int kDefaultIndex = detail::IndexOfDefaultCase>(); + static constexpr bool kHasDefaultCase = kDefaultIndex >= 0; static constexpr bool kHasReturnType = !std::is_same_v; + // Static assertions + static constexpr bool kDefaultIsOK = + kDefaultIndex == -1 || kDefaultIndex == static_cast(sizeof...(CASES) - 1); + static constexpr bool kReturnIsOK = + kHasDefaultCase || !kHasReturnType || std::is_constructible_v; + static_assert(kDefaultIsOK, "Default case must be last in Switch()"); + static_assert(kReturnIsOK, + "Switch() requires either a Default case or a return type that is either void or " + "default-constructable"); + + if (!object) { // Object is nullptr, so no cases can match + if constexpr (kHasDefaultCase) { + // Evaluate default case. + auto&& default_case = + std::get(std::forward_as_tuple(std::forward(cases)...)); + return static_cast(default_case(Default{})); + } else { + // No default case, no case can match. + if constexpr (kHasReturnType) { + return ReturnType{}; + } else { + return; + } + } + } + + // Replacement for std::aligned_storage as this is broken on earlier versions of MSVC. + using ReturnTypeOrU8 = std::conditional_t; + struct alignas(alignof(ReturnTypeOrU8)) ReturnStorage { + uint8_t data[sizeof(ReturnTypeOrU8)]; + }; + ReturnStorage storage; + auto* result = utils::Bitcast(&storage); + + const TypeInfo& type_info = object->TypeInfo(); + + // Examines the parameter type of the case function. + // If the parameter is a pointer type that `object` is of, or derives from, then that case + // function is called with `object` cast to that type, and `try_case` returns true. + // If the parameter is of type `Default`, then that case function is called and `try_case` + // returns true. + // Otherwise `try_case` returns false. + // If the case function is called and it returns a value, then this is copy constructed to the + // `result` pointer. + auto try_case = [&](auto&& case_fn) { + using CaseFunc = std::decay_t; + using CaseType = detail::SwitchCaseType; + if constexpr (std::is_same_v) { + if constexpr (kHasReturnType) { + new (result) ReturnType(static_cast(case_fn(Default{}))); + } else { + case_fn(Default{}); + } + return true; + } else { + if (type_info.Is()) { + auto* v = static_cast(object); + if constexpr (kHasReturnType) { + new (result) ReturnType(static_cast(case_fn(v))); + } else { + case_fn(v); + } + return true; + } + } + return false; + }; + + // Use a logical-or fold expression to try each of the cases in turn, until one matches the + // object type or a Default is reached. `handled` is true if a case function was called. + bool handled = ((try_case(std::forward(cases)) || ...)); + if constexpr (kHasReturnType) { - // Replacement for std::aligned_storage as this is broken on earlier versions of MSVC. - struct alignas(alignof(ReturnType)) ReturnStorage { - uint8_t data[sizeof(ReturnType)]; - }; - ReturnStorage storage; - auto* res = utils::Bitcast(&storage); - TINT_DEFER(res->~ReturnType()); - detail::SwitchCases(object, res, std::forward_as_tuple(std::forward(cases)...)); - return *res; - } else { - detail::SwitchCases(object, nullptr, - std::forward_as_tuple(std::forward(cases)...)); + if constexpr (kHasDefaultCase) { + // Default case means there must be a returned value. + // No need to check handled, no requirement for a zero-initializer of ReturnType. + TINT_DEFER(result->~ReturnType()); + return *result; + } else { + if (handled) { + TINT_DEFER(result->~ReturnType()); + return *result; + } + return ReturnType{}; + } } } diff --git a/src/tint/utils/slice.h b/src/tint/utils/slice.h index 719a53e5ff..325c470557 100644 --- a/src/tint/utils/slice.h +++ b/src/tint/utils/slice.h @@ -20,6 +20,7 @@ #include "src/tint/castable.h" #include "src/tint/traits.h" +#include "src/tint/utils/bitcast.h" namespace tint::utils { diff --git a/src/tint/writer/wgsl/generator_impl.cc b/src/tint/writer/wgsl/generator_impl.cc index 0a167eb753..f98cf86389 100644 --- a/src/tint/writer/wgsl/generator_impl.cc +++ b/src/tint/writer/wgsl/generator_impl.cc @@ -35,6 +35,7 @@ #include "src/tint/sem/struct.h" #include "src/tint/sem/switch_statement.h" #include "src/tint/switch.h" +#include "src/tint/utils/defer.h" #include "src/tint/utils/math.h" #include "src/tint/utils/scoped_assignment.h" #include "src/tint/writer/float_to_string.h"