From d586d914780ba28ebb3ff93547262dbb73c8c0de Mon Sep 17 00:00:00 2001 From: dan sinclair Date: Wed, 14 Dec 2022 14:07:07 +0000 Subject: [PATCH] Extract constant convert methods. This CL pulls the convert methods out into standalone methods inside the resolver and de-couples from the constants. Bug: tint:1718 Change-Id: Id566704687b2d74e05eae860477552f88f6a06b9 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/114120 Reviewed-by: Ben Clayton Kokoro: Kokoro Commit-Queue: Dan Sinclair --- src/tint/resolver/const_eval.cc | 258 ++++++++++++++++++-------------- 1 file changed, 143 insertions(+), 115 deletions(-) diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc index 94d60d4c58..d7de6cd509 100644 --- a/src/tint/resolver/const_eval.cc +++ b/src/tint/resolver/const_eval.cc @@ -253,12 +253,6 @@ class ImplConstant : public Castable { public: ImplConstant() = default; ~ImplConstant() override = default; - - /// Convert attempts to convert the constant value to the given type. On error, Convert() - /// creates a new diagnostic message and returns a Failure. - virtual utils::Result Convert(ProgramBuilder& builder, - const type::Type* target_ty, - const Source& source) const = 0; }; /// A result templated with a ImplConstant. @@ -297,62 +291,6 @@ class Scalar : public Castable, ImplConstant> { bool AllEqual() const override { return true; } size_t Hash() const override { return utils::Hash(type, ValueOf(value)); } - ImplResult Convert(ProgramBuilder& builder, - const type::Type* target_ty, - const Source& source) const override { - TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE); - if (target_ty == type) { - // If the types are identical, then no conversion is needed. - return this; - } - return ZeroTypeDispatch(target_ty, [&](auto zero_to) -> ImplResult { - // `value` is the source value. - // `FROM` is the source type. - // `TO` is the target type. - using TO = std::decay_t; - using FROM = T; - if constexpr (std::is_same_v) { - // [x -> bool] - return builder.create>(target_ty, !IsPositiveZero(value)); - } else if constexpr (std::is_same_v) { - // [bool -> x] - return builder.create>(target_ty, TO(value ? 1 : 0)); - } else if (auto conv = CheckedConvert(value)) { - // Conversion success - return builder.create>(target_ty, conv.Get()); - // --- Below this point are the failure cases --- - } else if constexpr (IsAbstract) { - // [abstract-numeric -> x] - materialization failure - builder.Diagnostics().add_error( - tint::diag::System::Resolver, - OverflowErrorMessage(value, builder.FriendlyName(target_ty)), source); - return utils::Failure; - } else if constexpr (IsFloatingPoint) { - // [x -> floating-point] - number not exactly representable - // https://www.w3.org/TR/WGSL/#floating-point-conversion - builder.Diagnostics().add_error( - tint::diag::System::Resolver, - OverflowErrorMessage(value, builder.FriendlyName(target_ty)), source); - return utils::Failure; - } else if constexpr (IsFloatingPoint) { - // [floating-point -> integer] - number not exactly representable - // https://www.w3.org/TR/WGSL/#floating-point-conversion - switch (conv.Failure()) { - case ConversionFailure::kExceedsNegativeLimit: - return builder.create>(target_ty, TO::Lowest()); - case ConversionFailure::kExceedsPositiveLimit: - return builder.create>(target_ty, TO::Highest()); - } - } else if constexpr (IsIntegral) { - // [integer -> integer] - number not exactly representable - // Static cast - return builder.create>(target_ty, static_cast(value)); - } - return nullptr; // Expression is not constant. - }); - TINT_END_DISABLE_WARNING(UNREACHABLE_CODE); - } - type::Type const* const type; const T value; }; @@ -373,23 +311,6 @@ class Splat : public Castable { bool AllEqual() const override { return true; } size_t Hash() const override { return utils::Hash(type, el->Hash(), count); } - ImplResult Convert(ProgramBuilder& builder, - const type::Type* target_ty, - const Source& source) const override { - // Convert the single splatted element type. - // Note: This file is the only place where `constant::Constant`s are created, so this - // static_cast is safe. - auto conv_el = static_cast(el)->Convert( - builder, type::Type::ElementOf(target_ty), source); - if (!conv_el) { - return utils::Failure; - } - if (!conv_el.Get()) { - return nullptr; - } - return builder.create(target_ty, conv_el.Get(), count); - } - type::Type const* const type; const constant::Constant* el; const size_t count; @@ -418,41 +339,6 @@ class Composite : public Castable { bool AllEqual() const override { return false; /* otherwise this should be a Splat */ } size_t Hash() const override { return hash; } - ImplResult Convert(ProgramBuilder& builder, - const type::Type* target_ty, - const Source& source) const override { - // Convert each of the composite element types. - utils::Vector conv_els; - conv_els.Reserve(elements.Length()); - std::function target_el_ty; - if (auto* str = target_ty->As()) { - if (str->Members().Length() != elements.Length()) { - TINT_ICE(Resolver, builder.Diagnostics()) - << "const-eval conversion of structure has mismatched element counts"; - return utils::Failure; - } - target_el_ty = [str](size_t idx) { return str->Members()[idx]->Type(); }; - } else { - auto* el_ty = type::Type::ElementOf(target_ty); - target_el_ty = [el_ty](size_t) { return el_ty; }; - } - - for (auto* el : elements) { - // Note: This file is the only place where `constant::Constant`s are created, so the - // static_cast is safe. - auto conv_el = static_cast(el)->Convert( - builder, target_el_ty(conv_els.Length()), source); - if (!conv_el) { - return utils::Failure; - } - if (!conv_el.Get()) { - return nullptr; - } - conv_els.Push(conv_el.Get()); - } - return CreateComposite(builder, target_ty, std::move(conv_els)); - } - size_t CalcHash() { auto h = utils::Hash(type, all_zero, any_zero); for (auto* el : elements) { @@ -468,6 +354,148 @@ class Composite : public Castable { const size_t hash; }; +template +ImplResult ScalarConvert(const Scalar* scalar, + ProgramBuilder& builder, + const type::Type* target_ty, + const Source& source) { + TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE); + if (target_ty == scalar->type) { + // If the types are identical, then no conversion is needed. + return scalar; + } + return ZeroTypeDispatch(target_ty, [&](auto zero_to) -> ImplResult { + // `value` is the source value. + // `FROM` is the source type. + // `TO` is the target type. + using TO = std::decay_t; + using FROM = T; + if constexpr (std::is_same_v) { + // [x -> bool] + return builder.create>(target_ty, !IsPositiveZero(scalar->value)); + } else if constexpr (std::is_same_v) { + // [bool -> x] + return builder.create>(target_ty, TO(scalar->value ? 1 : 0)); + } else if (auto conv = CheckedConvert(scalar->value)) { + // Conversion success + return builder.create>(target_ty, conv.Get()); + // --- Below this point are the failure cases --- + } else if constexpr (IsAbstract) { + // [abstract-numeric -> x] - materialization failure + builder.Diagnostics().add_error( + tint::diag::System::Resolver, + OverflowErrorMessage(scalar->value, builder.FriendlyName(target_ty)), source); + return utils::Failure; + } else if constexpr (IsFloatingPoint) { + // [x -> floating-point] - number not exactly representable + // https://www.w3.org/TR/WGSL/#floating-point-conversion + builder.Diagnostics().add_error( + tint::diag::System::Resolver, + OverflowErrorMessage(scalar->value, builder.FriendlyName(target_ty)), source); + return utils::Failure; + } else if constexpr (IsFloatingPoint) { + // [floating-point -> integer] - number not exactly representable + // https://www.w3.org/TR/WGSL/#floating-point-conversion + switch (conv.Failure()) { + case ConversionFailure::kExceedsNegativeLimit: + return builder.create>(target_ty, TO::Lowest()); + case ConversionFailure::kExceedsPositiveLimit: + return builder.create>(target_ty, TO::Highest()); + } + } else if constexpr (IsIntegral) { + // [integer -> integer] - number not exactly representable + // Static cast + return builder.create>(target_ty, static_cast(scalar->value)); + } + return nullptr; // Expression is not constant. + }); + TINT_END_DISABLE_WARNING(UNREACHABLE_CODE); +} + +// Forward declare +ImplResult ConvertInternal(const constant::Constant* c, + ProgramBuilder& builder, + const type::Type* target_ty, + const Source& source); + +ImplResult SplatConvert(const Splat* splat, + ProgramBuilder& builder, + const type::Type* target_ty, + const Source& source) { + // Convert the single splatted element type. + auto conv_el = ConvertInternal(splat->el, builder, type::Type::ElementOf(target_ty), source); + if (!conv_el) { + return utils::Failure; + } + if (!conv_el.Get()) { + return nullptr; + } + return builder.create(target_ty, conv_el.Get(), splat->count); +} + +ImplResult CompositeConvert(const Composite* composite, + ProgramBuilder& builder, + const type::Type* target_ty, + const Source& source) { + // Convert each of the composite element types. + utils::Vector conv_els; + conv_els.Reserve(composite->elements.Length()); + + std::function target_el_ty; + if (auto* str = target_ty->As()) { + if (str->Members().Length() != composite->elements.Length()) { + TINT_ICE(Resolver, builder.Diagnostics()) + << "const-eval conversion of structure has mismatched element counts"; + return utils::Failure; + } + target_el_ty = [str](size_t idx) { return str->Members()[idx]->Type(); }; + } else { + auto* el_ty = type::Type::ElementOf(target_ty); + target_el_ty = [el_ty](size_t) { return el_ty; }; + } + + for (auto* el : composite->elements) { + auto conv_el = ConvertInternal(el, builder, target_el_ty(conv_els.Length()), source); + if (!conv_el) { + return utils::Failure; + } + if (!conv_el.Get()) { + return nullptr; + } + conv_els.Push(conv_el.Get()); + } + return CreateComposite(builder, target_ty, std::move(conv_els)); +} + +ImplResult ConvertInternal(const constant::Constant* c, + ProgramBuilder& builder, + const type::Type* target_ty, + const Source& source) { + return Switch( + c, + [&](const Scalar* val) { + return ScalarConvert(val, builder, target_ty, source); + }, + [&](const Scalar* val) { + return ScalarConvert(val, builder, target_ty, source); + }, + [&](const Scalar* val) { + return ScalarConvert(val, builder, target_ty, source); + }, + [&](const Scalar* val) { + return ScalarConvert(val, builder, target_ty, source); + }, + [&](const Scalar* val) { + return ScalarConvert(val, builder, target_ty, source); + }, + [&](const Scalar* val) { + return ScalarConvert(val, builder, target_ty, source); + }, + [&](const Scalar* val) { return ScalarConvert(val, builder, target_ty, source); }, + [&](const Splat* val) { return SplatConvert(val, builder, target_ty, source); }, + [&](const Composite* val) { return CompositeConvert(val, builder, target_ty, source); }); +} + } // namespace } // namespace tint::resolver @@ -3707,7 +3735,7 @@ ConstEval::Result ConstEval::Convert(const type::Type* target_ty, if (value->Type() == target_ty) { return value; } - return static_cast(value)->Convert(builder, target_ty, source); + return ConvertInternal(value, builder, target_ty, source); } void ConstEval::AddError(const std::string& msg, const Source& source) const {