Extract constant convert methods.

This CL pulls the convert methods out into standalone methods inside the
resolver and de-couples from the constants.

Bug: tint:1718
Change-Id: Id566704687b2d74e05eae860477552f88f6a06b9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/114120
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
dan sinclair 2022-12-14 14:07:07 +00:00 committed by Dawn LUCI CQ
parent d9deee0a95
commit d586d91478
1 changed files with 143 additions and 115 deletions

View File

@ -253,12 +253,6 @@ class ImplConstant : public Castable<ImplConstant, constant::Constant> {
public:
ImplConstant() = default;
~ImplConstant() override = default;
/// Convert attempts to convert the constant value to the given type. On error, Convert()
/// creates a new diagnostic message and returns a Failure.
virtual utils::Result<const ImplConstant*> Convert(ProgramBuilder& builder,
const type::Type* target_ty,
const Source& source) const = 0;
};
/// A result templated with a ImplConstant.
@ -297,62 +291,6 @@ class Scalar : public Castable<Scalar<T>, ImplConstant> {
bool AllEqual() const override { return true; }
size_t Hash() const override { return utils::Hash(type, ValueOf(value)); }
ImplResult Convert(ProgramBuilder& builder,
const type::Type* target_ty,
const Source& source) const override {
TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
if (target_ty == type) {
// If the types are identical, then no conversion is needed.
return this;
}
return ZeroTypeDispatch(target_ty, [&](auto zero_to) -> ImplResult {
// `value` is the source value.
// `FROM` is the source type.
// `TO` is the target type.
using TO = std::decay_t<decltype(zero_to)>;
using FROM = T;
if constexpr (std::is_same_v<TO, bool>) {
// [x -> bool]
return builder.create<Scalar<TO>>(target_ty, !IsPositiveZero(value));
} else if constexpr (std::is_same_v<FROM, bool>) {
// [bool -> x]
return builder.create<Scalar<TO>>(target_ty, TO(value ? 1 : 0));
} else if (auto conv = CheckedConvert<TO>(value)) {
// Conversion success
return builder.create<Scalar<TO>>(target_ty, conv.Get());
// --- Below this point are the failure cases ---
} else if constexpr (IsAbstract<FROM>) {
// [abstract-numeric -> x] - materialization failure
builder.Diagnostics().add_error(
tint::diag::System::Resolver,
OverflowErrorMessage(value, builder.FriendlyName(target_ty)), source);
return utils::Failure;
} else if constexpr (IsFloatingPoint<TO>) {
// [x -> floating-point] - number not exactly representable
// https://www.w3.org/TR/WGSL/#floating-point-conversion
builder.Diagnostics().add_error(
tint::diag::System::Resolver,
OverflowErrorMessage(value, builder.FriendlyName(target_ty)), source);
return utils::Failure;
} else if constexpr (IsFloatingPoint<FROM>) {
// [floating-point -> integer] - number not exactly representable
// https://www.w3.org/TR/WGSL/#floating-point-conversion
switch (conv.Failure()) {
case ConversionFailure::kExceedsNegativeLimit:
return builder.create<Scalar<TO>>(target_ty, TO::Lowest());
case ConversionFailure::kExceedsPositiveLimit:
return builder.create<Scalar<TO>>(target_ty, TO::Highest());
}
} else if constexpr (IsIntegral<FROM>) {
// [integer -> integer] - number not exactly representable
// Static cast
return builder.create<Scalar<TO>>(target_ty, static_cast<TO>(value));
}
return nullptr; // Expression is not constant.
});
TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
}
type::Type const* const type;
const T value;
};
@ -373,23 +311,6 @@ class Splat : public Castable<Splat, ImplConstant> {
bool AllEqual() const override { return true; }
size_t Hash() const override { return utils::Hash(type, el->Hash(), count); }
ImplResult Convert(ProgramBuilder& builder,
const type::Type* target_ty,
const Source& source) const override {
// Convert the single splatted element type.
// Note: This file is the only place where `constant::Constant`s are created, so this
// static_cast is safe.
auto conv_el = static_cast<const ImplConstant*>(el)->Convert(
builder, type::Type::ElementOf(target_ty), source);
if (!conv_el) {
return utils::Failure;
}
if (!conv_el.Get()) {
return nullptr;
}
return builder.create<Splat>(target_ty, conv_el.Get(), count);
}
type::Type const* const type;
const constant::Constant* el;
const size_t count;
@ -418,41 +339,6 @@ class Composite : public Castable<Composite, ImplConstant> {
bool AllEqual() const override { return false; /* otherwise this should be a Splat */ }
size_t Hash() const override { return hash; }
ImplResult Convert(ProgramBuilder& builder,
const type::Type* target_ty,
const Source& source) const override {
// Convert each of the composite element types.
utils::Vector<const constant::Constant*, 4> conv_els;
conv_els.Reserve(elements.Length());
std::function<const type::Type*(size_t idx)> target_el_ty;
if (auto* str = target_ty->As<type::Struct>()) {
if (str->Members().Length() != elements.Length()) {
TINT_ICE(Resolver, builder.Diagnostics())
<< "const-eval conversion of structure has mismatched element counts";
return utils::Failure;
}
target_el_ty = [str](size_t idx) { return str->Members()[idx]->Type(); };
} else {
auto* el_ty = type::Type::ElementOf(target_ty);
target_el_ty = [el_ty](size_t) { return el_ty; };
}
for (auto* el : elements) {
// Note: This file is the only place where `constant::Constant`s are created, so the
// static_cast is safe.
auto conv_el = static_cast<const ImplConstant*>(el)->Convert(
builder, target_el_ty(conv_els.Length()), source);
if (!conv_el) {
return utils::Failure;
}
if (!conv_el.Get()) {
return nullptr;
}
conv_els.Push(conv_el.Get());
}
return CreateComposite(builder, target_ty, std::move(conv_els));
}
size_t CalcHash() {
auto h = utils::Hash(type, all_zero, any_zero);
for (auto* el : elements) {
@ -468,6 +354,148 @@ class Composite : public Castable<Composite, ImplConstant> {
const size_t hash;
};
template <typename T>
ImplResult ScalarConvert(const Scalar<T>* scalar,
ProgramBuilder& builder,
const type::Type* target_ty,
const Source& source) {
TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
if (target_ty == scalar->type) {
// If the types are identical, then no conversion is needed.
return scalar;
}
return ZeroTypeDispatch(target_ty, [&](auto zero_to) -> ImplResult {
// `value` is the source value.
// `FROM` is the source type.
// `TO` is the target type.
using TO = std::decay_t<decltype(zero_to)>;
using FROM = T;
if constexpr (std::is_same_v<TO, bool>) {
// [x -> bool]
return builder.create<Scalar<TO>>(target_ty, !IsPositiveZero(scalar->value));
} else if constexpr (std::is_same_v<FROM, bool>) {
// [bool -> x]
return builder.create<Scalar<TO>>(target_ty, TO(scalar->value ? 1 : 0));
} else if (auto conv = CheckedConvert<TO>(scalar->value)) {
// Conversion success
return builder.create<Scalar<TO>>(target_ty, conv.Get());
// --- Below this point are the failure cases ---
} else if constexpr (IsAbstract<FROM>) {
// [abstract-numeric -> x] - materialization failure
builder.Diagnostics().add_error(
tint::diag::System::Resolver,
OverflowErrorMessage(scalar->value, builder.FriendlyName(target_ty)), source);
return utils::Failure;
} else if constexpr (IsFloatingPoint<TO>) {
// [x -> floating-point] - number not exactly representable
// https://www.w3.org/TR/WGSL/#floating-point-conversion
builder.Diagnostics().add_error(
tint::diag::System::Resolver,
OverflowErrorMessage(scalar->value, builder.FriendlyName(target_ty)), source);
return utils::Failure;
} else if constexpr (IsFloatingPoint<FROM>) {
// [floating-point -> integer] - number not exactly representable
// https://www.w3.org/TR/WGSL/#floating-point-conversion
switch (conv.Failure()) {
case ConversionFailure::kExceedsNegativeLimit:
return builder.create<Scalar<TO>>(target_ty, TO::Lowest());
case ConversionFailure::kExceedsPositiveLimit:
return builder.create<Scalar<TO>>(target_ty, TO::Highest());
}
} else if constexpr (IsIntegral<FROM>) {
// [integer -> integer] - number not exactly representable
// Static cast
return builder.create<Scalar<TO>>(target_ty, static_cast<TO>(scalar->value));
}
return nullptr; // Expression is not constant.
});
TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
}
// Forward declare
ImplResult ConvertInternal(const constant::Constant* c,
ProgramBuilder& builder,
const type::Type* target_ty,
const Source& source);
ImplResult SplatConvert(const Splat* splat,
ProgramBuilder& builder,
const type::Type* target_ty,
const Source& source) {
// Convert the single splatted element type.
auto conv_el = ConvertInternal(splat->el, builder, type::Type::ElementOf(target_ty), source);
if (!conv_el) {
return utils::Failure;
}
if (!conv_el.Get()) {
return nullptr;
}
return builder.create<Splat>(target_ty, conv_el.Get(), splat->count);
}
ImplResult CompositeConvert(const Composite* composite,
ProgramBuilder& builder,
const type::Type* target_ty,
const Source& source) {
// Convert each of the composite element types.
utils::Vector<const constant::Constant*, 4> conv_els;
conv_els.Reserve(composite->elements.Length());
std::function<const type::Type*(size_t idx)> target_el_ty;
if (auto* str = target_ty->As<type::Struct>()) {
if (str->Members().Length() != composite->elements.Length()) {
TINT_ICE(Resolver, builder.Diagnostics())
<< "const-eval conversion of structure has mismatched element counts";
return utils::Failure;
}
target_el_ty = [str](size_t idx) { return str->Members()[idx]->Type(); };
} else {
auto* el_ty = type::Type::ElementOf(target_ty);
target_el_ty = [el_ty](size_t) { return el_ty; };
}
for (auto* el : composite->elements) {
auto conv_el = ConvertInternal(el, builder, target_el_ty(conv_els.Length()), source);
if (!conv_el) {
return utils::Failure;
}
if (!conv_el.Get()) {
return nullptr;
}
conv_els.Push(conv_el.Get());
}
return CreateComposite(builder, target_ty, std::move(conv_els));
}
ImplResult ConvertInternal(const constant::Constant* c,
ProgramBuilder& builder,
const type::Type* target_ty,
const Source& source) {
return Switch(
c,
[&](const Scalar<tint::AFloat>* val) {
return ScalarConvert(val, builder, target_ty, source);
},
[&](const Scalar<tint::AInt>* val) {
return ScalarConvert(val, builder, target_ty, source);
},
[&](const Scalar<tint::u32>* val) {
return ScalarConvert(val, builder, target_ty, source);
},
[&](const Scalar<tint::i32>* val) {
return ScalarConvert(val, builder, target_ty, source);
},
[&](const Scalar<tint::f32>* val) {
return ScalarConvert(val, builder, target_ty, source);
},
[&](const Scalar<tint::f16>* val) {
return ScalarConvert(val, builder, target_ty, source);
},
[&](const Scalar<bool>* val) { return ScalarConvert(val, builder, target_ty, source); },
[&](const Splat* val) { return SplatConvert(val, builder, target_ty, source); },
[&](const Composite* val) { return CompositeConvert(val, builder, target_ty, source); });
}
} // namespace
} // namespace tint::resolver
@ -3707,7 +3735,7 @@ ConstEval::Result ConstEval::Convert(const type::Type* target_ty,
if (value->Type() == target_ty) {
return value;
}
return static_cast<const ImplConstant*>(value)->Convert(builder, target_ty, source);
return ConvertInternal(value, builder, target_ty, source);
}
void ConstEval::AddError(const std::string& msg, const Source& source) const {