diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index fb513279cb..1b2e7d592b 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -398,7 +398,6 @@ libtint_source_set("libtint_core_all_src") { "resolver/intrinsic_table.inl", "resolver/resolver.cc", "resolver/resolver.h", - "resolver/resolver_constants.cc", "resolver/sem_helper.cc", "resolver/sem_helper.h", "resolver/uniformity.cc", @@ -1090,6 +1089,7 @@ if (tint_build_unittests) { "resolver/call_validation_test.cc", "resolver/compound_assignment_validation_test.cc", "resolver/compound_statement_test.cc", + "resolver/const_eval_test.cc", "resolver/control_block_validation_test.cc", "resolver/dependency_graph_test.cc", "resolver/entry_point_validation_test.cc", @@ -1104,7 +1104,6 @@ if (tint_build_unittests) { "resolver/ptr_ref_test.cc", "resolver/ptr_ref_validation_test.cc", "resolver/resolver_behavior_test.cc", - "resolver/resolver_constants_test.cc", "resolver/resolver_test.cc", "resolver/resolver_test_helper.cc", "resolver/resolver_test_helper.h", diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index 3e7154c80d..1d22340509 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -258,7 +258,6 @@ set(TINT_LIB_SRCS resolver/intrinsic_table.cc resolver/intrinsic_table.h resolver/intrinsic_table.inl - resolver/resolver_constants.cc resolver/resolver.cc resolver/resolver.h resolver/sem_helper.cc @@ -774,6 +773,7 @@ if(TINT_BUILD_TESTS) resolver/call_validation_test.cc resolver/compound_assignment_validation_test.cc resolver/compound_statement_test.cc + resolver/const_eval_test.cc resolver/control_block_validation_test.cc resolver/dependency_graph_test.cc resolver/entry_point_validation_test.cc @@ -789,7 +789,6 @@ if(TINT_BUILD_TESTS) resolver/ptr_ref_test.cc resolver/ptr_ref_validation_test.cc resolver/resolver_behavior_test.cc - resolver/resolver_constants_test.cc resolver/resolver_test_helper.cc resolver/resolver_test_helper.h resolver/resolver_test.cc diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc index fac2c7db78..95d8ba9996 100644 --- a/src/tint/resolver/const_eval.cc +++ b/src/tint/resolver/const_eval.cc @@ -14,6 +14,603 @@ #include "src/tint/resolver/const_eval.h" -#include "src/tint/sem/constant.h" +#include +#include +#include +#include +#include +#include -namespace tint::resolver::const_eval {} // namespace tint::resolver::const_eval +#include "src/tint/program_builder.h" +#include "src/tint/sem/abstract_float.h" +#include "src/tint/sem/abstract_int.h" +#include "src/tint/sem/array.h" +#include "src/tint/sem/bool.h" +#include "src/tint/sem/constant.h" +#include "src/tint/sem/f16.h" +#include "src/tint/sem/f32.h" +#include "src/tint/sem/i32.h" +#include "src/tint/sem/matrix.h" +#include "src/tint/sem/member_accessor_expression.h" +#include "src/tint/sem/type_constructor.h" +#include "src/tint/sem/u32.h" +#include "src/tint/sem/vector.h" +#include "src/tint/utils/compiler_macros.h" +#include "src/tint/utils/map.h" +#include "src/tint/utils/transform.h" + +using namespace tint::number_suffixes; // NOLINT + +namespace tint::resolver { + +namespace { + +/// TypeDispatch is a helper for calling the function `f`, passing a single zero-value argument of +/// the C++ type that corresponds to the sem::Type `type`. For example, calling `TypeDispatch()` +/// with a type of `sem::I32*` will call the function f with a single argument of `i32(0)`. +/// @returns the value returned by calling `f`. +/// @note `type` must be a scalar or abstract numeric type. Other types will not call `f`, and will +/// return the zero-initialized value of the return type for `f`. +template +auto TypeDispatch(const sem::Type* type, F&& f) { + return Switch( + type, // + [&](const sem::AbstractInt*) { return f(AInt(0)); }, // + [&](const sem::AbstractFloat*) { return f(AFloat(0)); }, // + [&](const sem::I32*) { return f(i32(0)); }, // + [&](const sem::U32*) { return f(u32(0)); }, // + [&](const sem::F32*) { return f(f32(0)); }, // + [&](const sem::F16*) { return f(f16(0)); }, // + [&](const sem::Bool*) { return f(static_cast(0)); }); +} + +/// @returns `value` if `T` is not a Number, otherwise ValueOf returns the inner value of the +/// Number. +template +inline auto ValueOf(T value) { + if constexpr (std::is_same_v, T>) { + return value; + } else { + return value.value; + } +} + +/// @returns true if `value` is a positive zero. +template +inline bool IsPositiveZero(T value) { + using N = UnwrapNumber; + return Number(value) == Number(0); // Considers sign bit +} + +/// Constant inherits from sem::Constant to add an private implementation method for conversion. +struct Constant : public sem::Constant { + /// Convert attempts to convert the constant value to the given type. On error, Convert() + /// creates a new diagnostic message and returns a Failure. + virtual utils::Result Convert(ProgramBuilder& builder, + const sem::Type* target_ty, + const Source& source) const = 0; +}; + +// Forward declaration +const Constant* CreateComposite(ProgramBuilder& builder, + const sem::Type* type, + std::vector elements); + +/// Element holds a single scalar or abstract-numeric value. +/// Element implements the Constant interface. +template +struct Element : Constant { + Element(const sem::Type* t, T v) : type(t), value(v) {} + ~Element() override = default; + const sem::Type* Type() const override { return type; } + std::variant Value() const override { + if constexpr (IsFloatingPoint>) { + return static_cast(value); + } else { + return static_cast(value); + } + } + const Constant* Index(size_t) const override { return nullptr; } + bool AllZero() const override { return IsPositiveZero(value); } + bool AnyZero() const override { return IsPositiveZero(value); } + bool AllEqual() const override { return true; } + size_t Hash() const override { return utils::Hash(type, ValueOf(value)); } + + utils::Result Convert(ProgramBuilder& builder, + const sem::Type* target_ty, + const Source& source) const override { + TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE); + if (target_ty == type) { + // If the types are identical, then no conversion is needed. + return this; + } + bool failed = false; + auto* res = TypeDispatch(target_ty, [&](auto zero_to) -> const Constant* { + // `T` is the source type, `value` is the source value. + // `TO` is the target type. + using TO = std::decay_t; + if constexpr (std::is_same_v) { + // [x -> bool] + return builder.create>(target_ty, !IsPositiveZero(value)); + } else if constexpr (std::is_same_v) { + // [bool -> x] + return builder.create>(target_ty, TO(value ? 1 : 0)); + } else if (auto conv = CheckedConvert(value)) { + // Conversion success + return builder.create>(target_ty, conv.Get()); + // --- Below this point are the failure cases --- + } else if constexpr (std::is_same_v || std::is_same_v) { + // [abstract-numeric -> x] - materialization failure + std::stringstream ss; + ss << "value " << value << " cannot be represented as "; + ss << "'" << builder.FriendlyName(target_ty) << "'"; + builder.Diagnostics().add_error(tint::diag::System::Resolver, ss.str(), source); + failed = true; + } else if constexpr (IsFloatingPoint>) { + // [x -> floating-point] - number not exactly representable + // https://www.w3.org/TR/WGSL/#floating-point-conversion + constexpr auto kInf = std::numeric_limits::infinity(); + switch (conv.Failure()) { + case ConversionFailure::kExceedsNegativeLimit: + return builder.create>(target_ty, TO(-kInf)); + case ConversionFailure::kExceedsPositiveLimit: + return builder.create>(target_ty, TO(kInf)); + } + } else { + // [x -> integer] - number not exactly representable + // https://www.w3.org/TR/WGSL/#floating-point-conversion + switch (conv.Failure()) { + case ConversionFailure::kExceedsNegativeLimit: + return builder.create>(target_ty, TO(TO::kLowest)); + case ConversionFailure::kExceedsPositiveLimit: + return builder.create>(target_ty, TO(TO::kHighest)); + } + } + return nullptr; // Expression is not constant. + }); + if (failed) { + // A diagnostic error has been raised, and resolving should abort. + return utils::Failure; + } + return res; + TINT_END_DISABLE_WARNING(UNREACHABLE_CODE); + } + + sem::Type const* const type; + const T value; +}; + +/// Splat holds a single Constant value, duplicated as all children. +/// Splat is used for zero-initializers, 'splat' constructors, or constructors where each element is +/// identical. Splat may be of a vector, matrix or array type. +/// Splat implements the Constant interface. +struct Splat : Constant { + Splat(const sem::Type* t, const Constant* e, size_t n) : type(t), el(e), count(n) {} + ~Splat() override = default; + const sem::Type* Type() const override { return type; } + std::variant Value() const override { return {}; } + const Constant* Index(size_t i) const override { return i < count ? el : nullptr; } + bool AllZero() const override { return el->AllZero(); } + bool AnyZero() const override { return el->AnyZero(); } + bool AllEqual() const override { return true; } + size_t Hash() const override { return utils::Hash(type, el->Hash(), count); } + + utils::Result Convert(ProgramBuilder& builder, + const sem::Type* target_ty, + const Source& source) const override { + // Convert the single splatted element type. + auto conv_el = el->Convert(builder, sem::Type::ElementOf(target_ty), source); + if (!conv_el) { + return utils::Failure; + } + if (!conv_el.Get()) { + return nullptr; + } + return builder.create(target_ty, conv_el.Get(), count); + } + + sem::Type const* const type; + const Constant* el; + const size_t count; +}; + +/// Composite holds a number of mixed child Constant values. +/// Composite may be of a vector, matrix or array type. +/// If each element is the same type and value, then a Splat would be a more efficient constant +/// implementation. Use CreateComposite() to create the appropriate Constant type. +/// Composite implements the Constant interface. +struct Composite : Constant { + Composite(const sem::Type* t, std::vector els, bool all_0, bool any_0) + : type(t), elements(std::move(els)), all_zero(all_0), any_zero(any_0), hash(CalcHash()) {} + ~Composite() override = default; + const sem::Type* Type() const override { return type; } + std::variant Value() const override { return {}; } + const Constant* Index(size_t i) const override { + return i < elements.size() ? elements[i] : nullptr; + } + bool AllZero() const override { return all_zero; } + bool AnyZero() const override { return any_zero; } + bool AllEqual() const override { return false; /* otherwise this should be a Splat */ } + size_t Hash() const override { return hash; } + + utils::Result Convert(ProgramBuilder& builder, + const sem::Type* target_ty, + const Source& source) const override { + // Convert each of the composite element types. + auto* el_ty = sem::Type::ElementOf(target_ty); + std::vector conv_els; + conv_els.reserve(elements.size()); + for (auto* el : elements) { + auto conv_el = el->Convert(builder, el_ty, source); + if (!conv_el) { + return utils::Failure; + } + if (!conv_el.Get()) { + return nullptr; + } + conv_els.emplace_back(conv_el.Get()); + } + return CreateComposite(builder, target_ty, std::move(conv_els)); + } + + size_t CalcHash() { + auto h = utils::Hash(type, all_zero, any_zero); + for (auto* el : elements) { + utils::HashCombine(&h, el->Hash()); + } + return h; + } + + sem::Type const* const type; + const std::vector elements; + const bool all_zero; + const bool any_zero; + const size_t hash; +}; + +/// CreateElement constructs and returns an Element. +template +const Constant* CreateElement(ProgramBuilder& builder, const sem::Type* t, T v) { + return builder.create>(t, v); +} + +/// ZeroValue returns a Constant for the zero-value of the type `type`. +const Constant* ZeroValue(ProgramBuilder& builder, const sem::Type* type) { + return Switch( + type, // + [&](const sem::Vector* v) -> const Constant* { + auto* zero_el = ZeroValue(builder, v->type()); + return builder.create(type, zero_el, v->Width()); + }, + [&](const sem::Matrix* m) -> const Constant* { + auto* zero_el = ZeroValue(builder, m->ColumnType()); + return builder.create(type, zero_el, m->columns()); + }, + [&](const sem::Array* a) -> const Constant* { + if (auto* zero_el = ZeroValue(builder, a->ElemType())) { + return builder.create(type, zero_el, a->Count()); + } + return nullptr; + }, + [&](const sem::Struct* s) -> const Constant* { + std::unordered_map zero_by_type; + std::vector zeros; + zeros.reserve(s->Members().size()); + for (auto* member : s->Members()) { + auto* zero = utils::GetOrCreate(zero_by_type, member->Type(), + [&] { return ZeroValue(builder, member->Type()); }); + if (!zero) { + return nullptr; + } + zeros.emplace_back(zero); + } + if (zero_by_type.size() == 1) { + // All members were of the same type, so the zero value is the same for all members. + return builder.create(type, zeros[0], s->Members().size()); + } + return CreateComposite(builder, s, std::move(zeros)); + }, + [&](Default) -> const Constant* { + return TypeDispatch(type, [&](auto zero) -> const Constant* { + return CreateElement(builder, type, zero); + }); + }); +} + +/// Equal returns true if the constants `a` and `b` are of the same type and value. +bool Equal(const sem::Constant* a, const sem::Constant* b) { + if (a->Hash() != b->Hash()) { + return false; + } + if (a->Type() != b->Type()) { + return false; + } + return Switch( + a->Type(), // + [&](const sem::Vector* vec) { + for (size_t i = 0; i < vec->Width(); i++) { + if (!Equal(a->Index(i), b->Index(i))) { + return false; + } + } + return true; + }, + [&](const sem::Matrix* mat) { + for (size_t i = 0; i < mat->columns(); i++) { + if (!Equal(a->Index(i), b->Index(i))) { + return false; + } + } + return true; + }, + [&](const sem::Array* arr) { + for (size_t i = 0; i < arr->Count(); i++) { + if (!Equal(a->Index(i), b->Index(i))) { + return false; + } + } + return true; + }, + [&](Default) { return a->Value() == b->Value(); }); +} + +/// CreateComposite is used to construct a constant of a vector, matrix or array type. +/// CreateComposite examines the element values and will return either a Composite or a Splat, +/// depending on the element types and values. +const Constant* CreateComposite(ProgramBuilder& builder, + const sem::Type* type, + std::vector elements) { + if (elements.size() == 0) { + return nullptr; + } + bool any_zero = false; + bool all_zero = true; + bool all_equal = true; + auto* first = elements.front(); + for (auto* el : elements) { + if (!el) { + return nullptr; + } + if (!any_zero && el->AnyZero()) { + any_zero = true; + } + if (all_zero && !el->AllZero()) { + all_zero = false; + } + if (all_equal && el != first) { + if (!Equal(el, first)) { + all_equal = false; + } + } + } + if (all_equal) { + return builder.create(type, elements[0], elements.size()); + } else { + return builder.create(type, std::move(elements), all_zero, any_zero); + } +} + +} // namespace + +ConstEval::ConstEval(ProgramBuilder& b) : builder(b) {} + +const sem::Constant* ConstEval::Literal(const sem::Type* ty, + const ast::LiteralExpression* literal) { + return Switch( + literal, + [&](const ast::BoolLiteralExpression* lit) { + return CreateElement(builder, ty, lit->value); + }, + [&](const ast::IntLiteralExpression* lit) -> const Constant* { + switch (lit->suffix) { + case ast::IntLiteralExpression::Suffix::kNone: + return CreateElement(builder, ty, AInt(lit->value)); + case ast::IntLiteralExpression::Suffix::kI: + return CreateElement(builder, ty, i32(lit->value)); + case ast::IntLiteralExpression::Suffix::kU: + return CreateElement(builder, ty, u32(lit->value)); + } + return nullptr; + }, + [&](const ast::FloatLiteralExpression* lit) -> const Constant* { + switch (lit->suffix) { + case ast::FloatLiteralExpression::Suffix::kNone: + return CreateElement(builder, ty, AFloat(lit->value)); + case ast::FloatLiteralExpression::Suffix::kF: + return CreateElement(builder, ty, f32(lit->value)); + case ast::FloatLiteralExpression::Suffix::kH: + return CreateElement(builder, ty, f16(lit->value)); + } + return nullptr; + }); +} + +const sem::Constant* ConstEval::CtorOrConv(const sem::Type* ty, + const std::vector& args) { + // For zero value init, return 0s + if (args.empty()) { + return ZeroValue(builder, ty); + } + + if (auto* el_ty = sem::Type::ElementOf(ty); el_ty && args.size() == 1) { + // Type constructor or conversion that takes a single argument. + auto& src = args[0]->Declaration()->source; + auto* arg = static_cast(args[0]->ConstantValue()); + if (!arg) { + return nullptr; // Single argument is not constant. + } + + if (ty->is_scalar()) { // Scalar type conversion: i32(x), u32(x), bool(x), etc + return Convert(el_ty, arg, src).Get(); + } + + if (arg->Type() == el_ty) { + // Argument type matches function type. This is a splat. + auto splat = [&](size_t n) { return builder.create(ty, arg, n); }; + return Switch( + ty, // + [&](const sem::Vector* v) { return splat(v->Width()); }, + [&](const sem::Matrix* m) { return splat(m->columns()); }, + [&](const sem::Array* a) { return splat(a->Count()); }); + } + + // Argument type and function type mismatch. This is a type conversion. + if (auto conv = Convert(ty, arg, src)) { + return conv.Get(); + } + + return nullptr; + } + + // Helper for pushing all the argument constants to `els`. + auto args_as_constants = [&] { + return utils::Transform( + args, [&](auto* expr) { return static_cast(expr->ConstantValue()); }); + }; + + // Multiple arguments. Must be a type constructor. + + return Switch( + ty, // What's the target type being constructed? + [&](const sem::Vector*) -> const Constant* { + // Vector can be constructed with a mix of scalars / abstract numerics and smaller + // vectors. + std::vector els; + els.reserve(args.size()); + for (auto* expr : args) { + auto* arg = static_cast(expr->ConstantValue()); + if (!arg) { + return nullptr; + } + auto* arg_ty = arg->Type(); + if (auto* arg_vec = arg_ty->As()) { + // Extract out vector elements. + for (uint32_t i = 0; i < arg_vec->Width(); i++) { + auto* el = static_cast(arg->Index(i)); + if (!el) { + return nullptr; + } + els.emplace_back(el); + } + } else { + els.emplace_back(arg); + } + } + return CreateComposite(builder, ty, std::move(els)); + }, + [&](const sem::Matrix* m) -> const Constant* { + // Matrix can be constructed with a set of scalars / abstract numerics, or column + // vectors. + if (args.size() == m->columns() * m->rows()) { + // Matrix built from scalars / abstract numerics + std::vector els; + els.reserve(args.size()); + for (uint32_t c = 0; c < m->columns(); c++) { + std::vector column; + column.reserve(m->rows()); + for (uint32_t r = 0; r < m->rows(); r++) { + auto* arg = + static_cast(args[r + c * m->rows()]->ConstantValue()); + if (!arg) { + return nullptr; + } + column.emplace_back(arg); + } + els.push_back(CreateComposite(builder, m->ColumnType(), std::move(column))); + } + return CreateComposite(builder, ty, std::move(els)); + } + // Matrix built from column vectors + return CreateComposite(builder, ty, args_as_constants()); + }, + [&](const sem::Array*) { + // Arrays must be constructed using a list of elements + return CreateComposite(builder, ty, args_as_constants()); + }, + [&](const sem::Struct*) { + // Structures must be constructed using a list of elements + return CreateComposite(builder, ty, args_as_constants()); + }); +} + +const sem::Constant* ConstEval::Index(const sem::Expression* obj_expr, + const sem::Expression* idx_expr) { + auto obj_val = obj_expr->ConstantValue(); + if (!obj_val) { + return {}; + } + + auto idx_val = idx_expr->ConstantValue(); + if (!idx_val) { + return {}; + } + + uint32_t el_count = 0; + sem::Type::ElementOf(obj_val->Type(), &el_count); + + AInt idx = idx_val->As(); + if (idx < 0 || idx >= el_count) { + auto clamped = std::min(std::max(idx, 0), el_count - 1); + AddWarning("index " + std::to_string(idx) + " out of bounds [0.." + + std::to_string(el_count - 1) + "]. Clamping index to " + + std::to_string(clamped), + idx_expr->Declaration()->source); + idx = clamped; + } + + return obj_val->Index(static_cast(idx)); +} + +const sem::Constant* ConstEval::MemberAccess(const sem::Expression* obj_expr, + const sem::StructMember* member) { + auto obj_val = obj_expr->ConstantValue(); + if (!obj_val) { + return {}; + } + return obj_val->Index(static_cast(member->Index())); +} + +const sem::Constant* ConstEval::Swizzle(const sem::Type* ty, + const sem::Expression* vec_expr, + const std::vector& indices) { + auto* vec_val = vec_expr->ConstantValue(); + if (!vec_val) { + return nullptr; + } + if (indices.size() == 1) { + return static_cast(vec_val->Index(static_cast(indices[0]))); + } else { + auto values = utils::Transform(indices, [&](uint32_t i) { + return static_cast(vec_val->Index(static_cast(i))); + }); + return CreateComposite(builder, ty, std::move(values)); + } +} + +const sem::Constant* ConstEval::Bitcast(const sem::Type*, const sem::Expression*) { + // TODO(crbug.com/tint/1581): Implement @const intrinsics + return nullptr; +} + +utils::Result ConstEval::Convert(const sem::Type* target_ty, + const sem::Constant* value, + const Source& source) { + if (value->Type() == target_ty) { + return value; + } + auto conv = static_cast(value)->Convert(builder, target_ty, source); + if (!conv) { + return utils::Failure; + } + return conv.Get(); +} + +void ConstEval::AddError(const std::string& msg, const Source& source) const { + builder.Diagnostics().add_error(diag::System::Resolver, msg, source); +} + +void ConstEval::AddWarning(const std::string& msg, const Source& source) const { + builder.Diagnostics().add_warning(diag::System::Resolver, msg, source); +} + +} // namespace tint::resolver diff --git a/src/tint/resolver/const_eval.h b/src/tint/resolver/const_eval.h index 3792e353a3..540908bb1d 100644 --- a/src/tint/resolver/const_eval.h +++ b/src/tint/resolver/const_eval.h @@ -16,24 +16,111 @@ #define SRC_TINT_RESOLVER_CONST_EVAL_H_ #include +#include +#include + +#include "src/tint/utils/result.h" // Forward declarations namespace tint { class ProgramBuilder; +class Source; } // namespace tint - -// Forward declarations +namespace tint::ast { +class LiteralExpression; +} // namespace tint::ast namespace tint::sem { class Constant; +class Expression; +class StructMember; +class Type; } // namespace tint::sem -namespace tint::resolver::const_eval { +namespace tint::resolver { -/// Typedef for a constant evaluation function -using Function = const sem::Constant*(ProgramBuilder& builder, - sem::Constant const* const* args, - size_t num_args); +/// ConstEval performs shader creation-time (constant expression) expression evaluation. +/// Methods are called from the resolver, either directly or via member-function pointers indexed by +/// the IntrinsicTable. All child-expression nodes are guaranteed to have been already resolved +/// before calling a method to evaluate an expression's value. +class ConstEval { + public: + /// Typedef for a constant evaluation function + using Function = const sem::Constant* (ConstEval::*)(const sem::Type* result_ty, + sem::Expression const* const* args, + size_t num_args); -} // namespace tint::resolver::const_eval + /// The result type of a method that may raise a diagnostic error and the caller should abort + /// resolving. Can be one of three distinct values: + /// * A non-null sem::Constant pointer. Returned when a expression resolves to a creation time + /// value. + /// * A null sem::Constant pointer. Returned when a expression cannot resolve to a creation time + /// value, but is otherwise legal. + /// * `utils::Failure`. Returned when there was a resolver error. In this situation the method + /// will have already reported a diagnostic error message, and the caller should abort + /// resolving. + using ConstantResult = utils::Result; + + /// Constructor + /// @param b the program builder + explicit ConstEval(ProgramBuilder& b); + + //////////////////////////////////////////////////////////////////////////////////////////////// + // Constant value evaluation methods, to be called directly from Resolver + //////////////////////////////////////////////////////////////////////////////////////////////// + + /// @param ty the target type + /// @param expr the input expression + /// @return the bit-cast of the given expression to the given type, or null if the value cannot + /// be calculated + const sem::Constant* Bitcast(const sem::Type* ty, const sem::Expression* expr); + + /// @param ty the target type + /// @param args the input arguments + /// @return the resulting type constructor or conversion, or null if the value cannot be + /// calculated + const sem::Constant* CtorOrConv(const sem::Type* ty, + const std::vector& args); + + /// @param obj the object being indexed + /// @param idx the index expression + /// @return the result of the index, or null if the value cannot be calculated + const sem::Constant* Index(const sem::Expression* obj, const sem::Expression* idx); + + /// @param ty the result type + /// @param lit the literal AST node + /// @return the constant value of the literal + const sem::Constant* Literal(const sem::Type* ty, const ast::LiteralExpression* lit); + + /// @param obj the object being accessed + /// @param member the member + /// @return the result of the member access, or null if the value cannot be calculated + const sem::Constant* MemberAccess(const sem::Expression* obj, const sem::StructMember* member); + + /// @param ty the result type + /// @param vector the vector being swizzled + /// @param indices the swizzle indices + /// @return the result of the swizzle, or null if the value cannot be calculated + const sem::Constant* Swizzle(const sem::Type* ty, + const sem::Expression* vector, + const std::vector& indices); + + /// Convert the `value` to `target_type` + /// @param ty the result type + /// @param value the value being converted + /// @param source the source location of the conversion + /// @return the converted value, or null if the value cannot be calculated + ConstantResult Convert(const sem::Type* ty, const sem::Constant* value, const Source& source); + + private: + /// Adds the given error message to the diagnostics + void AddError(const std::string& msg, const Source& source) const; + + /// Adds the given warning message to the diagnostics + void AddWarning(const std::string& msg, const Source& source) const; + + ProgramBuilder& builder; +}; + +} // namespace tint::resolver #endif // SRC_TINT_RESOLVER_CONST_EVAL_H_ diff --git a/src/tint/resolver/resolver_constants_test.cc b/src/tint/resolver/const_eval_test.cc similarity index 99% rename from src/tint/resolver/resolver_constants_test.cc rename to src/tint/resolver/const_eval_test.cc index 1774a165ca..1afb8d5683 100644 --- a/src/tint/resolver/resolver_constants_test.cc +++ b/src/tint/resolver/const_eval_test.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "src/tint/resolver/resolver.h" - #include #include "gtest/gtest.h" diff --git a/src/tint/resolver/intrinsic_table.cc b/src/tint/resolver/intrinsic_table.cc index d746a6c2fb..9f4210f861 100644 --- a/src/tint/resolver/intrinsic_table.cc +++ b/src/tint/resolver/intrinsic_table.cc @@ -854,7 +854,7 @@ struct OverloadInfo { /// The flags for the overload OverloadFlags flags; /// The function used to evaluate the overload at shader-creation time. - const_eval::Function* const const_eval_fn; + ConstEval::Function const const_eval_fn; }; /// IntrinsicInfo describes a builtin function or operator overload diff --git a/src/tint/resolver/intrinsic_table.h b/src/tint/resolver/intrinsic_table.h index d10f7bb14b..269a935b90 100644 --- a/src/tint/resolver/intrinsic_table.h +++ b/src/tint/resolver/intrinsic_table.h @@ -47,7 +47,7 @@ class IntrinsicTable { /// The semantic info for the builtin const sem::Builtin* sem = nullptr; /// The constant evaluation function - const_eval::Function* const_eval_fn = nullptr; + ConstEval::Function const_eval_fn = nullptr; }; /// UnaryOperator describes a resolved unary operator @@ -56,6 +56,8 @@ class IntrinsicTable { const sem::Type* result = nullptr; /// The type of the parameter of the unary operator const sem::Type* parameter = nullptr; + /// The constant evaluation function + ConstEval::Function const_eval_fn = nullptr; }; /// BinaryOperator describes a resolved binary operator @@ -66,6 +68,8 @@ class IntrinsicTable { const sem::Type* lhs = nullptr; /// The type of RHS parameter of the binary operator const sem::Type* rhs = nullptr; + /// The constant evaluation function + ConstEval::Function const_eval_fn = nullptr; }; /// Lookup looks for the builtin overload with the given signature, raising an error diagnostic diff --git a/src/tint/resolver/intrinsic_table.inl.tmpl b/src/tint/resolver/intrinsic_table.inl.tmpl index 663013ed7f..51d04d454f 100644 --- a/src/tint/resolver/intrinsic_table.inl.tmpl +++ b/src/tint/resolver/intrinsic_table.inl.tmpl @@ -103,7 +103,7 @@ constexpr OverloadInfo kOverloads[] = { {{- end }} {{- if $o.IsDeprecated}}, OverloadFlag::kIsDeprecated{{end }}), /* const eval */ -{{- if $o.ConstEvalFunction }} const_eval::{{$o.ConstEvalFunction}}, +{{- if $o.ConstEvalFunction }} &ConstEval::{{$o.ConstEvalFunction}}, {{- else }} nullptr, {{- end }} }, diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index 0df9a8f0d1..e364fa6dfe 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -91,6 +91,7 @@ namespace tint::resolver { Resolver::Resolver(ProgramBuilder* builder) : builder_(builder), diagnostics_(builder->Diagnostics()), + const_eval_(*builder), intrinsic_table_(IntrinsicTable::Create(*builder)), sem_(builder, dependencies_), validator_(builder, sem_) {} @@ -1313,7 +1314,7 @@ const sem::Expression* Resolver::Materialize(const sem::Expression* expr, << ") called on expression with no constant value"; return nullptr; } - auto materialized_val = ConvertValue(expr_val, target_ty, decl->source); + auto materialized_val = const_eval_.Convert(target_ty, expr_val, decl->source); if (!materialized_val) { // ConvertValue() has already failed and raised an diagnostic error. return nullptr; @@ -1422,7 +1423,7 @@ sem::Expression* Resolver::IndexAccessor(const ast::IndexAccessorExpression* exp ty = builder_->create(ty, ref->StorageClass(), ref->Access()); } - auto val = EvaluateIndexValue(obj, idx); + auto val = const_eval_.Index(obj, idx); bool has_side_effects = idx->HasSideEffects() || obj->HasSideEffects(); auto* sem = builder_->create( expr, ty, obj, idx, current_statement_, std::move(val), has_side_effects, @@ -1441,7 +1442,7 @@ sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) { return nullptr; } - auto val = EvaluateBitcastValue(inner, ty); + auto val = const_eval_.Bitcast(ty, inner); auto* sem = builder_->create(expr, ty, current_statement_, std::move(val), inner->HasSideEffects()); @@ -1489,7 +1490,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { if (!MaterializeArguments(args, call_target)) { return nullptr; } - auto val = EvaluateCtorOrConvValue(args, call_target->ReturnType()); + auto val = const_eval_.CtorOrConv(call_target->ReturnType(), args); return builder_->create(expr, call_target, std::move(args), current_statement_, val, has_side_effects); }; @@ -1528,7 +1529,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { if (!MaterializeArguments(args, call_target)) { return nullptr; } - auto val = EvaluateCtorOrConvValue(args, arr); + auto val = const_eval_.CtorOrConv(arr, args); return builder_->create(expr, call_target, std::move(args), current_statement_, val, has_side_effects); }, @@ -1550,7 +1551,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { if (!MaterializeArguments(args, call_target)) { return nullptr; } - auto val = EvaluateCtorOrConvValue(args, str); + auto val = const_eval_.CtorOrConv(str, args); return builder_->create(expr, call_target, std::move(args), current_statement_, std::move(val), has_side_effects); @@ -1682,28 +1683,17 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr, } // If the builtin is @const, and all arguments have constant values, evaluate the builtin now. - const sem::Constant* constant = nullptr; + const sem::Constant* value = nullptr; if (builtin.const_eval_fn) { - std::vector values(args.size()); - bool is_const = true; // all arguments have constant values - for (size_t i = 0; i < values.size(); i++) { - if (auto v = args[i]->ConstantValue()) { - values[i] = std::move(v); - } else { - is_const = false; - break; - } - } - if (is_const) { - constant = builtin.const_eval_fn(*builder_, values.data(), args.size()); - } + value = (const_eval_.*builtin.const_eval_fn)(builtin.sem->ReturnType(), args.data(), + args.size()); } bool has_side_effects = builtin.sem->HasSideEffects() || std::any_of(args.begin(), args.end(), [](auto* e) { return e->HasSideEffects(); }); auto* call = builder_->create(expr, builtin.sem, std::move(args), current_statement_, - constant, has_side_effects); + value, has_side_effects); if (current_function_) { current_function_->AddDirectlyCalledBuiltin(builtin.sem); @@ -1856,7 +1846,7 @@ sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) { return nullptr; } - auto val = EvaluateLiteralValue(literal, ty); + auto val = const_eval_.Literal(ty, literal); return builder_->create(literal, ty, current_statement_, std::move(val), /* has_side_effects */ false); } @@ -1976,7 +1966,7 @@ sem::Expression* Resolver::MemberAccessor(const ast::MemberAccessorExpression* e ret = builder_->create(ret, ref->StorageClass(), ref->Access()); } - auto* val = EvaluateMemberAccessValue(object, member); + auto* val = const_eval_.MemberAccess(object, member); return builder_->create(expr, ret, current_statement_, val, object, member, has_side_effects, source_var); } @@ -2044,7 +2034,7 @@ sem::Expression* Resolver::MemberAccessor(const ast::MemberAccessorExpression* e // the swizzle. ret = builder_->create(vec->type(), static_cast(size)); } - auto* val = EvaluateSwizzleValue(object, ret, swizzle); + auto* val = const_eval_.Swizzle(ret, object, swizzle); return builder_->create(expr, ret, current_statement_, val, object, std::move(swizzle), has_side_effects, source_var); } @@ -2078,9 +2068,14 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) { } } - auto* val = EvaluateBinaryValue(lhs, rhs, op); + const sem::Constant* value = nullptr; + if (op.const_eval_fn) { + const sem::Expression* args[] = {lhs, rhs}; + value = (const_eval_.*op.const_eval_fn)(op.result, args, 2u); + } + bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects(); - auto* sem = builder_->create(expr, op.result, current_statement_, val, + auto* sem = builder_->create(expr, op.result, current_statement_, value, has_side_effects); sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors(); @@ -2096,7 +2091,7 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) { const sem::Type* ty = nullptr; const sem::Variable* source_var = nullptr; - const sem::Constant* val = nullptr; + const sem::Constant* value = nullptr; switch (unary->op) { case ast::UnaryOp::kAddressOf: @@ -2149,12 +2144,14 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) { } } ty = op.result; - val = EvaluateUnaryValue(expr, op); + if (op.const_eval_fn) { + value = (const_eval_.*op.const_eval_fn)(ty, &expr, 1u); + } break; } } - auto* sem = builder_->create(unary, ty, current_statement_, val, + auto* sem = builder_->create(unary, ty, current_statement_, value, expr->HasSideEffects(), source_var); sem->Behaviors() = expr->Behaviors(); return sem; diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index 20f487c6d8..245512749f 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -24,6 +24,7 @@ #include #include "src/tint/program_builder.h" +#include "src/tint/resolver/const_eval.h" #include "src/tint/resolver/dependency_graph.h" #include "src/tint/resolver/intrinsic_table.h" #include "src/tint/resolver/sem_helper.h" @@ -34,7 +35,6 @@ #include "src/tint/sem/constant.h" #include "src/tint/sem/function.h" #include "src/tint/sem/struct.h" -#include "src/tint/utils/result.h" #include "src/tint/utils/unique_vector.h" // Forward declarations @@ -204,46 +204,6 @@ class Resolver { sem::Expression* MemberAccessor(const ast::MemberAccessorExpression*); sem::Expression* UnaryOp(const ast::UnaryOpExpression*); - //////////////////////////////////////////////////////////////////////////////////////////////// - /// Constant value evaluation methods - /// - /// These methods are called from the expression resolving methods, and so child-expression - /// nodes are guaranteed to have been already resolved and any constant values calculated. - //////////////////////////////////////////////////////////////////////////////////////////////// - const sem::Constant* EvaluateBinaryValue(const sem::Expression* lhs, - const sem::Expression* rhs, - const IntrinsicTable::BinaryOperator&); - const sem::Constant* EvaluateBitcastValue(const sem::Expression*, const sem::Type*); - const sem::Constant* EvaluateCtorOrConvValue( - const std::vector& args, - const sem::Type* ty); // Note: ty is not an array or structure - const sem::Constant* EvaluateIndexValue(const sem::Expression* obj, const sem::Expression* idx); - const sem::Constant* EvaluateLiteralValue(const ast::LiteralExpression*, const sem::Type*); - const sem::Constant* EvaluateMemberAccessValue(const sem::Expression* obj, - const sem::StructMember* member); - const sem::Constant* EvaluateSwizzleValue(const sem::Expression* vector, - const sem::Type* type, - const std::vector& indices); - const sem::Constant* EvaluateUnaryValue(const sem::Expression*, - const IntrinsicTable::UnaryOperator&); - - /// The result type of a ConstantEvaluation method. - /// Can be one of three distinct values: - /// * A non-null sem::Constant pointer. Returned when a expression resolves to a creation time - /// value. - /// * A null sem::Constant pointer. Returned when a expression cannot resolve to a creation time - /// value, but is otherwise legal. - /// * `utils::Failure`. Returned when there was a resolver error. In this situation the method - /// will have already reported a diagnostic error message, and the caller should abort - /// resolving. - using ConstantResult = utils::Result; - - /// Convert the `value` to `target_type` - /// @return the converted value - ConstantResult ConvertValue(const sem::Constant* value, - const sem::Type* target_type, - const Source& source); - /// If `expr` is not of an abstract-numeric type, then Materialize() will just return `expr`. /// If `expr` is of an abstract-numeric type: /// * Materialize will create and return a sem::Materialize node wrapping `expr`. @@ -449,6 +409,7 @@ class Resolver { ProgramBuilder* const builder_; diag::List& diagnostics_; + ConstEval const_eval_; std::unique_ptr const intrinsic_table_; DependencyGraph dependencies_; SemHelper sem_; diff --git a/src/tint/resolver/resolver_constants.cc b/src/tint/resolver/resolver_constants.cc deleted file mode 100644 index c5798be254..0000000000 --- a/src/tint/resolver/resolver_constants.cc +++ /dev/null @@ -1,605 +0,0 @@ -// Copyright 2021 The Tint Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "src/tint/resolver/resolver.h" - -#include - -#include "src/tint/sem/abstract_float.h" -#include "src/tint/sem/abstract_int.h" -#include "src/tint/sem/constant.h" -#include "src/tint/sem/member_accessor_expression.h" -#include "src/tint/sem/type_constructor.h" -#include "src/tint/utils/compiler_macros.h" -#include "src/tint/utils/map.h" -#include "src/tint/utils/transform.h" - -using namespace tint::number_suffixes; // NOLINT - -namespace tint::resolver { - -namespace { - -/// TypeDispatch is a helper for calling the function `f`, passing a single zero-value argument of -/// the C++ type that corresponds to the sem::Type `type`. For example, calling `TypeDispatch()` -/// with a type of `sem::I32*` will call the function f with a single argument of `i32(0)`. -/// @returns the value returned by calling `f`. -/// @note `type` must be a scalar or abstract numeric type. Other types will not call `f`, and will -/// return the zero-initialized value of the return type for `f`. -template -auto TypeDispatch(const sem::Type* type, F&& f) { - return Switch( - type, // - [&](const sem::AbstractInt*) { return f(AInt(0)); }, // - [&](const sem::AbstractFloat*) { return f(AFloat(0)); }, // - [&](const sem::I32*) { return f(i32(0)); }, // - [&](const sem::U32*) { return f(u32(0)); }, // - [&](const sem::F32*) { return f(f32(0)); }, // - [&](const sem::F16*) { return f(f16(0)); }, // - [&](const sem::Bool*) { return f(static_cast(0)); }); -} - -/// @returns `value` if `T` is not a Number, otherwise ValueOf returns the inner value of the -/// Number. -template -inline auto ValueOf(T value) { - if constexpr (std::is_same_v, T>) { - return value; - } else { - return value.value; - } -} - -/// @returns true if `value` is a positive zero. -template -inline bool IsPositiveZero(T value) { - using N = UnwrapNumber; - return Number(value) == Number(0); // Considers sign bit -} - -/// Constant inherits from sem::Constant to add an private implementation method for conversion. -struct Constant : public sem::Constant { - /// Convert attempts to convert the constant value to the given type. On error, Convert() - /// creates a new diagnostic message and returns a Failure. - virtual utils::Result Convert(ProgramBuilder& builder, - const sem::Type* target_ty, - const Source& source) const = 0; -}; - -// Forward declaration -const Constant* CreateComposite(ProgramBuilder& builder, - const sem::Type* type, - std::vector elements); - -/// Element holds a single scalar or abstract-numeric value. -/// Element implements the Constant interface. -template -struct Element : Constant { - Element(const sem::Type* t, T v) : type(t), value(v) {} - ~Element() override = default; - const sem::Type* Type() const override { return type; } - std::variant Value() const override { - if constexpr (IsFloatingPoint>) { - return static_cast(value); - } else { - return static_cast(value); - } - } - const Constant* Index(size_t) const override { return nullptr; } - bool AllZero() const override { return IsPositiveZero(value); } - bool AnyZero() const override { return IsPositiveZero(value); } - bool AllEqual() const override { return true; } - size_t Hash() const override { return utils::Hash(type, ValueOf(value)); } - - utils::Result Convert(ProgramBuilder& builder, - const sem::Type* target_ty, - const Source& source) const override { - TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE); - if (target_ty == type) { - // If the types are identical, then no conversion is needed. - return this; - } - bool failed = false; - auto* res = TypeDispatch(target_ty, [&](auto zero_to) -> const Constant* { - // `T` is the source type, `value` is the source value. - // `TO` is the target type. - using TO = std::decay_t; - if constexpr (std::is_same_v) { - // [x -> bool] - return builder.create>(target_ty, !IsPositiveZero(value)); - } else if constexpr (std::is_same_v) { - // [bool -> x] - return builder.create>(target_ty, TO(value ? 1 : 0)); - } else if (auto conv = CheckedConvert(value)) { - // Conversion success - return builder.create>(target_ty, conv.Get()); - // --- Below this point are the failure cases --- - } else if constexpr (std::is_same_v || std::is_same_v) { - // [abstract-numeric -> x] - materialization failure - std::stringstream ss; - ss << "value " << value << " cannot be represented as "; - ss << "'" << builder.FriendlyName(target_ty) << "'"; - builder.Diagnostics().add_error(tint::diag::System::Resolver, ss.str(), source); - failed = true; - } else if constexpr (IsFloatingPoint>) { - // [x -> floating-point] - number not exactly representable - // https://www.w3.org/TR/WGSL/#floating-point-conversion - constexpr auto kInf = std::numeric_limits::infinity(); - switch (conv.Failure()) { - case ConversionFailure::kExceedsNegativeLimit: - return builder.create>(target_ty, TO(-kInf)); - case ConversionFailure::kExceedsPositiveLimit: - return builder.create>(target_ty, TO(kInf)); - } - } else { - // [x -> integer] - number not exactly representable - // https://www.w3.org/TR/WGSL/#floating-point-conversion - switch (conv.Failure()) { - case ConversionFailure::kExceedsNegativeLimit: - return builder.create>(target_ty, TO(TO::kLowest)); - case ConversionFailure::kExceedsPositiveLimit: - return builder.create>(target_ty, TO(TO::kHighest)); - } - } - return nullptr; // Expression is not constant. - }); - if (failed) { - // A diagnostic error has been raised, and resolving should abort. - return utils::Failure; - } - return res; - TINT_END_DISABLE_WARNING(UNREACHABLE_CODE); - } - - sem::Type const* const type; - const T value; -}; - -/// Splat holds a single Constant value, duplicated as all children. -/// Splat is used for zero-initializers, 'splat' constructors, or constructors where each element is -/// identical. Splat may be of a vector, matrix or array type. -/// Splat implements the Constant interface. -struct Splat : Constant { - Splat(const sem::Type* t, const Constant* e, size_t n) : type(t), el(e), count(n) {} - ~Splat() override = default; - const sem::Type* Type() const override { return type; } - std::variant Value() const override { return {}; } - const Constant* Index(size_t i) const override { return i < count ? el : nullptr; } - bool AllZero() const override { return el->AllZero(); } - bool AnyZero() const override { return el->AnyZero(); } - bool AllEqual() const override { return true; } - size_t Hash() const override { return utils::Hash(type, el->Hash(), count); } - - utils::Result Convert(ProgramBuilder& builder, - const sem::Type* target_ty, - const Source& source) const override { - // Convert the single splatted element type. - auto conv_el = el->Convert(builder, sem::Type::ElementOf(target_ty), source); - if (!conv_el) { - return utils::Failure; - } - if (!conv_el.Get()) { - return nullptr; - } - return builder.create(target_ty, conv_el.Get(), count); - } - - sem::Type const* const type; - const Constant* el; - const size_t count; -}; - -/// Composite holds a number of mixed child Constant values. -/// Composite may be of a vector, matrix or array type. -/// If each element is the same type and value, then a Splat would be a more efficient constant -/// implementation. Use CreateComposite() to create the appropriate Constant type. -/// Composite implements the Constant interface. -struct Composite : Constant { - Composite(const sem::Type* t, std::vector els, bool all_0, bool any_0) - : type(t), elements(std::move(els)), all_zero(all_0), any_zero(any_0), hash(CalcHash()) {} - ~Composite() override = default; - const sem::Type* Type() const override { return type; } - std::variant Value() const override { return {}; } - const Constant* Index(size_t i) const override { - return i < elements.size() ? elements[i] : nullptr; - } - bool AllZero() const override { return all_zero; } - bool AnyZero() const override { return any_zero; } - bool AllEqual() const override { return false; /* otherwise this should be a Splat */ } - size_t Hash() const override { return hash; } - - utils::Result Convert(ProgramBuilder& builder, - const sem::Type* target_ty, - const Source& source) const override { - // Convert each of the composite element types. - auto* el_ty = sem::Type::ElementOf(target_ty); - std::vector conv_els; - conv_els.reserve(elements.size()); - for (auto* el : elements) { - auto conv_el = el->Convert(builder, el_ty, source); - if (!conv_el) { - return utils::Failure; - } - if (!conv_el.Get()) { - return nullptr; - } - conv_els.emplace_back(conv_el.Get()); - } - return CreateComposite(builder, target_ty, std::move(conv_els)); - } - - size_t CalcHash() { - auto h = utils::Hash(type, all_zero, any_zero); - for (auto* el : elements) { - utils::HashCombine(&h, el->Hash()); - } - return h; - } - - sem::Type const* const type; - const std::vector elements; - const bool all_zero; - const bool any_zero; - const size_t hash; -}; - -/// CreateElement constructs and returns an Element. -template -const Constant* CreateElement(ProgramBuilder& builder, const sem::Type* t, T v) { - return builder.create>(t, v); -} - -/// ZeroValue returns a Constant for the zero-value of the type `type`. -const Constant* ZeroValue(ProgramBuilder& builder, const sem::Type* type) { - return Switch( - type, // - [&](const sem::Vector* v) -> const Constant* { - auto* zero_el = ZeroValue(builder, v->type()); - return builder.create(type, zero_el, v->Width()); - }, - [&](const sem::Matrix* m) -> const Constant* { - auto* zero_el = ZeroValue(builder, m->ColumnType()); - return builder.create(type, zero_el, m->columns()); - }, - [&](const sem::Array* a) -> const Constant* { - if (auto* zero_el = ZeroValue(builder, a->ElemType())) { - return builder.create(type, zero_el, a->Count()); - } - return nullptr; - }, - [&](const sem::Struct* s) -> const Constant* { - std::unordered_map zero_by_type; - std::vector zeros; - zeros.reserve(s->Members().size()); - for (auto* member : s->Members()) { - auto* zero = utils::GetOrCreate(zero_by_type, member->Type(), - [&] { return ZeroValue(builder, member->Type()); }); - if (!zero) { - return nullptr; - } - zeros.emplace_back(zero); - } - if (zero_by_type.size() == 1) { - // All members were of the same type, so the zero value is the same for all members. - return builder.create(type, zeros[0], s->Members().size()); - } - return CreateComposite(builder, s, std::move(zeros)); - }, - [&](Default) -> const Constant* { - return TypeDispatch(type, [&](auto zero) -> const Constant* { - return CreateElement(builder, type, zero); - }); - }); -} - -/// Equal returns true if the constants `a` and `b` are of the same type and value. -bool Equal(const sem::Constant* a, const sem::Constant* b) { - if (a->Hash() != b->Hash()) { - return false; - } - if (a->Type() != b->Type()) { - return false; - } - return Switch( - a->Type(), // - [&](const sem::Vector* vec) { - for (size_t i = 0; i < vec->Width(); i++) { - if (!Equal(a->Index(i), b->Index(i))) { - return false; - } - } - return true; - }, - [&](const sem::Matrix* mat) { - for (size_t i = 0; i < mat->columns(); i++) { - if (!Equal(a->Index(i), b->Index(i))) { - return false; - } - } - return true; - }, - [&](const sem::Array* arr) { - for (size_t i = 0; i < arr->Count(); i++) { - if (!Equal(a->Index(i), b->Index(i))) { - return false; - } - } - return true; - }, - [&](Default) { return a->Value() == b->Value(); }); -} - -/// CreateComposite is used to construct a constant of a vector, matrix or array type. -/// CreateComposite examines the element values and will return either a Composite or a Splat, -/// depending on the element types and values. -const Constant* CreateComposite(ProgramBuilder& builder, - const sem::Type* type, - std::vector elements) { - if (elements.size() == 0) { - return nullptr; - } - bool any_zero = false; - bool all_zero = true; - bool all_equal = true; - auto* first = elements.front(); - for (auto* el : elements) { - if (!el) { - return nullptr; - } - if (!any_zero && el->AnyZero()) { - any_zero = true; - } - if (all_zero && !el->AllZero()) { - all_zero = false; - } - if (all_equal && el != first) { - if (!Equal(el, first)) { - all_equal = false; - } - } - } - if (all_equal) { - return builder.create(type, elements[0], elements.size()); - } else { - return builder.create(type, std::move(elements), all_zero, any_zero); - } -} - -} // namespace - -const sem::Constant* Resolver::EvaluateLiteralValue(const ast::LiteralExpression* literal, - const sem::Type* type) { - return Switch( - literal, - [&](const ast::BoolLiteralExpression* lit) { - return CreateElement(*builder_, type, lit->value); - }, - [&](const ast::IntLiteralExpression* lit) -> const Constant* { - switch (lit->suffix) { - case ast::IntLiteralExpression::Suffix::kNone: - return CreateElement(*builder_, type, AInt(lit->value)); - case ast::IntLiteralExpression::Suffix::kI: - return CreateElement(*builder_, type, i32(lit->value)); - case ast::IntLiteralExpression::Suffix::kU: - return CreateElement(*builder_, type, u32(lit->value)); - } - return nullptr; - }, - [&](const ast::FloatLiteralExpression* lit) -> const Constant* { - switch (lit->suffix) { - case ast::FloatLiteralExpression::Suffix::kNone: - return CreateElement(*builder_, type, AFloat(lit->value)); - case ast::FloatLiteralExpression::Suffix::kF: - return CreateElement(*builder_, type, f32(lit->value)); - case ast::FloatLiteralExpression::Suffix::kH: - return CreateElement(*builder_, type, f16(lit->value)); - } - return nullptr; - }); -} - -const sem::Constant* Resolver::EvaluateCtorOrConvValue( - const std::vector& args, - const sem::Type* ty) { - // For zero value init, return 0s - if (args.empty()) { - return ZeroValue(*builder_, ty); - } - - if (auto* el_ty = sem::Type::ElementOf(ty); el_ty && args.size() == 1) { - // Type constructor or conversion that takes a single argument. - auto& src = args[0]->Declaration()->source; - auto* arg = static_cast(args[0]->ConstantValue()); - if (!arg) { - return nullptr; // Single argument is not constant. - } - - if (ty->is_scalar()) { // Scalar type conversion: i32(x), u32(x), bool(x), etc - return ConvertValue(arg, el_ty, src).Get(); - } - - if (arg->Type() == el_ty) { - // Argument type matches function type. This is a splat. - auto splat = [&](size_t n) { return builder_->create(ty, arg, n); }; - return Switch( - ty, // - [&](const sem::Vector* v) { return splat(v->Width()); }, - [&](const sem::Matrix* m) { return splat(m->columns()); }, - [&](const sem::Array* a) { return splat(a->Count()); }); - } - - // Argument type and function type mismatch. This is a type conversion. - if (auto conv = ConvertValue(arg, ty, src)) { - return conv.Get(); - } - - return nullptr; - } - - // Helper for pushing all the argument constants to `els`. - auto args_as_constants = [&] { - return utils::Transform( - args, [&](auto* expr) { return static_cast(expr->ConstantValue()); }); - }; - - // Multiple arguments. Must be a type constructor. - - return Switch( - ty, // What's the target type being constructed? - [&](const sem::Vector*) -> const Constant* { - // Vector can be constructed with a mix of scalars / abstract numerics and smaller - // vectors. - std::vector els; - els.reserve(args.size()); - for (auto* expr : args) { - auto* arg = static_cast(expr->ConstantValue()); - if (!arg) { - return nullptr; - } - auto* arg_ty = arg->Type(); - if (auto* arg_vec = arg_ty->As()) { - // Extract out vector elements. - for (uint32_t i = 0; i < arg_vec->Width(); i++) { - auto* el = static_cast(arg->Index(i)); - if (!el) { - return nullptr; - } - els.emplace_back(el); - } - } else { - els.emplace_back(arg); - } - } - return CreateComposite(*builder_, ty, std::move(els)); - }, - [&](const sem::Matrix* m) -> const Constant* { - // Matrix can be constructed with a set of scalars / abstract numerics, or column - // vectors. - if (args.size() == m->columns() * m->rows()) { - // Matrix built from scalars / abstract numerics - std::vector els; - els.reserve(args.size()); - for (uint32_t c = 0; c < m->columns(); c++) { - std::vector column; - column.reserve(m->rows()); - for (uint32_t r = 0; r < m->rows(); r++) { - auto* arg = - static_cast(args[r + c * m->rows()]->ConstantValue()); - if (!arg) { - return nullptr; - } - column.emplace_back(arg); - } - els.push_back(CreateComposite(*builder_, m->ColumnType(), std::move(column))); - } - return CreateComposite(*builder_, ty, std::move(els)); - } - // Matrix built from column vectors - return CreateComposite(*builder_, ty, args_as_constants()); - }, - [&](const sem::Array*) { - // Arrays must be constructed using a list of elements - return CreateComposite(*builder_, ty, args_as_constants()); - }, - [&](const sem::Struct*) { - // Structures must be constructed using a list of elements - return CreateComposite(*builder_, ty, args_as_constants()); - }); -} - -const sem::Constant* Resolver::EvaluateIndexValue(const sem::Expression* obj_expr, - const sem::Expression* idx_expr) { - auto obj_val = obj_expr->ConstantValue(); - if (!obj_val) { - return {}; - } - - auto idx_val = idx_expr->ConstantValue(); - if (!idx_val) { - return {}; - } - - uint32_t el_count = 0; - sem::Type::ElementOf(obj_val->Type(), &el_count); - - AInt idx = idx_val->As(); - if (idx < 0 || idx >= el_count) { - auto clamped = std::min(std::max(idx, 0), el_count - 1); - AddWarning("index " + std::to_string(idx) + " out of bounds [0.." + - std::to_string(el_count - 1) + "]. Clamping index to " + - std::to_string(clamped), - idx_expr->Declaration()->source); - idx = clamped; - } - - return obj_val->Index(static_cast(idx)); -} - -const sem::Constant* Resolver::EvaluateMemberAccessValue(const sem::Expression* obj_expr, - const sem::StructMember* member) { - auto obj_val = obj_expr->ConstantValue(); - if (!obj_val) { - return {}; - } - return obj_val->Index(static_cast(member->Index())); -} - -const sem::Constant* Resolver::EvaluateSwizzleValue(const sem::Expression* vec_expr, - const sem::Type* type, - const std::vector& indices) { - auto* vec_val = vec_expr->ConstantValue(); - if (!vec_val) { - return nullptr; - } - if (indices.size() == 1) { - return static_cast(vec_val->Index(static_cast(indices[0]))); - } else { - auto values = utils::Transform( - indices, [&](uint32_t i) { return static_cast(vec_val->Index(i)); }); - return CreateComposite(*builder_, type, std::move(values)); - } -} - -const sem::Constant* Resolver::EvaluateBitcastValue(const sem::Expression*, const sem::Type*) { - // TODO(crbug.com/tint/1581): Implement @const intrinsics - return nullptr; -} - -const sem::Constant* Resolver::EvaluateBinaryValue(const sem::Expression*, - const sem::Expression*, - const IntrinsicTable::BinaryOperator&) { - // TODO(crbug.com/tint/1581): Implement @const intrinsics - return nullptr; -} - -const sem::Constant* Resolver::EvaluateUnaryValue(const sem::Expression*, - const IntrinsicTable::UnaryOperator&) { - // TODO(crbug.com/tint/1581): Implement @const intrinsics - return nullptr; -} - -utils::Result Resolver::ConvertValue(const sem::Constant* value, - const sem::Type* target_ty, - const Source& source) { - if (value->Type() == target_ty) { - return value; - } - auto conv = static_cast(value)->Convert(*builder_, target_ty, source); - if (!conv) { - return utils::Failure; - } - return conv.Get(); -} - -} // namespace tint::resolver