tint/resolver: Add ConstEval class
Extract out the methods of Resolver::EvaluateXXXValue() to a new tint::resolver::ConstEval class. Removes more bloat from Resolver, and creates a centralized class for constant evaluation, which can be referred to by the IntrinsicTable. Change-Id: I3b58882ef293fe07f019ad2138a7e9dbbac8de53 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/95951 Commit-Queue: Ben Clayton <bclayton@google.com> Reviewed-by: Antonio Maiorano <amaiorano@google.com> Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
parent
cfe07a1b33
commit
65c5c9d92b
|
@ -398,7 +398,6 @@ libtint_source_set("libtint_core_all_src") {
|
|||
"resolver/intrinsic_table.inl",
|
||||
"resolver/resolver.cc",
|
||||
"resolver/resolver.h",
|
||||
"resolver/resolver_constants.cc",
|
||||
"resolver/sem_helper.cc",
|
||||
"resolver/sem_helper.h",
|
||||
"resolver/uniformity.cc",
|
||||
|
@ -1090,6 +1089,7 @@ if (tint_build_unittests) {
|
|||
"resolver/call_validation_test.cc",
|
||||
"resolver/compound_assignment_validation_test.cc",
|
||||
"resolver/compound_statement_test.cc",
|
||||
"resolver/const_eval_test.cc",
|
||||
"resolver/control_block_validation_test.cc",
|
||||
"resolver/dependency_graph_test.cc",
|
||||
"resolver/entry_point_validation_test.cc",
|
||||
|
@ -1104,7 +1104,6 @@ if (tint_build_unittests) {
|
|||
"resolver/ptr_ref_test.cc",
|
||||
"resolver/ptr_ref_validation_test.cc",
|
||||
"resolver/resolver_behavior_test.cc",
|
||||
"resolver/resolver_constants_test.cc",
|
||||
"resolver/resolver_test.cc",
|
||||
"resolver/resolver_test_helper.cc",
|
||||
"resolver/resolver_test_helper.h",
|
||||
|
|
|
@ -258,7 +258,6 @@ set(TINT_LIB_SRCS
|
|||
resolver/intrinsic_table.cc
|
||||
resolver/intrinsic_table.h
|
||||
resolver/intrinsic_table.inl
|
||||
resolver/resolver_constants.cc
|
||||
resolver/resolver.cc
|
||||
resolver/resolver.h
|
||||
resolver/sem_helper.cc
|
||||
|
@ -774,6 +773,7 @@ if(TINT_BUILD_TESTS)
|
|||
resolver/call_validation_test.cc
|
||||
resolver/compound_assignment_validation_test.cc
|
||||
resolver/compound_statement_test.cc
|
||||
resolver/const_eval_test.cc
|
||||
resolver/control_block_validation_test.cc
|
||||
resolver/dependency_graph_test.cc
|
||||
resolver/entry_point_validation_test.cc
|
||||
|
@ -789,7 +789,6 @@ if(TINT_BUILD_TESTS)
|
|||
resolver/ptr_ref_test.cc
|
||||
resolver/ptr_ref_validation_test.cc
|
||||
resolver/resolver_behavior_test.cc
|
||||
resolver/resolver_constants_test.cc
|
||||
resolver/resolver_test_helper.cc
|
||||
resolver/resolver_test_helper.h
|
||||
resolver/resolver_test.cc
|
||||
|
|
|
@ -14,6 +14,603 @@
|
|||
|
||||
#include "src/tint/resolver/const_eval.h"
|
||||
|
||||
#include "src/tint/sem/constant.h"
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
namespace tint::resolver::const_eval {} // namespace tint::resolver::const_eval
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/abstract_float.h"
|
||||
#include "src/tint/sem/abstract_int.h"
|
||||
#include "src/tint/sem/array.h"
|
||||
#include "src/tint/sem/bool.h"
|
||||
#include "src/tint/sem/constant.h"
|
||||
#include "src/tint/sem/f16.h"
|
||||
#include "src/tint/sem/f32.h"
|
||||
#include "src/tint/sem/i32.h"
|
||||
#include "src/tint/sem/matrix.h"
|
||||
#include "src/tint/sem/member_accessor_expression.h"
|
||||
#include "src/tint/sem/type_constructor.h"
|
||||
#include "src/tint/sem/u32.h"
|
||||
#include "src/tint/sem/vector.h"
|
||||
#include "src/tint/utils/compiler_macros.h"
|
||||
#include "src/tint/utils/map.h"
|
||||
#include "src/tint/utils/transform.h"
|
||||
|
||||
using namespace tint::number_suffixes; // NOLINT
|
||||
|
||||
namespace tint::resolver {
|
||||
|
||||
namespace {
|
||||
|
||||
/// TypeDispatch is a helper for calling the function `f`, passing a single zero-value argument of
|
||||
/// the C++ type that corresponds to the sem::Type `type`. For example, calling `TypeDispatch()`
|
||||
/// with a type of `sem::I32*` will call the function f with a single argument of `i32(0)`.
|
||||
/// @returns the value returned by calling `f`.
|
||||
/// @note `type` must be a scalar or abstract numeric type. Other types will not call `f`, and will
|
||||
/// return the zero-initialized value of the return type for `f`.
|
||||
template <typename F>
|
||||
auto TypeDispatch(const sem::Type* type, F&& f) {
|
||||
return Switch(
|
||||
type, //
|
||||
[&](const sem::AbstractInt*) { return f(AInt(0)); }, //
|
||||
[&](const sem::AbstractFloat*) { return f(AFloat(0)); }, //
|
||||
[&](const sem::I32*) { return f(i32(0)); }, //
|
||||
[&](const sem::U32*) { return f(u32(0)); }, //
|
||||
[&](const sem::F32*) { return f(f32(0)); }, //
|
||||
[&](const sem::F16*) { return f(f16(0)); }, //
|
||||
[&](const sem::Bool*) { return f(static_cast<bool>(0)); });
|
||||
}
|
||||
|
||||
/// @returns `value` if `T` is not a Number, otherwise ValueOf returns the inner value of the
|
||||
/// Number.
|
||||
template <typename T>
|
||||
inline auto ValueOf(T value) {
|
||||
if constexpr (std::is_same_v<UnwrapNumber<T>, T>) {
|
||||
return value;
|
||||
} else {
|
||||
return value.value;
|
||||
}
|
||||
}
|
||||
|
||||
/// @returns true if `value` is a positive zero.
|
||||
template <typename T>
|
||||
inline bool IsPositiveZero(T value) {
|
||||
using N = UnwrapNumber<T>;
|
||||
return Number<N>(value) == Number<N>(0); // Considers sign bit
|
||||
}
|
||||
|
||||
/// Constant inherits from sem::Constant to add an private implementation method for conversion.
|
||||
struct Constant : public sem::Constant {
|
||||
/// Convert attempts to convert the constant value to the given type. On error, Convert()
|
||||
/// creates a new diagnostic message and returns a Failure.
|
||||
virtual utils::Result<const Constant*> Convert(ProgramBuilder& builder,
|
||||
const sem::Type* target_ty,
|
||||
const Source& source) const = 0;
|
||||
};
|
||||
|
||||
// Forward declaration
|
||||
const Constant* CreateComposite(ProgramBuilder& builder,
|
||||
const sem::Type* type,
|
||||
std::vector<const Constant*> elements);
|
||||
|
||||
/// Element holds a single scalar or abstract-numeric value.
|
||||
/// Element implements the Constant interface.
|
||||
template <typename T>
|
||||
struct Element : Constant {
|
||||
Element(const sem::Type* t, T v) : type(t), value(v) {}
|
||||
~Element() override = default;
|
||||
const sem::Type* Type() const override { return type; }
|
||||
std::variant<std::monostate, AInt, AFloat> Value() const override {
|
||||
if constexpr (IsFloatingPoint<UnwrapNumber<T>>) {
|
||||
return static_cast<AFloat>(value);
|
||||
} else {
|
||||
return static_cast<AInt>(value);
|
||||
}
|
||||
}
|
||||
const Constant* Index(size_t) const override { return nullptr; }
|
||||
bool AllZero() const override { return IsPositiveZero(value); }
|
||||
bool AnyZero() const override { return IsPositiveZero(value); }
|
||||
bool AllEqual() const override { return true; }
|
||||
size_t Hash() const override { return utils::Hash(type, ValueOf(value)); }
|
||||
|
||||
utils::Result<const Constant*> Convert(ProgramBuilder& builder,
|
||||
const sem::Type* target_ty,
|
||||
const Source& source) const override {
|
||||
TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
|
||||
if (target_ty == type) {
|
||||
// If the types are identical, then no conversion is needed.
|
||||
return this;
|
||||
}
|
||||
bool failed = false;
|
||||
auto* res = TypeDispatch(target_ty, [&](auto zero_to) -> const Constant* {
|
||||
// `T` is the source type, `value` is the source value.
|
||||
// `TO` is the target type.
|
||||
using TO = std::decay_t<decltype(zero_to)>;
|
||||
if constexpr (std::is_same_v<TO, bool>) {
|
||||
// [x -> bool]
|
||||
return builder.create<Element<TO>>(target_ty, !IsPositiveZero(value));
|
||||
} else if constexpr (std::is_same_v<T, bool>) {
|
||||
// [bool -> x]
|
||||
return builder.create<Element<TO>>(target_ty, TO(value ? 1 : 0));
|
||||
} else if (auto conv = CheckedConvert<TO>(value)) {
|
||||
// Conversion success
|
||||
return builder.create<Element<TO>>(target_ty, conv.Get());
|
||||
// --- Below this point are the failure cases ---
|
||||
} else if constexpr (std::is_same_v<T, AInt> || std::is_same_v<T, AFloat>) {
|
||||
// [abstract-numeric -> x] - materialization failure
|
||||
std::stringstream ss;
|
||||
ss << "value " << value << " cannot be represented as ";
|
||||
ss << "'" << builder.FriendlyName(target_ty) << "'";
|
||||
builder.Diagnostics().add_error(tint::diag::System::Resolver, ss.str(), source);
|
||||
failed = true;
|
||||
} else if constexpr (IsFloatingPoint<UnwrapNumber<TO>>) {
|
||||
// [x -> floating-point] - number not exactly representable
|
||||
// https://www.w3.org/TR/WGSL/#floating-point-conversion
|
||||
constexpr auto kInf = std::numeric_limits<double>::infinity();
|
||||
switch (conv.Failure()) {
|
||||
case ConversionFailure::kExceedsNegativeLimit:
|
||||
return builder.create<Element<TO>>(target_ty, TO(-kInf));
|
||||
case ConversionFailure::kExceedsPositiveLimit:
|
||||
return builder.create<Element<TO>>(target_ty, TO(kInf));
|
||||
}
|
||||
} else {
|
||||
// [x -> integer] - number not exactly representable
|
||||
// https://www.w3.org/TR/WGSL/#floating-point-conversion
|
||||
switch (conv.Failure()) {
|
||||
case ConversionFailure::kExceedsNegativeLimit:
|
||||
return builder.create<Element<TO>>(target_ty, TO(TO::kLowest));
|
||||
case ConversionFailure::kExceedsPositiveLimit:
|
||||
return builder.create<Element<TO>>(target_ty, TO(TO::kHighest));
|
||||
}
|
||||
}
|
||||
return nullptr; // Expression is not constant.
|
||||
});
|
||||
if (failed) {
|
||||
// A diagnostic error has been raised, and resolving should abort.
|
||||
return utils::Failure;
|
||||
}
|
||||
return res;
|
||||
TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
|
||||
}
|
||||
|
||||
sem::Type const* const type;
|
||||
const T value;
|
||||
};
|
||||
|
||||
/// Splat holds a single Constant value, duplicated as all children.
|
||||
/// Splat is used for zero-initializers, 'splat' constructors, or constructors where each element is
|
||||
/// identical. Splat may be of a vector, matrix or array type.
|
||||
/// Splat implements the Constant interface.
|
||||
struct Splat : Constant {
|
||||
Splat(const sem::Type* t, const Constant* e, size_t n) : type(t), el(e), count(n) {}
|
||||
~Splat() override = default;
|
||||
const sem::Type* Type() const override { return type; }
|
||||
std::variant<std::monostate, AInt, AFloat> Value() const override { return {}; }
|
||||
const Constant* Index(size_t i) const override { return i < count ? el : nullptr; }
|
||||
bool AllZero() const override { return el->AllZero(); }
|
||||
bool AnyZero() const override { return el->AnyZero(); }
|
||||
bool AllEqual() const override { return true; }
|
||||
size_t Hash() const override { return utils::Hash(type, el->Hash(), count); }
|
||||
|
||||
utils::Result<const Constant*> Convert(ProgramBuilder& builder,
|
||||
const sem::Type* target_ty,
|
||||
const Source& source) const override {
|
||||
// Convert the single splatted element type.
|
||||
auto conv_el = el->Convert(builder, sem::Type::ElementOf(target_ty), source);
|
||||
if (!conv_el) {
|
||||
return utils::Failure;
|
||||
}
|
||||
if (!conv_el.Get()) {
|
||||
return nullptr;
|
||||
}
|
||||
return builder.create<Splat>(target_ty, conv_el.Get(), count);
|
||||
}
|
||||
|
||||
sem::Type const* const type;
|
||||
const Constant* el;
|
||||
const size_t count;
|
||||
};
|
||||
|
||||
/// Composite holds a number of mixed child Constant values.
|
||||
/// Composite may be of a vector, matrix or array type.
|
||||
/// If each element is the same type and value, then a Splat would be a more efficient constant
|
||||
/// implementation. Use CreateComposite() to create the appropriate Constant type.
|
||||
/// Composite implements the Constant interface.
|
||||
struct Composite : Constant {
|
||||
Composite(const sem::Type* t, std::vector<const Constant*> els, bool all_0, bool any_0)
|
||||
: type(t), elements(std::move(els)), all_zero(all_0), any_zero(any_0), hash(CalcHash()) {}
|
||||
~Composite() override = default;
|
||||
const sem::Type* Type() const override { return type; }
|
||||
std::variant<std::monostate, AInt, AFloat> Value() const override { return {}; }
|
||||
const Constant* Index(size_t i) const override {
|
||||
return i < elements.size() ? elements[i] : nullptr;
|
||||
}
|
||||
bool AllZero() const override { return all_zero; }
|
||||
bool AnyZero() const override { return any_zero; }
|
||||
bool AllEqual() const override { return false; /* otherwise this should be a Splat */ }
|
||||
size_t Hash() const override { return hash; }
|
||||
|
||||
utils::Result<const Constant*> Convert(ProgramBuilder& builder,
|
||||
const sem::Type* target_ty,
|
||||
const Source& source) const override {
|
||||
// Convert each of the composite element types.
|
||||
auto* el_ty = sem::Type::ElementOf(target_ty);
|
||||
std::vector<const Constant*> conv_els;
|
||||
conv_els.reserve(elements.size());
|
||||
for (auto* el : elements) {
|
||||
auto conv_el = el->Convert(builder, el_ty, source);
|
||||
if (!conv_el) {
|
||||
return utils::Failure;
|
||||
}
|
||||
if (!conv_el.Get()) {
|
||||
return nullptr;
|
||||
}
|
||||
conv_els.emplace_back(conv_el.Get());
|
||||
}
|
||||
return CreateComposite(builder, target_ty, std::move(conv_els));
|
||||
}
|
||||
|
||||
size_t CalcHash() {
|
||||
auto h = utils::Hash(type, all_zero, any_zero);
|
||||
for (auto* el : elements) {
|
||||
utils::HashCombine(&h, el->Hash());
|
||||
}
|
||||
return h;
|
||||
}
|
||||
|
||||
sem::Type const* const type;
|
||||
const std::vector<const Constant*> elements;
|
||||
const bool all_zero;
|
||||
const bool any_zero;
|
||||
const size_t hash;
|
||||
};
|
||||
|
||||
/// CreateElement constructs and returns an Element<T>.
|
||||
template <typename T>
|
||||
const Constant* CreateElement(ProgramBuilder& builder, const sem::Type* t, T v) {
|
||||
return builder.create<Element<T>>(t, v);
|
||||
}
|
||||
|
||||
/// ZeroValue returns a Constant for the zero-value of the type `type`.
|
||||
const Constant* ZeroValue(ProgramBuilder& builder, const sem::Type* type) {
|
||||
return Switch(
|
||||
type, //
|
||||
[&](const sem::Vector* v) -> const Constant* {
|
||||
auto* zero_el = ZeroValue(builder, v->type());
|
||||
return builder.create<Splat>(type, zero_el, v->Width());
|
||||
},
|
||||
[&](const sem::Matrix* m) -> const Constant* {
|
||||
auto* zero_el = ZeroValue(builder, m->ColumnType());
|
||||
return builder.create<Splat>(type, zero_el, m->columns());
|
||||
},
|
||||
[&](const sem::Array* a) -> const Constant* {
|
||||
if (auto* zero_el = ZeroValue(builder, a->ElemType())) {
|
||||
return builder.create<Splat>(type, zero_el, a->Count());
|
||||
}
|
||||
return nullptr;
|
||||
},
|
||||
[&](const sem::Struct* s) -> const Constant* {
|
||||
std::unordered_map<sem::Type*, const Constant*> zero_by_type;
|
||||
std::vector<const Constant*> zeros;
|
||||
zeros.reserve(s->Members().size());
|
||||
for (auto* member : s->Members()) {
|
||||
auto* zero = utils::GetOrCreate(zero_by_type, member->Type(),
|
||||
[&] { return ZeroValue(builder, member->Type()); });
|
||||
if (!zero) {
|
||||
return nullptr;
|
||||
}
|
||||
zeros.emplace_back(zero);
|
||||
}
|
||||
if (zero_by_type.size() == 1) {
|
||||
// All members were of the same type, so the zero value is the same for all members.
|
||||
return builder.create<Splat>(type, zeros[0], s->Members().size());
|
||||
}
|
||||
return CreateComposite(builder, s, std::move(zeros));
|
||||
},
|
||||
[&](Default) -> const Constant* {
|
||||
return TypeDispatch(type, [&](auto zero) -> const Constant* {
|
||||
return CreateElement(builder, type, zero);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/// Equal returns true if the constants `a` and `b` are of the same type and value.
|
||||
bool Equal(const sem::Constant* a, const sem::Constant* b) {
|
||||
if (a->Hash() != b->Hash()) {
|
||||
return false;
|
||||
}
|
||||
if (a->Type() != b->Type()) {
|
||||
return false;
|
||||
}
|
||||
return Switch(
|
||||
a->Type(), //
|
||||
[&](const sem::Vector* vec) {
|
||||
for (size_t i = 0; i < vec->Width(); i++) {
|
||||
if (!Equal(a->Index(i), b->Index(i))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
},
|
||||
[&](const sem::Matrix* mat) {
|
||||
for (size_t i = 0; i < mat->columns(); i++) {
|
||||
if (!Equal(a->Index(i), b->Index(i))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
},
|
||||
[&](const sem::Array* arr) {
|
||||
for (size_t i = 0; i < arr->Count(); i++) {
|
||||
if (!Equal(a->Index(i), b->Index(i))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
},
|
||||
[&](Default) { return a->Value() == b->Value(); });
|
||||
}
|
||||
|
||||
/// CreateComposite is used to construct a constant of a vector, matrix or array type.
|
||||
/// CreateComposite examines the element values and will return either a Composite or a Splat,
|
||||
/// depending on the element types and values.
|
||||
const Constant* CreateComposite(ProgramBuilder& builder,
|
||||
const sem::Type* type,
|
||||
std::vector<const Constant*> elements) {
|
||||
if (elements.size() == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
bool any_zero = false;
|
||||
bool all_zero = true;
|
||||
bool all_equal = true;
|
||||
auto* first = elements.front();
|
||||
for (auto* el : elements) {
|
||||
if (!el) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!any_zero && el->AnyZero()) {
|
||||
any_zero = true;
|
||||
}
|
||||
if (all_zero && !el->AllZero()) {
|
||||
all_zero = false;
|
||||
}
|
||||
if (all_equal && el != first) {
|
||||
if (!Equal(el, first)) {
|
||||
all_equal = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (all_equal) {
|
||||
return builder.create<Splat>(type, elements[0], elements.size());
|
||||
} else {
|
||||
return builder.create<Composite>(type, std::move(elements), all_zero, any_zero);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ConstEval::ConstEval(ProgramBuilder& b) : builder(b) {}
|
||||
|
||||
const sem::Constant* ConstEval::Literal(const sem::Type* ty,
|
||||
const ast::LiteralExpression* literal) {
|
||||
return Switch(
|
||||
literal,
|
||||
[&](const ast::BoolLiteralExpression* lit) {
|
||||
return CreateElement(builder, ty, lit->value);
|
||||
},
|
||||
[&](const ast::IntLiteralExpression* lit) -> const Constant* {
|
||||
switch (lit->suffix) {
|
||||
case ast::IntLiteralExpression::Suffix::kNone:
|
||||
return CreateElement(builder, ty, AInt(lit->value));
|
||||
case ast::IntLiteralExpression::Suffix::kI:
|
||||
return CreateElement(builder, ty, i32(lit->value));
|
||||
case ast::IntLiteralExpression::Suffix::kU:
|
||||
return CreateElement(builder, ty, u32(lit->value));
|
||||
}
|
||||
return nullptr;
|
||||
},
|
||||
[&](const ast::FloatLiteralExpression* lit) -> const Constant* {
|
||||
switch (lit->suffix) {
|
||||
case ast::FloatLiteralExpression::Suffix::kNone:
|
||||
return CreateElement(builder, ty, AFloat(lit->value));
|
||||
case ast::FloatLiteralExpression::Suffix::kF:
|
||||
return CreateElement(builder, ty, f32(lit->value));
|
||||
case ast::FloatLiteralExpression::Suffix::kH:
|
||||
return CreateElement(builder, ty, f16(lit->value));
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
}
|
||||
|
||||
const sem::Constant* ConstEval::CtorOrConv(const sem::Type* ty,
|
||||
const std::vector<const sem::Expression*>& args) {
|
||||
// For zero value init, return 0s
|
||||
if (args.empty()) {
|
||||
return ZeroValue(builder, ty);
|
||||
}
|
||||
|
||||
if (auto* el_ty = sem::Type::ElementOf(ty); el_ty && args.size() == 1) {
|
||||
// Type constructor or conversion that takes a single argument.
|
||||
auto& src = args[0]->Declaration()->source;
|
||||
auto* arg = static_cast<const Constant*>(args[0]->ConstantValue());
|
||||
if (!arg) {
|
||||
return nullptr; // Single argument is not constant.
|
||||
}
|
||||
|
||||
if (ty->is_scalar()) { // Scalar type conversion: i32(x), u32(x), bool(x), etc
|
||||
return Convert(el_ty, arg, src).Get();
|
||||
}
|
||||
|
||||
if (arg->Type() == el_ty) {
|
||||
// Argument type matches function type. This is a splat.
|
||||
auto splat = [&](size_t n) { return builder.create<Splat>(ty, arg, n); };
|
||||
return Switch(
|
||||
ty, //
|
||||
[&](const sem::Vector* v) { return splat(v->Width()); },
|
||||
[&](const sem::Matrix* m) { return splat(m->columns()); },
|
||||
[&](const sem::Array* a) { return splat(a->Count()); });
|
||||
}
|
||||
|
||||
// Argument type and function type mismatch. This is a type conversion.
|
||||
if (auto conv = Convert(ty, arg, src)) {
|
||||
return conv.Get();
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Helper for pushing all the argument constants to `els`.
|
||||
auto args_as_constants = [&] {
|
||||
return utils::Transform(
|
||||
args, [&](auto* expr) { return static_cast<const Constant*>(expr->ConstantValue()); });
|
||||
};
|
||||
|
||||
// Multiple arguments. Must be a type constructor.
|
||||
|
||||
return Switch(
|
||||
ty, // What's the target type being constructed?
|
||||
[&](const sem::Vector*) -> const Constant* {
|
||||
// Vector can be constructed with a mix of scalars / abstract numerics and smaller
|
||||
// vectors.
|
||||
std::vector<const Constant*> els;
|
||||
els.reserve(args.size());
|
||||
for (auto* expr : args) {
|
||||
auto* arg = static_cast<const Constant*>(expr->ConstantValue());
|
||||
if (!arg) {
|
||||
return nullptr;
|
||||
}
|
||||
auto* arg_ty = arg->Type();
|
||||
if (auto* arg_vec = arg_ty->As<sem::Vector>()) {
|
||||
// Extract out vector elements.
|
||||
for (uint32_t i = 0; i < arg_vec->Width(); i++) {
|
||||
auto* el = static_cast<const Constant*>(arg->Index(i));
|
||||
if (!el) {
|
||||
return nullptr;
|
||||
}
|
||||
els.emplace_back(el);
|
||||
}
|
||||
} else {
|
||||
els.emplace_back(arg);
|
||||
}
|
||||
}
|
||||
return CreateComposite(builder, ty, std::move(els));
|
||||
},
|
||||
[&](const sem::Matrix* m) -> const Constant* {
|
||||
// Matrix can be constructed with a set of scalars / abstract numerics, or column
|
||||
// vectors.
|
||||
if (args.size() == m->columns() * m->rows()) {
|
||||
// Matrix built from scalars / abstract numerics
|
||||
std::vector<const Constant*> els;
|
||||
els.reserve(args.size());
|
||||
for (uint32_t c = 0; c < m->columns(); c++) {
|
||||
std::vector<const Constant*> column;
|
||||
column.reserve(m->rows());
|
||||
for (uint32_t r = 0; r < m->rows(); r++) {
|
||||
auto* arg =
|
||||
static_cast<const Constant*>(args[r + c * m->rows()]->ConstantValue());
|
||||
if (!arg) {
|
||||
return nullptr;
|
||||
}
|
||||
column.emplace_back(arg);
|
||||
}
|
||||
els.push_back(CreateComposite(builder, m->ColumnType(), std::move(column)));
|
||||
}
|
||||
return CreateComposite(builder, ty, std::move(els));
|
||||
}
|
||||
// Matrix built from column vectors
|
||||
return CreateComposite(builder, ty, args_as_constants());
|
||||
},
|
||||
[&](const sem::Array*) {
|
||||
// Arrays must be constructed using a list of elements
|
||||
return CreateComposite(builder, ty, args_as_constants());
|
||||
},
|
||||
[&](const sem::Struct*) {
|
||||
// Structures must be constructed using a list of elements
|
||||
return CreateComposite(builder, ty, args_as_constants());
|
||||
});
|
||||
}
|
||||
|
||||
const sem::Constant* ConstEval::Index(const sem::Expression* obj_expr,
|
||||
const sem::Expression* idx_expr) {
|
||||
auto obj_val = obj_expr->ConstantValue();
|
||||
if (!obj_val) {
|
||||
return {};
|
||||
}
|
||||
|
||||
auto idx_val = idx_expr->ConstantValue();
|
||||
if (!idx_val) {
|
||||
return {};
|
||||
}
|
||||
|
||||
uint32_t el_count = 0;
|
||||
sem::Type::ElementOf(obj_val->Type(), &el_count);
|
||||
|
||||
AInt idx = idx_val->As<AInt>();
|
||||
if (idx < 0 || idx >= el_count) {
|
||||
auto clamped = std::min<AInt::type>(std::max<AInt::type>(idx, 0), el_count - 1);
|
||||
AddWarning("index " + std::to_string(idx) + " out of bounds [0.." +
|
||||
std::to_string(el_count - 1) + "]. Clamping index to " +
|
||||
std::to_string(clamped),
|
||||
idx_expr->Declaration()->source);
|
||||
idx = clamped;
|
||||
}
|
||||
|
||||
return obj_val->Index(static_cast<size_t>(idx));
|
||||
}
|
||||
|
||||
const sem::Constant* ConstEval::MemberAccess(const sem::Expression* obj_expr,
|
||||
const sem::StructMember* member) {
|
||||
auto obj_val = obj_expr->ConstantValue();
|
||||
if (!obj_val) {
|
||||
return {};
|
||||
}
|
||||
return obj_val->Index(static_cast<size_t>(member->Index()));
|
||||
}
|
||||
|
||||
const sem::Constant* ConstEval::Swizzle(const sem::Type* ty,
|
||||
const sem::Expression* vec_expr,
|
||||
const std::vector<uint32_t>& indices) {
|
||||
auto* vec_val = vec_expr->ConstantValue();
|
||||
if (!vec_val) {
|
||||
return nullptr;
|
||||
}
|
||||
if (indices.size() == 1) {
|
||||
return static_cast<const Constant*>(vec_val->Index(static_cast<size_t>(indices[0])));
|
||||
} else {
|
||||
auto values = utils::Transform(indices, [&](uint32_t i) {
|
||||
return static_cast<const Constant*>(vec_val->Index(static_cast<size_t>(i)));
|
||||
});
|
||||
return CreateComposite(builder, ty, std::move(values));
|
||||
}
|
||||
}
|
||||
|
||||
const sem::Constant* ConstEval::Bitcast(const sem::Type*, const sem::Expression*) {
|
||||
// TODO(crbug.com/tint/1581): Implement @const intrinsics
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
utils::Result<const sem::Constant*> ConstEval::Convert(const sem::Type* target_ty,
|
||||
const sem::Constant* value,
|
||||
const Source& source) {
|
||||
if (value->Type() == target_ty) {
|
||||
return value;
|
||||
}
|
||||
auto conv = static_cast<const Constant*>(value)->Convert(builder, target_ty, source);
|
||||
if (!conv) {
|
||||
return utils::Failure;
|
||||
}
|
||||
return conv.Get();
|
||||
}
|
||||
|
||||
void ConstEval::AddError(const std::string& msg, const Source& source) const {
|
||||
builder.Diagnostics().add_error(diag::System::Resolver, msg, source);
|
||||
}
|
||||
|
||||
void ConstEval::AddWarning(const std::string& msg, const Source& source) const {
|
||||
builder.Diagnostics().add_warning(diag::System::Resolver, msg, source);
|
||||
}
|
||||
|
||||
} // namespace tint::resolver
|
||||
|
|
|
@ -16,24 +16,111 @@
|
|||
#define SRC_TINT_RESOLVER_CONST_EVAL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "src/tint/utils/result.h"
|
||||
|
||||
// Forward declarations
|
||||
namespace tint {
|
||||
class ProgramBuilder;
|
||||
class Source;
|
||||
} // namespace tint
|
||||
|
||||
// Forward declarations
|
||||
namespace tint::ast {
|
||||
class LiteralExpression;
|
||||
} // namespace tint::ast
|
||||
namespace tint::sem {
|
||||
class Constant;
|
||||
class Expression;
|
||||
class StructMember;
|
||||
class Type;
|
||||
} // namespace tint::sem
|
||||
|
||||
namespace tint::resolver::const_eval {
|
||||
namespace tint::resolver {
|
||||
|
||||
/// Typedef for a constant evaluation function
|
||||
using Function = const sem::Constant*(ProgramBuilder& builder,
|
||||
sem::Constant const* const* args,
|
||||
size_t num_args);
|
||||
/// ConstEval performs shader creation-time (constant expression) expression evaluation.
|
||||
/// Methods are called from the resolver, either directly or via member-function pointers indexed by
|
||||
/// the IntrinsicTable. All child-expression nodes are guaranteed to have been already resolved
|
||||
/// before calling a method to evaluate an expression's value.
|
||||
class ConstEval {
|
||||
public:
|
||||
/// Typedef for a constant evaluation function
|
||||
using Function = const sem::Constant* (ConstEval::*)(const sem::Type* result_ty,
|
||||
sem::Expression const* const* args,
|
||||
size_t num_args);
|
||||
|
||||
} // namespace tint::resolver::const_eval
|
||||
/// The result type of a method that may raise a diagnostic error and the caller should abort
|
||||
/// resolving. Can be one of three distinct values:
|
||||
/// * A non-null sem::Constant pointer. Returned when a expression resolves to a creation time
|
||||
/// value.
|
||||
/// * A null sem::Constant pointer. Returned when a expression cannot resolve to a creation time
|
||||
/// value, but is otherwise legal.
|
||||
/// * `utils::Failure`. Returned when there was a resolver error. In this situation the method
|
||||
/// will have already reported a diagnostic error message, and the caller should abort
|
||||
/// resolving.
|
||||
using ConstantResult = utils::Result<const sem::Constant*>;
|
||||
|
||||
/// Constructor
|
||||
/// @param b the program builder
|
||||
explicit ConstEval(ProgramBuilder& b);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Constant value evaluation methods, to be called directly from Resolver
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// @param ty the target type
|
||||
/// @param expr the input expression
|
||||
/// @return the bit-cast of the given expression to the given type, or null if the value cannot
|
||||
/// be calculated
|
||||
const sem::Constant* Bitcast(const sem::Type* ty, const sem::Expression* expr);
|
||||
|
||||
/// @param ty the target type
|
||||
/// @param args the input arguments
|
||||
/// @return the resulting type constructor or conversion, or null if the value cannot be
|
||||
/// calculated
|
||||
const sem::Constant* CtorOrConv(const sem::Type* ty,
|
||||
const std::vector<const sem::Expression*>& args);
|
||||
|
||||
/// @param obj the object being indexed
|
||||
/// @param idx the index expression
|
||||
/// @return the result of the index, or null if the value cannot be calculated
|
||||
const sem::Constant* Index(const sem::Expression* obj, const sem::Expression* idx);
|
||||
|
||||
/// @param ty the result type
|
||||
/// @param lit the literal AST node
|
||||
/// @return the constant value of the literal
|
||||
const sem::Constant* Literal(const sem::Type* ty, const ast::LiteralExpression* lit);
|
||||
|
||||
/// @param obj the object being accessed
|
||||
/// @param member the member
|
||||
/// @return the result of the member access, or null if the value cannot be calculated
|
||||
const sem::Constant* MemberAccess(const sem::Expression* obj, const sem::StructMember* member);
|
||||
|
||||
/// @param ty the result type
|
||||
/// @param vector the vector being swizzled
|
||||
/// @param indices the swizzle indices
|
||||
/// @return the result of the swizzle, or null if the value cannot be calculated
|
||||
const sem::Constant* Swizzle(const sem::Type* ty,
|
||||
const sem::Expression* vector,
|
||||
const std::vector<uint32_t>& indices);
|
||||
|
||||
/// Convert the `value` to `target_type`
|
||||
/// @param ty the result type
|
||||
/// @param value the value being converted
|
||||
/// @param source the source location of the conversion
|
||||
/// @return the converted value, or null if the value cannot be calculated
|
||||
ConstantResult Convert(const sem::Type* ty, const sem::Constant* value, const Source& source);
|
||||
|
||||
private:
|
||||
/// Adds the given error message to the diagnostics
|
||||
void AddError(const std::string& msg, const Source& source) const;
|
||||
|
||||
/// Adds the given warning message to the diagnostics
|
||||
void AddWarning(const std::string& msg, const Source& source) const;
|
||||
|
||||
ProgramBuilder& builder;
|
||||
};
|
||||
|
||||
} // namespace tint::resolver
|
||||
|
||||
#endif // SRC_TINT_RESOLVER_CONST_EVAL_H_
|
||||
|
|
|
@ -12,8 +12,6 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/resolver/resolver.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "gtest/gtest.h"
|
|
@ -854,7 +854,7 @@ struct OverloadInfo {
|
|||
/// The flags for the overload
|
||||
OverloadFlags flags;
|
||||
/// The function used to evaluate the overload at shader-creation time.
|
||||
const_eval::Function* const const_eval_fn;
|
||||
ConstEval::Function const const_eval_fn;
|
||||
};
|
||||
|
||||
/// IntrinsicInfo describes a builtin function or operator overload
|
||||
|
|
|
@ -47,7 +47,7 @@ class IntrinsicTable {
|
|||
/// The semantic info for the builtin
|
||||
const sem::Builtin* sem = nullptr;
|
||||
/// The constant evaluation function
|
||||
const_eval::Function* const_eval_fn = nullptr;
|
||||
ConstEval::Function const_eval_fn = nullptr;
|
||||
};
|
||||
|
||||
/// UnaryOperator describes a resolved unary operator
|
||||
|
@ -56,6 +56,8 @@ class IntrinsicTable {
|
|||
const sem::Type* result = nullptr;
|
||||
/// The type of the parameter of the unary operator
|
||||
const sem::Type* parameter = nullptr;
|
||||
/// The constant evaluation function
|
||||
ConstEval::Function const_eval_fn = nullptr;
|
||||
};
|
||||
|
||||
/// BinaryOperator describes a resolved binary operator
|
||||
|
@ -66,6 +68,8 @@ class IntrinsicTable {
|
|||
const sem::Type* lhs = nullptr;
|
||||
/// The type of RHS parameter of the binary operator
|
||||
const sem::Type* rhs = nullptr;
|
||||
/// The constant evaluation function
|
||||
ConstEval::Function const_eval_fn = nullptr;
|
||||
};
|
||||
|
||||
/// Lookup looks for the builtin overload with the given signature, raising an error diagnostic
|
||||
|
|
|
@ -103,7 +103,7 @@ constexpr OverloadInfo kOverloads[] = {
|
|||
{{- end }}
|
||||
{{- if $o.IsDeprecated}}, OverloadFlag::kIsDeprecated{{end }}),
|
||||
/* const eval */
|
||||
{{- if $o.ConstEvalFunction }} const_eval::{{$o.ConstEvalFunction}},
|
||||
{{- if $o.ConstEvalFunction }} &ConstEval::{{$o.ConstEvalFunction}},
|
||||
{{- else }} nullptr,
|
||||
{{- end }}
|
||||
},
|
||||
|
|
|
@ -91,6 +91,7 @@ namespace tint::resolver {
|
|||
Resolver::Resolver(ProgramBuilder* builder)
|
||||
: builder_(builder),
|
||||
diagnostics_(builder->Diagnostics()),
|
||||
const_eval_(*builder),
|
||||
intrinsic_table_(IntrinsicTable::Create(*builder)),
|
||||
sem_(builder, dependencies_),
|
||||
validator_(builder, sem_) {}
|
||||
|
@ -1313,7 +1314,7 @@ const sem::Expression* Resolver::Materialize(const sem::Expression* expr,
|
|||
<< ") called on expression with no constant value";
|
||||
return nullptr;
|
||||
}
|
||||
auto materialized_val = ConvertValue(expr_val, target_ty, decl->source);
|
||||
auto materialized_val = const_eval_.Convert(target_ty, expr_val, decl->source);
|
||||
if (!materialized_val) {
|
||||
// ConvertValue() has already failed and raised an diagnostic error.
|
||||
return nullptr;
|
||||
|
@ -1422,7 +1423,7 @@ sem::Expression* Resolver::IndexAccessor(const ast::IndexAccessorExpression* exp
|
|||
ty = builder_->create<sem::Reference>(ty, ref->StorageClass(), ref->Access());
|
||||
}
|
||||
|
||||
auto val = EvaluateIndexValue(obj, idx);
|
||||
auto val = const_eval_.Index(obj, idx);
|
||||
bool has_side_effects = idx->HasSideEffects() || obj->HasSideEffects();
|
||||
auto* sem = builder_->create<sem::IndexAccessorExpression>(
|
||||
expr, ty, obj, idx, current_statement_, std::move(val), has_side_effects,
|
||||
|
@ -1441,7 +1442,7 @@ sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
auto val = EvaluateBitcastValue(inner, ty);
|
||||
auto val = const_eval_.Bitcast(ty, inner);
|
||||
auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_, std::move(val),
|
||||
inner->HasSideEffects());
|
||||
|
||||
|
@ -1489,7 +1490,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
|
|||
if (!MaterializeArguments(args, call_target)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto val = EvaluateCtorOrConvValue(args, call_target->ReturnType());
|
||||
auto val = const_eval_.CtorOrConv(call_target->ReturnType(), args);
|
||||
return builder_->create<sem::Call>(expr, call_target, std::move(args), current_statement_,
|
||||
val, has_side_effects);
|
||||
};
|
||||
|
@ -1528,7 +1529,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
|
|||
if (!MaterializeArguments(args, call_target)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto val = EvaluateCtorOrConvValue(args, arr);
|
||||
auto val = const_eval_.CtorOrConv(arr, args);
|
||||
return builder_->create<sem::Call>(expr, call_target, std::move(args),
|
||||
current_statement_, val, has_side_effects);
|
||||
},
|
||||
|
@ -1550,7 +1551,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
|
|||
if (!MaterializeArguments(args, call_target)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto val = EvaluateCtorOrConvValue(args, str);
|
||||
auto val = const_eval_.CtorOrConv(str, args);
|
||||
return builder_->create<sem::Call>(expr, call_target, std::move(args),
|
||||
current_statement_, std::move(val),
|
||||
has_side_effects);
|
||||
|
@ -1682,28 +1683,17 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr,
|
|||
}
|
||||
|
||||
// If the builtin is @const, and all arguments have constant values, evaluate the builtin now.
|
||||
const sem::Constant* constant = nullptr;
|
||||
const sem::Constant* value = nullptr;
|
||||
if (builtin.const_eval_fn) {
|
||||
std::vector<const sem::Constant*> values(args.size());
|
||||
bool is_const = true; // all arguments have constant values
|
||||
for (size_t i = 0; i < values.size(); i++) {
|
||||
if (auto v = args[i]->ConstantValue()) {
|
||||
values[i] = std::move(v);
|
||||
} else {
|
||||
is_const = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (is_const) {
|
||||
constant = builtin.const_eval_fn(*builder_, values.data(), args.size());
|
||||
}
|
||||
value = (const_eval_.*builtin.const_eval_fn)(builtin.sem->ReturnType(), args.data(),
|
||||
args.size());
|
||||
}
|
||||
|
||||
bool has_side_effects =
|
||||
builtin.sem->HasSideEffects() ||
|
||||
std::any_of(args.begin(), args.end(), [](auto* e) { return e->HasSideEffects(); });
|
||||
auto* call = builder_->create<sem::Call>(expr, builtin.sem, std::move(args), current_statement_,
|
||||
constant, has_side_effects);
|
||||
value, has_side_effects);
|
||||
|
||||
if (current_function_) {
|
||||
current_function_->AddDirectlyCalledBuiltin(builtin.sem);
|
||||
|
@ -1856,7 +1846,7 @@ sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
auto val = EvaluateLiteralValue(literal, ty);
|
||||
auto val = const_eval_.Literal(ty, literal);
|
||||
return builder_->create<sem::Expression>(literal, ty, current_statement_, std::move(val),
|
||||
/* has_side_effects */ false);
|
||||
}
|
||||
|
@ -1976,7 +1966,7 @@ sem::Expression* Resolver::MemberAccessor(const ast::MemberAccessorExpression* e
|
|||
ret = builder_->create<sem::Reference>(ret, ref->StorageClass(), ref->Access());
|
||||
}
|
||||
|
||||
auto* val = EvaluateMemberAccessValue(object, member);
|
||||
auto* val = const_eval_.MemberAccess(object, member);
|
||||
return builder_->create<sem::StructMemberAccess>(expr, ret, current_statement_, val, object,
|
||||
member, has_side_effects, source_var);
|
||||
}
|
||||
|
@ -2044,7 +2034,7 @@ sem::Expression* Resolver::MemberAccessor(const ast::MemberAccessorExpression* e
|
|||
// the swizzle.
|
||||
ret = builder_->create<sem::Vector>(vec->type(), static_cast<uint32_t>(size));
|
||||
}
|
||||
auto* val = EvaluateSwizzleValue(object, ret, swizzle);
|
||||
auto* val = const_eval_.Swizzle(ret, object, swizzle);
|
||||
return builder_->create<sem::Swizzle>(expr, ret, current_statement_, val, object,
|
||||
std::move(swizzle), has_side_effects, source_var);
|
||||
}
|
||||
|
@ -2078,9 +2068,14 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
|
|||
}
|
||||
}
|
||||
|
||||
auto* val = EvaluateBinaryValue(lhs, rhs, op);
|
||||
const sem::Constant* value = nullptr;
|
||||
if (op.const_eval_fn) {
|
||||
const sem::Expression* args[] = {lhs, rhs};
|
||||
value = (const_eval_.*op.const_eval_fn)(op.result, args, 2u);
|
||||
}
|
||||
|
||||
bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects();
|
||||
auto* sem = builder_->create<sem::Expression>(expr, op.result, current_statement_, val,
|
||||
auto* sem = builder_->create<sem::Expression>(expr, op.result, current_statement_, value,
|
||||
has_side_effects);
|
||||
sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors();
|
||||
|
||||
|
@ -2096,7 +2091,7 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
|
|||
|
||||
const sem::Type* ty = nullptr;
|
||||
const sem::Variable* source_var = nullptr;
|
||||
const sem::Constant* val = nullptr;
|
||||
const sem::Constant* value = nullptr;
|
||||
|
||||
switch (unary->op) {
|
||||
case ast::UnaryOp::kAddressOf:
|
||||
|
@ -2149,12 +2144,14 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
|
|||
}
|
||||
}
|
||||
ty = op.result;
|
||||
val = EvaluateUnaryValue(expr, op);
|
||||
if (op.const_eval_fn) {
|
||||
value = (const_eval_.*op.const_eval_fn)(ty, &expr, 1u);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto* sem = builder_->create<sem::Expression>(unary, ty, current_statement_, val,
|
||||
auto* sem = builder_->create<sem::Expression>(unary, ty, current_statement_, value,
|
||||
expr->HasSideEffects(), source_var);
|
||||
sem->Behaviors() = expr->Behaviors();
|
||||
return sem;
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/resolver/const_eval.h"
|
||||
#include "src/tint/resolver/dependency_graph.h"
|
||||
#include "src/tint/resolver/intrinsic_table.h"
|
||||
#include "src/tint/resolver/sem_helper.h"
|
||||
|
@ -34,7 +35,6 @@
|
|||
#include "src/tint/sem/constant.h"
|
||||
#include "src/tint/sem/function.h"
|
||||
#include "src/tint/sem/struct.h"
|
||||
#include "src/tint/utils/result.h"
|
||||
#include "src/tint/utils/unique_vector.h"
|
||||
|
||||
// Forward declarations
|
||||
|
@ -204,46 +204,6 @@ class Resolver {
|
|||
sem::Expression* MemberAccessor(const ast::MemberAccessorExpression*);
|
||||
sem::Expression* UnaryOp(const ast::UnaryOpExpression*);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Constant value evaluation methods
|
||||
///
|
||||
/// These methods are called from the expression resolving methods, and so child-expression
|
||||
/// nodes are guaranteed to have been already resolved and any constant values calculated.
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
const sem::Constant* EvaluateBinaryValue(const sem::Expression* lhs,
|
||||
const sem::Expression* rhs,
|
||||
const IntrinsicTable::BinaryOperator&);
|
||||
const sem::Constant* EvaluateBitcastValue(const sem::Expression*, const sem::Type*);
|
||||
const sem::Constant* EvaluateCtorOrConvValue(
|
||||
const std::vector<const sem::Expression*>& args,
|
||||
const sem::Type* ty); // Note: ty is not an array or structure
|
||||
const sem::Constant* EvaluateIndexValue(const sem::Expression* obj, const sem::Expression* idx);
|
||||
const sem::Constant* EvaluateLiteralValue(const ast::LiteralExpression*, const sem::Type*);
|
||||
const sem::Constant* EvaluateMemberAccessValue(const sem::Expression* obj,
|
||||
const sem::StructMember* member);
|
||||
const sem::Constant* EvaluateSwizzleValue(const sem::Expression* vector,
|
||||
const sem::Type* type,
|
||||
const std::vector<uint32_t>& indices);
|
||||
const sem::Constant* EvaluateUnaryValue(const sem::Expression*,
|
||||
const IntrinsicTable::UnaryOperator&);
|
||||
|
||||
/// The result type of a ConstantEvaluation method.
|
||||
/// Can be one of three distinct values:
|
||||
/// * A non-null sem::Constant pointer. Returned when a expression resolves to a creation time
|
||||
/// value.
|
||||
/// * A null sem::Constant pointer. Returned when a expression cannot resolve to a creation time
|
||||
/// value, but is otherwise legal.
|
||||
/// * `utils::Failure`. Returned when there was a resolver error. In this situation the method
|
||||
/// will have already reported a diagnostic error message, and the caller should abort
|
||||
/// resolving.
|
||||
using ConstantResult = utils::Result<const sem::Constant*>;
|
||||
|
||||
/// Convert the `value` to `target_type`
|
||||
/// @return the converted value
|
||||
ConstantResult ConvertValue(const sem::Constant* value,
|
||||
const sem::Type* target_type,
|
||||
const Source& source);
|
||||
|
||||
/// If `expr` is not of an abstract-numeric type, then Materialize() will just return `expr`.
|
||||
/// If `expr` is of an abstract-numeric type:
|
||||
/// * Materialize will create and return a sem::Materialize node wrapping `expr`.
|
||||
|
@ -449,6 +409,7 @@ class Resolver {
|
|||
|
||||
ProgramBuilder* const builder_;
|
||||
diag::List& diagnostics_;
|
||||
ConstEval const_eval_;
|
||||
std::unique_ptr<IntrinsicTable> const intrinsic_table_;
|
||||
DependencyGraph dependencies_;
|
||||
SemHelper sem_;
|
||||
|
|
|
@ -1,605 +0,0 @@
|
|||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/resolver/resolver.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "src/tint/sem/abstract_float.h"
|
||||
#include "src/tint/sem/abstract_int.h"
|
||||
#include "src/tint/sem/constant.h"
|
||||
#include "src/tint/sem/member_accessor_expression.h"
|
||||
#include "src/tint/sem/type_constructor.h"
|
||||
#include "src/tint/utils/compiler_macros.h"
|
||||
#include "src/tint/utils/map.h"
|
||||
#include "src/tint/utils/transform.h"
|
||||
|
||||
using namespace tint::number_suffixes; // NOLINT
|
||||
|
||||
namespace tint::resolver {
|
||||
|
||||
namespace {
|
||||
|
||||
/// TypeDispatch is a helper for calling the function `f`, passing a single zero-value argument of
|
||||
/// the C++ type that corresponds to the sem::Type `type`. For example, calling `TypeDispatch()`
|
||||
/// with a type of `sem::I32*` will call the function f with a single argument of `i32(0)`.
|
||||
/// @returns the value returned by calling `f`.
|
||||
/// @note `type` must be a scalar or abstract numeric type. Other types will not call `f`, and will
|
||||
/// return the zero-initialized value of the return type for `f`.
|
||||
template <typename F>
|
||||
auto TypeDispatch(const sem::Type* type, F&& f) {
|
||||
return Switch(
|
||||
type, //
|
||||
[&](const sem::AbstractInt*) { return f(AInt(0)); }, //
|
||||
[&](const sem::AbstractFloat*) { return f(AFloat(0)); }, //
|
||||
[&](const sem::I32*) { return f(i32(0)); }, //
|
||||
[&](const sem::U32*) { return f(u32(0)); }, //
|
||||
[&](const sem::F32*) { return f(f32(0)); }, //
|
||||
[&](const sem::F16*) { return f(f16(0)); }, //
|
||||
[&](const sem::Bool*) { return f(static_cast<bool>(0)); });
|
||||
}
|
||||
|
||||
/// @returns `value` if `T` is not a Number, otherwise ValueOf returns the inner value of the
|
||||
/// Number.
|
||||
template <typename T>
|
||||
inline auto ValueOf(T value) {
|
||||
if constexpr (std::is_same_v<UnwrapNumber<T>, T>) {
|
||||
return value;
|
||||
} else {
|
||||
return value.value;
|
||||
}
|
||||
}
|
||||
|
||||
/// @returns true if `value` is a positive zero.
|
||||
template <typename T>
|
||||
inline bool IsPositiveZero(T value) {
|
||||
using N = UnwrapNumber<T>;
|
||||
return Number<N>(value) == Number<N>(0); // Considers sign bit
|
||||
}
|
||||
|
||||
/// Constant inherits from sem::Constant to add an private implementation method for conversion.
|
||||
struct Constant : public sem::Constant {
|
||||
/// Convert attempts to convert the constant value to the given type. On error, Convert()
|
||||
/// creates a new diagnostic message and returns a Failure.
|
||||
virtual utils::Result<const Constant*> Convert(ProgramBuilder& builder,
|
||||
const sem::Type* target_ty,
|
||||
const Source& source) const = 0;
|
||||
};
|
||||
|
||||
// Forward declaration
|
||||
const Constant* CreateComposite(ProgramBuilder& builder,
|
||||
const sem::Type* type,
|
||||
std::vector<const Constant*> elements);
|
||||
|
||||
/// Element holds a single scalar or abstract-numeric value.
|
||||
/// Element implements the Constant interface.
|
||||
template <typename T>
|
||||
struct Element : Constant {
|
||||
Element(const sem::Type* t, T v) : type(t), value(v) {}
|
||||
~Element() override = default;
|
||||
const sem::Type* Type() const override { return type; }
|
||||
std::variant<std::monostate, AInt, AFloat> Value() const override {
|
||||
if constexpr (IsFloatingPoint<UnwrapNumber<T>>) {
|
||||
return static_cast<AFloat>(value);
|
||||
} else {
|
||||
return static_cast<AInt>(value);
|
||||
}
|
||||
}
|
||||
const Constant* Index(size_t) const override { return nullptr; }
|
||||
bool AllZero() const override { return IsPositiveZero(value); }
|
||||
bool AnyZero() const override { return IsPositiveZero(value); }
|
||||
bool AllEqual() const override { return true; }
|
||||
size_t Hash() const override { return utils::Hash(type, ValueOf(value)); }
|
||||
|
||||
utils::Result<const Constant*> Convert(ProgramBuilder& builder,
|
||||
const sem::Type* target_ty,
|
||||
const Source& source) const override {
|
||||
TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
|
||||
if (target_ty == type) {
|
||||
// If the types are identical, then no conversion is needed.
|
||||
return this;
|
||||
}
|
||||
bool failed = false;
|
||||
auto* res = TypeDispatch(target_ty, [&](auto zero_to) -> const Constant* {
|
||||
// `T` is the source type, `value` is the source value.
|
||||
// `TO` is the target type.
|
||||
using TO = std::decay_t<decltype(zero_to)>;
|
||||
if constexpr (std::is_same_v<TO, bool>) {
|
||||
// [x -> bool]
|
||||
return builder.create<Element<TO>>(target_ty, !IsPositiveZero(value));
|
||||
} else if constexpr (std::is_same_v<T, bool>) {
|
||||
// [bool -> x]
|
||||
return builder.create<Element<TO>>(target_ty, TO(value ? 1 : 0));
|
||||
} else if (auto conv = CheckedConvert<TO>(value)) {
|
||||
// Conversion success
|
||||
return builder.create<Element<TO>>(target_ty, conv.Get());
|
||||
// --- Below this point are the failure cases ---
|
||||
} else if constexpr (std::is_same_v<T, AInt> || std::is_same_v<T, AFloat>) {
|
||||
// [abstract-numeric -> x] - materialization failure
|
||||
std::stringstream ss;
|
||||
ss << "value " << value << " cannot be represented as ";
|
||||
ss << "'" << builder.FriendlyName(target_ty) << "'";
|
||||
builder.Diagnostics().add_error(tint::diag::System::Resolver, ss.str(), source);
|
||||
failed = true;
|
||||
} else if constexpr (IsFloatingPoint<UnwrapNumber<TO>>) {
|
||||
// [x -> floating-point] - number not exactly representable
|
||||
// https://www.w3.org/TR/WGSL/#floating-point-conversion
|
||||
constexpr auto kInf = std::numeric_limits<double>::infinity();
|
||||
switch (conv.Failure()) {
|
||||
case ConversionFailure::kExceedsNegativeLimit:
|
||||
return builder.create<Element<TO>>(target_ty, TO(-kInf));
|
||||
case ConversionFailure::kExceedsPositiveLimit:
|
||||
return builder.create<Element<TO>>(target_ty, TO(kInf));
|
||||
}
|
||||
} else {
|
||||
// [x -> integer] - number not exactly representable
|
||||
// https://www.w3.org/TR/WGSL/#floating-point-conversion
|
||||
switch (conv.Failure()) {
|
||||
case ConversionFailure::kExceedsNegativeLimit:
|
||||
return builder.create<Element<TO>>(target_ty, TO(TO::kLowest));
|
||||
case ConversionFailure::kExceedsPositiveLimit:
|
||||
return builder.create<Element<TO>>(target_ty, TO(TO::kHighest));
|
||||
}
|
||||
}
|
||||
return nullptr; // Expression is not constant.
|
||||
});
|
||||
if (failed) {
|
||||
// A diagnostic error has been raised, and resolving should abort.
|
||||
return utils::Failure;
|
||||
}
|
||||
return res;
|
||||
TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
|
||||
}
|
||||
|
||||
sem::Type const* const type;
|
||||
const T value;
|
||||
};
|
||||
|
||||
/// Splat holds a single Constant value, duplicated as all children.
|
||||
/// Splat is used for zero-initializers, 'splat' constructors, or constructors where each element is
|
||||
/// identical. Splat may be of a vector, matrix or array type.
|
||||
/// Splat implements the Constant interface.
|
||||
struct Splat : Constant {
|
||||
Splat(const sem::Type* t, const Constant* e, size_t n) : type(t), el(e), count(n) {}
|
||||
~Splat() override = default;
|
||||
const sem::Type* Type() const override { return type; }
|
||||
std::variant<std::monostate, AInt, AFloat> Value() const override { return {}; }
|
||||
const Constant* Index(size_t i) const override { return i < count ? el : nullptr; }
|
||||
bool AllZero() const override { return el->AllZero(); }
|
||||
bool AnyZero() const override { return el->AnyZero(); }
|
||||
bool AllEqual() const override { return true; }
|
||||
size_t Hash() const override { return utils::Hash(type, el->Hash(), count); }
|
||||
|
||||
utils::Result<const Constant*> Convert(ProgramBuilder& builder,
|
||||
const sem::Type* target_ty,
|
||||
const Source& source) const override {
|
||||
// Convert the single splatted element type.
|
||||
auto conv_el = el->Convert(builder, sem::Type::ElementOf(target_ty), source);
|
||||
if (!conv_el) {
|
||||
return utils::Failure;
|
||||
}
|
||||
if (!conv_el.Get()) {
|
||||
return nullptr;
|
||||
}
|
||||
return builder.create<Splat>(target_ty, conv_el.Get(), count);
|
||||
}
|
||||
|
||||
sem::Type const* const type;
|
||||
const Constant* el;
|
||||
const size_t count;
|
||||
};
|
||||
|
||||
/// Composite holds a number of mixed child Constant values.
|
||||
/// Composite may be of a vector, matrix or array type.
|
||||
/// If each element is the same type and value, then a Splat would be a more efficient constant
|
||||
/// implementation. Use CreateComposite() to create the appropriate Constant type.
|
||||
/// Composite implements the Constant interface.
|
||||
struct Composite : Constant {
|
||||
Composite(const sem::Type* t, std::vector<const Constant*> els, bool all_0, bool any_0)
|
||||
: type(t), elements(std::move(els)), all_zero(all_0), any_zero(any_0), hash(CalcHash()) {}
|
||||
~Composite() override = default;
|
||||
const sem::Type* Type() const override { return type; }
|
||||
std::variant<std::monostate, AInt, AFloat> Value() const override { return {}; }
|
||||
const Constant* Index(size_t i) const override {
|
||||
return i < elements.size() ? elements[i] : nullptr;
|
||||
}
|
||||
bool AllZero() const override { return all_zero; }
|
||||
bool AnyZero() const override { return any_zero; }
|
||||
bool AllEqual() const override { return false; /* otherwise this should be a Splat */ }
|
||||
size_t Hash() const override { return hash; }
|
||||
|
||||
utils::Result<const Constant*> Convert(ProgramBuilder& builder,
|
||||
const sem::Type* target_ty,
|
||||
const Source& source) const override {
|
||||
// Convert each of the composite element types.
|
||||
auto* el_ty = sem::Type::ElementOf(target_ty);
|
||||
std::vector<const Constant*> conv_els;
|
||||
conv_els.reserve(elements.size());
|
||||
for (auto* el : elements) {
|
||||
auto conv_el = el->Convert(builder, el_ty, source);
|
||||
if (!conv_el) {
|
||||
return utils::Failure;
|
||||
}
|
||||
if (!conv_el.Get()) {
|
||||
return nullptr;
|
||||
}
|
||||
conv_els.emplace_back(conv_el.Get());
|
||||
}
|
||||
return CreateComposite(builder, target_ty, std::move(conv_els));
|
||||
}
|
||||
|
||||
size_t CalcHash() {
|
||||
auto h = utils::Hash(type, all_zero, any_zero);
|
||||
for (auto* el : elements) {
|
||||
utils::HashCombine(&h, el->Hash());
|
||||
}
|
||||
return h;
|
||||
}
|
||||
|
||||
sem::Type const* const type;
|
||||
const std::vector<const Constant*> elements;
|
||||
const bool all_zero;
|
||||
const bool any_zero;
|
||||
const size_t hash;
|
||||
};
|
||||
|
||||
/// CreateElement constructs and returns an Element<T>.
|
||||
template <typename T>
|
||||
const Constant* CreateElement(ProgramBuilder& builder, const sem::Type* t, T v) {
|
||||
return builder.create<Element<T>>(t, v);
|
||||
}
|
||||
|
||||
/// ZeroValue returns a Constant for the zero-value of the type `type`.
|
||||
const Constant* ZeroValue(ProgramBuilder& builder, const sem::Type* type) {
|
||||
return Switch(
|
||||
type, //
|
||||
[&](const sem::Vector* v) -> const Constant* {
|
||||
auto* zero_el = ZeroValue(builder, v->type());
|
||||
return builder.create<Splat>(type, zero_el, v->Width());
|
||||
},
|
||||
[&](const sem::Matrix* m) -> const Constant* {
|
||||
auto* zero_el = ZeroValue(builder, m->ColumnType());
|
||||
return builder.create<Splat>(type, zero_el, m->columns());
|
||||
},
|
||||
[&](const sem::Array* a) -> const Constant* {
|
||||
if (auto* zero_el = ZeroValue(builder, a->ElemType())) {
|
||||
return builder.create<Splat>(type, zero_el, a->Count());
|
||||
}
|
||||
return nullptr;
|
||||
},
|
||||
[&](const sem::Struct* s) -> const Constant* {
|
||||
std::unordered_map<sem::Type*, const Constant*> zero_by_type;
|
||||
std::vector<const Constant*> zeros;
|
||||
zeros.reserve(s->Members().size());
|
||||
for (auto* member : s->Members()) {
|
||||
auto* zero = utils::GetOrCreate(zero_by_type, member->Type(),
|
||||
[&] { return ZeroValue(builder, member->Type()); });
|
||||
if (!zero) {
|
||||
return nullptr;
|
||||
}
|
||||
zeros.emplace_back(zero);
|
||||
}
|
||||
if (zero_by_type.size() == 1) {
|
||||
// All members were of the same type, so the zero value is the same for all members.
|
||||
return builder.create<Splat>(type, zeros[0], s->Members().size());
|
||||
}
|
||||
return CreateComposite(builder, s, std::move(zeros));
|
||||
},
|
||||
[&](Default) -> const Constant* {
|
||||
return TypeDispatch(type, [&](auto zero) -> const Constant* {
|
||||
return CreateElement(builder, type, zero);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/// Equal returns true if the constants `a` and `b` are of the same type and value.
|
||||
bool Equal(const sem::Constant* a, const sem::Constant* b) {
|
||||
if (a->Hash() != b->Hash()) {
|
||||
return false;
|
||||
}
|
||||
if (a->Type() != b->Type()) {
|
||||
return false;
|
||||
}
|
||||
return Switch(
|
||||
a->Type(), //
|
||||
[&](const sem::Vector* vec) {
|
||||
for (size_t i = 0; i < vec->Width(); i++) {
|
||||
if (!Equal(a->Index(i), b->Index(i))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
},
|
||||
[&](const sem::Matrix* mat) {
|
||||
for (size_t i = 0; i < mat->columns(); i++) {
|
||||
if (!Equal(a->Index(i), b->Index(i))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
},
|
||||
[&](const sem::Array* arr) {
|
||||
for (size_t i = 0; i < arr->Count(); i++) {
|
||||
if (!Equal(a->Index(i), b->Index(i))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
},
|
||||
[&](Default) { return a->Value() == b->Value(); });
|
||||
}
|
||||
|
||||
/// CreateComposite is used to construct a constant of a vector, matrix or array type.
|
||||
/// CreateComposite examines the element values and will return either a Composite or a Splat,
|
||||
/// depending on the element types and values.
|
||||
const Constant* CreateComposite(ProgramBuilder& builder,
|
||||
const sem::Type* type,
|
||||
std::vector<const Constant*> elements) {
|
||||
if (elements.size() == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
bool any_zero = false;
|
||||
bool all_zero = true;
|
||||
bool all_equal = true;
|
||||
auto* first = elements.front();
|
||||
for (auto* el : elements) {
|
||||
if (!el) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!any_zero && el->AnyZero()) {
|
||||
any_zero = true;
|
||||
}
|
||||
if (all_zero && !el->AllZero()) {
|
||||
all_zero = false;
|
||||
}
|
||||
if (all_equal && el != first) {
|
||||
if (!Equal(el, first)) {
|
||||
all_equal = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (all_equal) {
|
||||
return builder.create<Splat>(type, elements[0], elements.size());
|
||||
} else {
|
||||
return builder.create<Composite>(type, std::move(elements), all_zero, any_zero);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
const sem::Constant* Resolver::EvaluateLiteralValue(const ast::LiteralExpression* literal,
|
||||
const sem::Type* type) {
|
||||
return Switch(
|
||||
literal,
|
||||
[&](const ast::BoolLiteralExpression* lit) {
|
||||
return CreateElement(*builder_, type, lit->value);
|
||||
},
|
||||
[&](const ast::IntLiteralExpression* lit) -> const Constant* {
|
||||
switch (lit->suffix) {
|
||||
case ast::IntLiteralExpression::Suffix::kNone:
|
||||
return CreateElement(*builder_, type, AInt(lit->value));
|
||||
case ast::IntLiteralExpression::Suffix::kI:
|
||||
return CreateElement(*builder_, type, i32(lit->value));
|
||||
case ast::IntLiteralExpression::Suffix::kU:
|
||||
return CreateElement(*builder_, type, u32(lit->value));
|
||||
}
|
||||
return nullptr;
|
||||
},
|
||||
[&](const ast::FloatLiteralExpression* lit) -> const Constant* {
|
||||
switch (lit->suffix) {
|
||||
case ast::FloatLiteralExpression::Suffix::kNone:
|
||||
return CreateElement(*builder_, type, AFloat(lit->value));
|
||||
case ast::FloatLiteralExpression::Suffix::kF:
|
||||
return CreateElement(*builder_, type, f32(lit->value));
|
||||
case ast::FloatLiteralExpression::Suffix::kH:
|
||||
return CreateElement(*builder_, type, f16(lit->value));
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
}
|
||||
|
||||
const sem::Constant* Resolver::EvaluateCtorOrConvValue(
|
||||
const std::vector<const sem::Expression*>& args,
|
||||
const sem::Type* ty) {
|
||||
// For zero value init, return 0s
|
||||
if (args.empty()) {
|
||||
return ZeroValue(*builder_, ty);
|
||||
}
|
||||
|
||||
if (auto* el_ty = sem::Type::ElementOf(ty); el_ty && args.size() == 1) {
|
||||
// Type constructor or conversion that takes a single argument.
|
||||
auto& src = args[0]->Declaration()->source;
|
||||
auto* arg = static_cast<const Constant*>(args[0]->ConstantValue());
|
||||
if (!arg) {
|
||||
return nullptr; // Single argument is not constant.
|
||||
}
|
||||
|
||||
if (ty->is_scalar()) { // Scalar type conversion: i32(x), u32(x), bool(x), etc
|
||||
return ConvertValue(arg, el_ty, src).Get();
|
||||
}
|
||||
|
||||
if (arg->Type() == el_ty) {
|
||||
// Argument type matches function type. This is a splat.
|
||||
auto splat = [&](size_t n) { return builder_->create<Splat>(ty, arg, n); };
|
||||
return Switch(
|
||||
ty, //
|
||||
[&](const sem::Vector* v) { return splat(v->Width()); },
|
||||
[&](const sem::Matrix* m) { return splat(m->columns()); },
|
||||
[&](const sem::Array* a) { return splat(a->Count()); });
|
||||
}
|
||||
|
||||
// Argument type and function type mismatch. This is a type conversion.
|
||||
if (auto conv = ConvertValue(arg, ty, src)) {
|
||||
return conv.Get();
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Helper for pushing all the argument constants to `els`.
|
||||
auto args_as_constants = [&] {
|
||||
return utils::Transform(
|
||||
args, [&](auto* expr) { return static_cast<const Constant*>(expr->ConstantValue()); });
|
||||
};
|
||||
|
||||
// Multiple arguments. Must be a type constructor.
|
||||
|
||||
return Switch(
|
||||
ty, // What's the target type being constructed?
|
||||
[&](const sem::Vector*) -> const Constant* {
|
||||
// Vector can be constructed with a mix of scalars / abstract numerics and smaller
|
||||
// vectors.
|
||||
std::vector<const Constant*> els;
|
||||
els.reserve(args.size());
|
||||
for (auto* expr : args) {
|
||||
auto* arg = static_cast<const Constant*>(expr->ConstantValue());
|
||||
if (!arg) {
|
||||
return nullptr;
|
||||
}
|
||||
auto* arg_ty = arg->Type();
|
||||
if (auto* arg_vec = arg_ty->As<sem::Vector>()) {
|
||||
// Extract out vector elements.
|
||||
for (uint32_t i = 0; i < arg_vec->Width(); i++) {
|
||||
auto* el = static_cast<const Constant*>(arg->Index(i));
|
||||
if (!el) {
|
||||
return nullptr;
|
||||
}
|
||||
els.emplace_back(el);
|
||||
}
|
||||
} else {
|
||||
els.emplace_back(arg);
|
||||
}
|
||||
}
|
||||
return CreateComposite(*builder_, ty, std::move(els));
|
||||
},
|
||||
[&](const sem::Matrix* m) -> const Constant* {
|
||||
// Matrix can be constructed with a set of scalars / abstract numerics, or column
|
||||
// vectors.
|
||||
if (args.size() == m->columns() * m->rows()) {
|
||||
// Matrix built from scalars / abstract numerics
|
||||
std::vector<const Constant*> els;
|
||||
els.reserve(args.size());
|
||||
for (uint32_t c = 0; c < m->columns(); c++) {
|
||||
std::vector<const Constant*> column;
|
||||
column.reserve(m->rows());
|
||||
for (uint32_t r = 0; r < m->rows(); r++) {
|
||||
auto* arg =
|
||||
static_cast<const Constant*>(args[r + c * m->rows()]->ConstantValue());
|
||||
if (!arg) {
|
||||
return nullptr;
|
||||
}
|
||||
column.emplace_back(arg);
|
||||
}
|
||||
els.push_back(CreateComposite(*builder_, m->ColumnType(), std::move(column)));
|
||||
}
|
||||
return CreateComposite(*builder_, ty, std::move(els));
|
||||
}
|
||||
// Matrix built from column vectors
|
||||
return CreateComposite(*builder_, ty, args_as_constants());
|
||||
},
|
||||
[&](const sem::Array*) {
|
||||
// Arrays must be constructed using a list of elements
|
||||
return CreateComposite(*builder_, ty, args_as_constants());
|
||||
},
|
||||
[&](const sem::Struct*) {
|
||||
// Structures must be constructed using a list of elements
|
||||
return CreateComposite(*builder_, ty, args_as_constants());
|
||||
});
|
||||
}
|
||||
|
||||
const sem::Constant* Resolver::EvaluateIndexValue(const sem::Expression* obj_expr,
|
||||
const sem::Expression* idx_expr) {
|
||||
auto obj_val = obj_expr->ConstantValue();
|
||||
if (!obj_val) {
|
||||
return {};
|
||||
}
|
||||
|
||||
auto idx_val = idx_expr->ConstantValue();
|
||||
if (!idx_val) {
|
||||
return {};
|
||||
}
|
||||
|
||||
uint32_t el_count = 0;
|
||||
sem::Type::ElementOf(obj_val->Type(), &el_count);
|
||||
|
||||
AInt idx = idx_val->As<AInt>();
|
||||
if (idx < 0 || idx >= el_count) {
|
||||
auto clamped = std::min<AInt::type>(std::max<AInt::type>(idx, 0), el_count - 1);
|
||||
AddWarning("index " + std::to_string(idx) + " out of bounds [0.." +
|
||||
std::to_string(el_count - 1) + "]. Clamping index to " +
|
||||
std::to_string(clamped),
|
||||
idx_expr->Declaration()->source);
|
||||
idx = clamped;
|
||||
}
|
||||
|
||||
return obj_val->Index(static_cast<size_t>(idx));
|
||||
}
|
||||
|
||||
const sem::Constant* Resolver::EvaluateMemberAccessValue(const sem::Expression* obj_expr,
|
||||
const sem::StructMember* member) {
|
||||
auto obj_val = obj_expr->ConstantValue();
|
||||
if (!obj_val) {
|
||||
return {};
|
||||
}
|
||||
return obj_val->Index(static_cast<size_t>(member->Index()));
|
||||
}
|
||||
|
||||
const sem::Constant* Resolver::EvaluateSwizzleValue(const sem::Expression* vec_expr,
|
||||
const sem::Type* type,
|
||||
const std::vector<uint32_t>& indices) {
|
||||
auto* vec_val = vec_expr->ConstantValue();
|
||||
if (!vec_val) {
|
||||
return nullptr;
|
||||
}
|
||||
if (indices.size() == 1) {
|
||||
return static_cast<const Constant*>(vec_val->Index(static_cast<size_t>(indices[0])));
|
||||
} else {
|
||||
auto values = utils::Transform(
|
||||
indices, [&](uint32_t i) { return static_cast<const Constant*>(vec_val->Index(i)); });
|
||||
return CreateComposite(*builder_, type, std::move(values));
|
||||
}
|
||||
}
|
||||
|
||||
const sem::Constant* Resolver::EvaluateBitcastValue(const sem::Expression*, const sem::Type*) {
|
||||
// TODO(crbug.com/tint/1581): Implement @const intrinsics
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const sem::Constant* Resolver::EvaluateBinaryValue(const sem::Expression*,
|
||||
const sem::Expression*,
|
||||
const IntrinsicTable::BinaryOperator&) {
|
||||
// TODO(crbug.com/tint/1581): Implement @const intrinsics
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const sem::Constant* Resolver::EvaluateUnaryValue(const sem::Expression*,
|
||||
const IntrinsicTable::UnaryOperator&) {
|
||||
// TODO(crbug.com/tint/1581): Implement @const intrinsics
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
utils::Result<const sem::Constant*> Resolver::ConvertValue(const sem::Constant* value,
|
||||
const sem::Type* target_ty,
|
||||
const Source& source) {
|
||||
if (value->Type() == target_ty) {
|
||||
return value;
|
||||
}
|
||||
auto conv = static_cast<const Constant*>(value)->Convert(*builder_, target_ty, source);
|
||||
if (!conv) {
|
||||
return utils::Failure;
|
||||
}
|
||||
return conv.Get();
|
||||
}
|
||||
|
||||
} // namespace tint::resolver
|
Loading…
Reference in New Issue