// Copyright 2022 The Tint Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "src/tint/resolver/const_eval.h" #include #include #include #include #include #include #include #include "src/tint/program_builder.h" #include "src/tint/sem/abstract_float.h" #include "src/tint/sem/abstract_int.h" #include "src/tint/sem/array.h" #include "src/tint/sem/bool.h" #include "src/tint/sem/constant.h" #include "src/tint/sem/f16.h" #include "src/tint/sem/f32.h" #include "src/tint/sem/i32.h" #include "src/tint/sem/matrix.h" #include "src/tint/sem/member_accessor_expression.h" #include "src/tint/sem/type_initializer.h" #include "src/tint/sem/u32.h" #include "src/tint/sem/vector.h" #include "src/tint/utils/compiler_macros.h" #include "src/tint/utils/map.h" #include "src/tint/utils/scoped_assignment.h" #include "src/tint/utils/transform.h" using namespace tint::number_suffixes; // NOLINT namespace tint::resolver { namespace { /// Returns the first element of a parameter pack template T First(T&& first, ...) { return std::forward(first); } /// Helper that calls `f` passing in the value of all `cs`. /// Calls `f` with all constants cast to the type of the first `cs` argument. template auto Dispatch_iu32(F&& f, CONSTANTS&&... cs) { return Switch( First(cs...)->Type(), // [&](const sem::I32*) { return f(cs->template As()...); }, [&](const sem::U32*) { return f(cs->template As()...); }); } /// Helper that calls `f` passing in the value of all `cs`. /// Calls `f` with all constants cast to the type of the first `cs` argument. template auto Dispatch_ia_iu32(F&& f, CONSTANTS&&... cs) { return Switch( First(cs...)->Type(), // [&](const sem::AbstractInt*) { return f(cs->template As()...); }, [&](const sem::I32*) { return f(cs->template As()...); }, [&](const sem::U32*) { return f(cs->template As()...); }); } /// Helper that calls `f` passing in the value of all `cs`. /// Calls `f` with all constants cast to the type of the first `cs` argument. template auto Dispatch_ia_iu32_bool(F&& f, CONSTANTS&&... cs) { return Switch( First(cs...)->Type(), // [&](const sem::AbstractInt*) { return f(cs->template As()...); }, [&](const sem::I32*) { return f(cs->template As()...); }, [&](const sem::U32*) { return f(cs->template As()...); }, [&](const sem::Bool*) { return f(cs->template As()...); }); } /// Helper that calls `f` passing in the value of all `cs`. /// Calls `f` with all constants cast to the type of the first `cs` argument. template auto Dispatch_fia_fi32_f16(F&& f, CONSTANTS&&... cs) { return Switch( First(cs...)->Type(), // [&](const sem::AbstractInt*) { return f(cs->template As()...); }, [&](const sem::AbstractFloat*) { return f(cs->template As()...); }, [&](const sem::F32*) { return f(cs->template As()...); }, [&](const sem::I32*) { return f(cs->template As()...); }, [&](const sem::F16*) { return f(cs->template As()...); }); } /// Helper that calls `f` passing in the value of all `cs`. /// Calls `f` with all constants cast to the type of the first `cs` argument. template auto Dispatch_fia_fiu32_f16(F&& f, CONSTANTS&&... cs) { return Switch( First(cs...)->Type(), // [&](const sem::AbstractInt*) { return f(cs->template As()...); }, [&](const sem::AbstractFloat*) { return f(cs->template As()...); }, [&](const sem::F32*) { return f(cs->template As()...); }, [&](const sem::I32*) { return f(cs->template As()...); }, [&](const sem::U32*) { return f(cs->template As()...); }, [&](const sem::F16*) { return f(cs->template As()...); }); } /// Helper that calls `f` passing in the value of all `cs`. /// Calls `f` with all constants cast to the type of the first `cs` argument. template auto Dispatch_fia_fiu32_f16_bool(F&& f, CONSTANTS&&... cs) { return Switch( First(cs...)->Type(), // [&](const sem::AbstractInt*) { return f(cs->template As()...); }, [&](const sem::AbstractFloat*) { return f(cs->template As()...); }, [&](const sem::F32*) { return f(cs->template As()...); }, [&](const sem::I32*) { return f(cs->template As()...); }, [&](const sem::U32*) { return f(cs->template As()...); }, [&](const sem::F16*) { return f(cs->template As()...); }, [&](const sem::Bool*) { return f(cs->template As()...); }); } /// Helper that calls `f` passing in the value of all `cs`. /// Calls `f` with all constants cast to the type of the first `cs` argument. template auto Dispatch_fa_f32_f16(F&& f, CONSTANTS&&... cs) { return Switch( First(cs...)->Type(), // [&](const sem::AbstractFloat*) { return f(cs->template As()...); }, [&](const sem::F32*) { return f(cs->template As()...); }, [&](const sem::F16*) { return f(cs->template As()...); }); } /// Helper that calls `f` passing in the value of all `cs`. /// Calls `f` with all constants cast to the type of the first `cs` argument. template auto Dispatch_bool(F&& f, CONSTANTS&&... cs) { return f(cs->template As()...); } /// ZeroTypeDispatch is a helper for calling the function `f`, passing a single zero-value argument /// of the C++ type that corresponds to the sem::Type `type`. For example, calling /// `ZeroTypeDispatch()` with a type of `sem::I32*` will call the function f with a single argument /// of `i32(0)`. /// @returns the value returned by calling `f`. /// @note `type` must be a scalar or abstract numeric type. Other types will not call `f`, and will /// return the zero-initialized value of the return type for `f`. template auto ZeroTypeDispatch(const sem::Type* type, F&& f) { return Switch( type, // [&](const sem::AbstractInt*) { return f(AInt(0)); }, // [&](const sem::AbstractFloat*) { return f(AFloat(0)); }, // [&](const sem::I32*) { return f(i32(0)); }, // [&](const sem::U32*) { return f(u32(0)); }, // [&](const sem::F32*) { return f(f32(0)); }, // [&](const sem::F16*) { return f(f16(0)); }, // [&](const sem::Bool*) { return f(static_cast(0)); }); } /// @returns `value` if `T` is not a Number, otherwise ValueOf returns the inner value of the /// Number. template inline auto ValueOf(T value) { if constexpr (std::is_same_v, T>) { return value; } else { return value.value; } } /// @returns true if `value` is a positive zero. template inline bool IsPositiveZero(T value) { using N = UnwrapNumber; return Number(value) == Number(0); // Considers sign bit } template std::string OverflowErrorMessage(NumberT lhs, const char* op, NumberT rhs) { std::stringstream ss; ss << "'" << lhs.value << " " << op << " " << rhs.value << "' cannot be represented as '" << FriendlyName() << "'"; return ss.str(); } /// @returns the number of consecutive leading bits in `@p e` set to `@p bit_value_to_count`. template auto CountLeadingBits(T e, T bit_value_to_count) -> std::make_unsigned_t { using UT = std::make_unsigned_t; constexpr UT kNumBits = sizeof(UT) * 8; constexpr UT kLeftMost = UT{1} << (kNumBits - 1); const UT b = bit_value_to_count == 0 ? UT{0} : kLeftMost; auto v = static_cast(e); auto count = UT{0}; while ((count < kNumBits) && ((v & kLeftMost) == b)) { ++count; v <<= 1; } return count; } /// @returns the number of consecutive trailing bits set to `@p bit_value_to_count` in `@p e` template auto CountTrailingBits(T e, T bit_value_to_count) -> std::make_unsigned_t { using UT = std::make_unsigned_t; constexpr UT kNumBits = sizeof(UT) * 8; constexpr UT kRightMost = UT{1}; const UT b = static_cast(bit_value_to_count); auto v = static_cast(e); auto count = UT{0}; while ((count < kNumBits) && ((v & kRightMost) == b)) { ++count; v >>= 1; } return count; } /// ImplConstant inherits from sem::Constant to add an private implementation method for conversion. struct ImplConstant : public sem::Constant { /// Convert attempts to convert the constant value to the given type. On error, Convert() /// creates a new diagnostic message and returns a Failure. virtual utils::Result Convert(ProgramBuilder& builder, const sem::Type* target_ty, const Source& source) const = 0; }; /// A result templated with a ImplConstant. using ImplResult = utils::Result; // Forward declaration const ImplConstant* CreateComposite(ProgramBuilder& builder, const sem::Type* type, utils::VectorRef elements); /// Element holds a single scalar or abstract-numeric value. /// Element implements the Constant interface. template struct Element : ImplConstant { static_assert(!std::is_same_v, T> || std::is_same_v, "T must be a Number or bool"); Element(const sem::Type* t, T v) : type(t), value(v) {} ~Element() override = default; const sem::Type* Type() const override { return type; } std::variant Value() const override { if constexpr (IsFloatingPoint>) { return static_cast(value); } else { return static_cast(value); } } const sem::Constant* Index(size_t) const override { return nullptr; } bool AllZero() const override { return IsPositiveZero(value); } bool AnyZero() const override { return IsPositiveZero(value); } bool AllEqual() const override { return true; } size_t Hash() const override { return utils::Hash(type, ValueOf(value)); } ImplResult Convert(ProgramBuilder& builder, const sem::Type* target_ty, const Source& source) const override { TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE); if (target_ty == type) { // If the types are identical, then no conversion is needed. return this; } return ZeroTypeDispatch(target_ty, [&](auto zero_to) -> ImplResult { // `value` is the source value. // `FROM` is the source type. // `TO` is the target type. using TO = std::decay_t; using FROM = T; if constexpr (std::is_same_v) { // [x -> bool] return builder.create>(target_ty, !IsPositiveZero(value)); } else if constexpr (std::is_same_v) { // [bool -> x] return builder.create>(target_ty, TO(value ? 1 : 0)); } else if (auto conv = CheckedConvert(value)) { // Conversion success return builder.create>(target_ty, conv.Get()); // --- Below this point are the failure cases --- } else if constexpr (IsAbstract) { // [abstract-numeric -> x] - materialization failure std::stringstream ss; ss << "value " << value << " cannot be represented as "; ss << "'" << builder.FriendlyName(target_ty) << "'"; builder.Diagnostics().add_error(tint::diag::System::Resolver, ss.str(), source); return utils::Failure; } else if constexpr (IsFloatingPoint) { // [x -> floating-point] - number not exactly representable // https://www.w3.org/TR/WGSL/#floating-point-conversion switch (conv.Failure()) { case ConversionFailure::kExceedsNegativeLimit: return builder.create>(target_ty, -TO::Inf()); case ConversionFailure::kExceedsPositiveLimit: return builder.create>(target_ty, TO::Inf()); } } else if constexpr (IsFloatingPoint) { // [floating-point -> integer] - number not exactly representable // https://www.w3.org/TR/WGSL/#floating-point-conversion switch (conv.Failure()) { case ConversionFailure::kExceedsNegativeLimit: return builder.create>(target_ty, TO::Lowest()); case ConversionFailure::kExceedsPositiveLimit: return builder.create>(target_ty, TO::Highest()); } } else if constexpr (IsIntegral) { // [integer -> integer] - number not exactly representable // Static cast return builder.create>(target_ty, static_cast(value)); } return nullptr; // Expression is not constant. }); TINT_END_DISABLE_WARNING(UNREACHABLE_CODE); } sem::Type const* const type; const T value; }; /// Splat holds a single Constant value, duplicated as all children. /// Splat is used for zero-initializers, 'splat' initializers, or initializers where each element is /// identical. Splat may be of a vector, matrix or array type. /// Splat implements the Constant interface. struct Splat : ImplConstant { Splat(const sem::Type* t, const sem::Constant* e, size_t n) : type(t), el(e), count(n) {} ~Splat() override = default; const sem::Type* Type() const override { return type; } std::variant Value() const override { return {}; } const sem::Constant* Index(size_t i) const override { return i < count ? el : nullptr; } bool AllZero() const override { return el->AllZero(); } bool AnyZero() const override { return el->AnyZero(); } bool AllEqual() const override { return true; } size_t Hash() const override { return utils::Hash(type, el->Hash(), count); } ImplResult Convert(ProgramBuilder& builder, const sem::Type* target_ty, const Source& source) const override { // Convert the single splatted element type. // Note: This file is the only place where `sem::Constant`s are created, so this static_cast // is safe. auto conv_el = static_cast(el)->Convert( builder, sem::Type::ElementOf(target_ty), source); if (!conv_el) { return utils::Failure; } if (!conv_el.Get()) { return nullptr; } return builder.create(target_ty, conv_el.Get(), count); } sem::Type const* const type; const sem::Constant* el; const size_t count; }; /// Composite holds a number of mixed child Constant values. /// Composite may be of a vector, matrix or array type. /// If each element is the same type and value, then a Splat would be a more efficient constant /// implementation. Use CreateComposite() to create the appropriate Constant type. /// Composite implements the Constant interface. struct Composite : ImplConstant { Composite(const sem::Type* t, utils::VectorRef els, bool all_0, bool any_0) : type(t), elements(std::move(els)), all_zero(all_0), any_zero(any_0), hash(CalcHash()) {} ~Composite() override = default; const sem::Type* Type() const override { return type; } std::variant Value() const override { return {}; } const sem::Constant* Index(size_t i) const override { return i < elements.Length() ? elements[i] : nullptr; } bool AllZero() const override { return all_zero; } bool AnyZero() const override { return any_zero; } bool AllEqual() const override { return false; /* otherwise this should be a Splat */ } size_t Hash() const override { return hash; } ImplResult Convert(ProgramBuilder& builder, const sem::Type* target_ty, const Source& source) const override { // Convert each of the composite element types. auto* el_ty = sem::Type::ElementOf(target_ty); utils::Vector conv_els; conv_els.Reserve(elements.Length()); for (auto* el : elements) { // Note: This file is the only place where `sem::Constant`s are created, so this // static_cast is safe. auto conv_el = static_cast(el)->Convert(builder, el_ty, source); if (!conv_el) { return utils::Failure; } if (!conv_el.Get()) { return nullptr; } conv_els.Push(conv_el.Get()); } return CreateComposite(builder, target_ty, std::move(conv_els)); } size_t CalcHash() { auto h = utils::Hash(type, all_zero, any_zero); for (auto* el : elements) { h = utils::HashCombine(h, el->Hash()); } return h; } sem::Type const* const type; const utils::Vector elements; const bool all_zero; const bool any_zero; const size_t hash; }; /// CreateElement constructs and returns an Element. template const ImplConstant* CreateElement(ProgramBuilder& builder, const sem::Type* t, T v) { return builder.create>(t, v); } /// ZeroValue returns a Constant for the zero-value of the type `type`. const ImplConstant* ZeroValue(ProgramBuilder& builder, const sem::Type* type) { return Switch( type, // [&](const sem::Vector* v) -> const ImplConstant* { auto* zero_el = ZeroValue(builder, v->type()); return builder.create(type, zero_el, v->Width()); }, [&](const sem::Matrix* m) -> const ImplConstant* { auto* zero_el = ZeroValue(builder, m->ColumnType()); return builder.create(type, zero_el, m->columns()); }, [&](const sem::Array* a) -> const ImplConstant* { if (auto n = a->ConstantCount()) { if (auto* zero_el = ZeroValue(builder, a->ElemType())) { return builder.create(type, zero_el, n.value()); } } return nullptr; }, [&](const sem::Struct* s) -> const ImplConstant* { std::unordered_map zero_by_type; utils::Vector zeros; zeros.Reserve(s->Members().size()); for (auto* member : s->Members()) { auto* zero = utils::GetOrCreate(zero_by_type, member->Type(), [&] { return ZeroValue(builder, member->Type()); }); if (!zero) { return nullptr; } zeros.Push(zero); } if (zero_by_type.size() == 1) { // All members were of the same type, so the zero value is the same for all members. return builder.create(type, zeros[0], s->Members().size()); } return CreateComposite(builder, s, std::move(zeros)); }, [&](Default) -> const ImplConstant* { return ZeroTypeDispatch(type, [&](auto zero) -> const ImplConstant* { return CreateElement(builder, type, zero); }); }); } /// Equal returns true if the constants `a` and `b` are of the same type and value. bool Equal(const sem::Constant* a, const sem::Constant* b) { if (a->Hash() != b->Hash()) { return false; } if (a->Type() != b->Type()) { return false; } return Switch( a->Type(), // [&](const sem::Vector* vec) { for (size_t i = 0; i < vec->Width(); i++) { if (!Equal(a->Index(i), b->Index(i))) { return false; } } return true; }, [&](const sem::Matrix* mat) { for (size_t i = 0; i < mat->columns(); i++) { if (!Equal(a->Index(i), b->Index(i))) { return false; } } return true; }, [&](const sem::Array* arr) { if (auto count = arr->ConstantCount()) { for (size_t i = 0; i < count; i++) { if (!Equal(a->Index(i), b->Index(i))) { return false; } } return true; } return false; }, [&](Default) { return a->Value() == b->Value(); }); } /// CreateComposite is used to construct a constant of a vector, matrix or array type. /// CreateComposite examines the element values and will return either a Composite or a Splat, /// depending on the element types and values. const ImplConstant* CreateComposite(ProgramBuilder& builder, const sem::Type* type, utils::VectorRef elements) { if (elements.IsEmpty()) { return nullptr; } bool any_zero = false; bool all_zero = true; bool all_equal = true; auto* first = elements.Front(); for (auto* el : elements) { if (!el) { return nullptr; } if (!any_zero && el->AnyZero()) { any_zero = true; } if (all_zero && !el->AllZero()) { all_zero = false; } if (all_equal && el != first) { if (!Equal(el, first)) { all_equal = false; } } } if (all_equal) { return builder.create(type, elements[0], elements.Length()); } else { return builder.create(type, std::move(elements), all_zero, any_zero); } } namespace detail { /// Implementation of TransformElements template ImplResult TransformElements(ProgramBuilder& builder, const sem::Type* composite_ty, F&& f, size_t index, CONSTANTS&&... cs) { uint32_t n = 0; auto* ty = First(cs...)->Type(); auto* el_ty = sem::Type::ElementOf(ty, &n); if (el_ty == ty) { constexpr bool kHasIndexParam = traits::IsType>; if constexpr (kHasIndexParam) { return f(cs..., index); } else { return f(cs...); } } utils::Vector els; els.Reserve(n); for (uint32_t i = 0; i < n; i++) { if (auto el = detail::TransformElements(builder, sem::Type::ElementOf(composite_ty), std::forward(f), index + i, cs->Index(i)...)) { els.Push(el.Get()); } else { return el.Failure(); } } return CreateComposite(builder, composite_ty, std::move(els)); } } // namespace detail /// TransformElements constructs a new constant of type `composite_ty` by applying the /// transformation function `f` on each of the most deeply nested elements of 'cs'. Assumes that all /// input constants `cs` are of the same arity (all scalars or all vectors of the same size). /// If `f`'s last argument is a `size_t`, then the index of the most deeply nested element inside /// the most deeply nested aggregate type will be passed in. template ImplResult TransformElements(ProgramBuilder& builder, const sem::Type* composite_ty, F&& f, CONSTANTS&&... cs) { return detail::TransformElements(builder, composite_ty, f, 0, cs...); } /// TransformBinaryElements constructs a new constant of type `composite_ty` by applying the /// transformation function 'f' on each of the most deeply nested elements of both `c0` and `c1`. /// Unlike TransformElements, this function handles the constants being of different arity, e.g. /// vector-scalar, scalar-vector. template ImplResult TransformBinaryElements(ProgramBuilder& builder, const sem::Type* composite_ty, F&& f, const sem::Constant* c0, const sem::Constant* c1) { uint32_t n0 = 0, n1 = 0; sem::Type::ElementOf(c0->Type(), &n0); sem::Type::ElementOf(c1->Type(), &n1); uint32_t max_n = std::max(n0, n1); // If arity of both constants is 1, invoke callback if (max_n == 1u) { return f(c0, c1); } utils::Vector els; els.Reserve(max_n); for (uint32_t i = 0; i < max_n; i++) { auto nested_or_self = [&](auto& c, uint32_t num_elems) { if (num_elems == 1) { return c; } return c->Index(i); }; if (auto el = TransformBinaryElements(builder, sem::Type::ElementOf(composite_ty), std::forward(f), nested_or_self(c0, n0), nested_or_self(c1, n1))) { els.Push(el.Get()); } else { return el.Failure(); } } return CreateComposite(builder, composite_ty, std::move(els)); } } // namespace ConstEval::ConstEval(ProgramBuilder& b) : builder(b) {} template utils::Result ConstEval::Add(NumberT a, NumberT b) { NumberT result; if constexpr (IsAbstract) { // Check for over/underflow for abstract values if (auto r = CheckedAdd(a, b)) { result = r->value; } else { AddError(OverflowErrorMessage(a, "+", b), *current_source); return utils::Failure; } } else { using T = UnwrapNumber; auto add_values = [](T lhs, T rhs) { if constexpr (std::is_integral_v && std::is_signed_v) { // Ensure no UB for signed overflow using UT = std::make_unsigned_t; return static_cast(static_cast(lhs) + static_cast(rhs)); } else { return lhs + rhs; } }; result = add_values(a.value, b.value); } return result; } template utils::Result ConstEval::Mul(NumberT a, NumberT b) { using T = UnwrapNumber; NumberT result; if constexpr (IsAbstract) { // Check for over/underflow for abstract values if (auto r = CheckedMul(a, b)) { result = r->value; } else { AddError(OverflowErrorMessage(a, "*", b), *current_source); return utils::Failure; } } else { auto mul_values = [](T lhs, T rhs) { if constexpr (std::is_integral_v && std::is_signed_v) { // For signed integrals, avoid C++ UB by multiplying as unsigned using UT = std::make_unsigned_t; return static_cast(static_cast(lhs) * static_cast(rhs)); } else { return lhs * rhs; } }; result = mul_values(a.value, b.value); } return result; } template utils::Result ConstEval::Dot2(NumberT a1, NumberT a2, NumberT b1, NumberT b2) { auto r1 = Mul(a1, b1); if (!r1) { return utils::Failure; } auto r2 = Mul(a2, b2); if (!r2) { return utils::Failure; } auto r = Add(r1.Get(), r2.Get()); if (!r) { return utils::Failure; } return r; } template utils::Result ConstEval::Dot3(NumberT a1, NumberT a2, NumberT a3, NumberT b1, NumberT b2, NumberT b3) { auto r1 = Mul(a1, b1); if (!r1) { return utils::Failure; } auto r2 = Mul(a2, b2); if (!r2) { return utils::Failure; } auto r3 = Mul(a3, b3); if (!r3) { return utils::Failure; } auto r = Add(r1.Get(), r2.Get()); if (!r) { return utils::Failure; } r = Add(r.Get(), r3.Get()); if (!r) { return utils::Failure; } return r; } template utils::Result ConstEval::Dot4(NumberT a1, NumberT a2, NumberT a3, NumberT a4, NumberT b1, NumberT b2, NumberT b3, NumberT b4) { auto r1 = Mul(a1, b1); if (!r1) { return utils::Failure; } auto r2 = Mul(a2, b2); if (!r2) { return utils::Failure; } auto r3 = Mul(a3, b3); if (!r3) { return utils::Failure; } auto r4 = Mul(a4, b4); if (!r4) { return utils::Failure; } auto r = Add(r1.Get(), r2.Get()); if (!r) { return utils::Failure; } r = Add(r.Get(), r3.Get()); if (!r) { return utils::Failure; } r = Add(r.Get(), r4.Get()); if (!r) { return utils::Failure; } return r; } auto ConstEval::AddFunc(const sem::Type* elem_ty) { return [=](auto a1, auto a2) -> ImplResult { if (auto r = Add(a1, a2)) { return CreateElement(builder, elem_ty, r.Get()); } return utils::Failure; }; } auto ConstEval::MulFunc(const sem::Type* elem_ty) { return [=](auto a1, auto a2) -> ImplResult { if (auto r = Mul(a1, a2)) { return CreateElement(builder, elem_ty, r.Get()); } return utils::Failure; }; } auto ConstEval::Dot2Func(const sem::Type* elem_ty) { return [=](auto a1, auto a2, auto b1, auto b2) -> ImplResult { if (auto r = Dot2(a1, a2, b1, b2)) { return CreateElement(builder, elem_ty, r.Get()); } return utils::Failure; }; } auto ConstEval::Dot3Func(const sem::Type* elem_ty) { return [=](auto a1, auto a2, auto a3, auto b1, auto b2, auto b3) -> ImplResult { if (auto r = Dot3(a1, a2, a3, b1, b2, b3)) { return CreateElement(builder, elem_ty, r.Get()); } return utils::Failure; }; } auto ConstEval::Dot4Func(const sem::Type* elem_ty) { return [=](auto a1, auto a2, auto a3, auto a4, auto b1, auto b2, auto b3, auto b4) -> ImplResult { if (auto r = Dot4(a1, a2, a3, a4, b1, b2, b3, b4)) { return CreateElement(builder, elem_ty, r.Get()); } return utils::Failure; }; } ConstEval::Result ConstEval::Literal(const sem::Type* ty, const ast::LiteralExpression* literal) { return Switch( literal, [&](const ast::BoolLiteralExpression* lit) { return CreateElement(builder, ty, lit->value); }, [&](const ast::IntLiteralExpression* lit) -> ImplResult { switch (lit->suffix) { case ast::IntLiteralExpression::Suffix::kNone: return CreateElement(builder, ty, AInt(lit->value)); case ast::IntLiteralExpression::Suffix::kI: return CreateElement(builder, ty, i32(lit->value)); case ast::IntLiteralExpression::Suffix::kU: return CreateElement(builder, ty, u32(lit->value)); } return nullptr; }, [&](const ast::FloatLiteralExpression* lit) -> ImplResult { switch (lit->suffix) { case ast::FloatLiteralExpression::Suffix::kNone: return CreateElement(builder, ty, AFloat(lit->value)); case ast::FloatLiteralExpression::Suffix::kF: return CreateElement(builder, ty, f32(lit->value)); case ast::FloatLiteralExpression::Suffix::kH: return CreateElement(builder, ty, f16(lit->value)); } return nullptr; }); } ConstEval::Result ConstEval::ArrayOrStructInit(const sem::Type* ty, utils::VectorRef args) { if (args.IsEmpty()) { return ZeroValue(builder, ty); } if (args.Length() == 1 && args[0]->Type() == ty) { // Identity initializer. return args[0]->ConstantValue(); } // Multiple arguments. Must be a type initializer. utils::Vector els; els.Reserve(args.Length()); for (auto* arg : args) { els.Push(arg->ConstantValue()); } return CreateComposite(builder, ty, std::move(els)); } ConstEval::Result ConstEval::Conv(const sem::Type* ty, utils::VectorRef args, const Source& source) { uint32_t el_count = 0; auto* el_ty = sem::Type::ElementOf(ty, &el_count); if (!el_ty) { return nullptr; } if (!args[0]) { return nullptr; // Single argument is not constant. } return Convert(ty, args[0], source); } ConstEval::Result ConstEval::Zero(const sem::Type* ty, utils::VectorRef, const Source&) { return ZeroValue(builder, ty); } ConstEval::Result ConstEval::Identity(const sem::Type*, utils::VectorRef args, const Source&) { return args[0]; } ConstEval::Result ConstEval::VecSplat(const sem::Type* ty, utils::VectorRef args, const Source&) { if (auto* arg = args[0]) { return builder.create(ty, arg, static_cast(ty)->Width()); } return nullptr; } ConstEval::Result ConstEval::VecInitS(const sem::Type* ty, utils::VectorRef args, const Source&) { return CreateComposite(builder, ty, args); } ConstEval::Result ConstEval::VecInitM(const sem::Type* ty, utils::VectorRef args, const Source&) { utils::Vector els; for (auto* arg : args) { auto* val = arg; if (!val) { return nullptr; } auto* arg_ty = arg->Type(); if (auto* arg_vec = arg_ty->As()) { // Extract out vector elements. for (uint32_t j = 0; j < arg_vec->Width(); j++) { auto* el = val->Index(j); if (!el) { return nullptr; } els.Push(el); } } else { els.Push(val); } } return CreateComposite(builder, ty, std::move(els)); } ConstEval::Result ConstEval::MatInitS(const sem::Type* ty, utils::VectorRef args, const Source&) { auto* m = static_cast(ty); utils::Vector els; for (uint32_t c = 0; c < m->columns(); c++) { utils::Vector column; for (uint32_t r = 0; r < m->rows(); r++) { auto i = r + c * m->rows(); column.Push(args[i]); } els.Push(CreateComposite(builder, m->ColumnType(), std::move(column))); } return CreateComposite(builder, ty, std::move(els)); } ConstEval::Result ConstEval::MatInitV(const sem::Type* ty, utils::VectorRef args, const Source&) { return CreateComposite(builder, ty, args); } ConstEval::Result ConstEval::Index(const sem::Expression* obj_expr, const sem::Expression* idx_expr) { auto idx_val = idx_expr->ConstantValue(); if (!idx_val) { return nullptr; } uint32_t el_count = 0; sem::Type::ElementOf(obj_expr->Type()->UnwrapRef(), &el_count); AInt idx = idx_val->As(); if (idx < 0 || (el_count > 0 && idx >= el_count)) { std::string range; if (el_count > 0) { range = " [0.." + std::to_string(el_count - 1) + "]"; } AddError("index " + std::to_string(idx) + " out of bounds" + range, idx_expr->Declaration()->source); return utils::Failure; } auto obj_val = obj_expr->ConstantValue(); if (!obj_val) { return nullptr; } return obj_val->Index(static_cast(idx)); } ConstEval::Result ConstEval::MemberAccess(const sem::Expression* obj_expr, const sem::StructMember* member) { auto obj_val = obj_expr->ConstantValue(); if (!obj_val) { return nullptr; } return obj_val->Index(static_cast(member->Index())); } ConstEval::Result ConstEval::Swizzle(const sem::Type* ty, const sem::Expression* vec_expr, utils::VectorRef indices) { auto* vec_val = vec_expr->ConstantValue(); if (!vec_val) { return nullptr; } if (indices.Length() == 1) { return vec_val->Index(static_cast(indices[0])); } auto values = utils::Transform<4>( indices, [&](uint32_t i) { return vec_val->Index(static_cast(i)); }); return CreateComposite(builder, ty, std::move(values)); } ConstEval::Result ConstEval::Bitcast(const sem::Type*, const sem::Expression*) { // TODO(crbug.com/tint/1581): Implement @const intrinsics return nullptr; } ConstEval::Result ConstEval::OpComplement(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c) { auto create = [&](auto i) { return CreateElement(builder, c->Type(), decltype(i)(~i.value)); }; return Dispatch_ia_iu32(create, c); }; return TransformElements(builder, ty, transform, args[0]); } ConstEval::Result ConstEval::OpUnaryMinus(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c) { auto create = [&](auto i) { // For signed integrals, avoid C++ UB by not negating the // smallest negative number. In WGSL, this operation is well // defined to return the same value, see: // https://gpuweb.github.io/gpuweb/wgsl/#arithmetic-expr. using T = UnwrapNumber; if constexpr (std::is_integral_v) { auto v = i.value; if (v != std::numeric_limits::min()) { v = -v; } return CreateElement(builder, c->Type(), decltype(i)(v)); } else { return CreateElement(builder, c->Type(), decltype(i)(-i.value)); } }; return Dispatch_fia_fi32_f16(create, c); }; return TransformElements(builder, ty, transform, args[0]); } ConstEval::Result ConstEval::OpNot(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c) { auto create = [&](auto i) { return CreateElement(builder, c->Type(), decltype(i)(!i)); }; return Dispatch_bool(create, c); }; return TransformElements(builder, ty, transform, args[0]); } ConstEval::Result ConstEval::OpPlus(const sem::Type* ty, utils::VectorRef args, const Source& source) { TINT_SCOPED_ASSIGNMENT(current_source, &source); auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { return Dispatch_fia_fiu32_f16(AddFunc(c0->Type()), c0, c1); }; return TransformBinaryElements(builder, ty, transform, args[0], args[1]); } ConstEval::Result ConstEval::OpMinus(const sem::Type* ty, utils::VectorRef args, const Source& source) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto create = [&](auto i, auto j) -> ImplResult { using NumberT = decltype(i); NumberT result; if constexpr (IsAbstract) { // Check for over/underflow for abstract values if (auto r = CheckedSub(i, j)) { result = r->value; } else { AddError(OverflowErrorMessage(i, "-", j), source); return utils::Failure; } } else { using T = UnwrapNumber; auto subtract_values = [](T lhs, T rhs) { if constexpr (std::is_integral_v && std::is_signed_v) { // Ensure no UB for signed underflow using UT = std::make_unsigned_t; return static_cast(static_cast(lhs) - static_cast(rhs)); } else { return lhs - rhs; } }; result = subtract_values(i.value, j.value); } return CreateElement(builder, c0->Type(), result); }; return Dispatch_fia_fiu32_f16(create, c0, c1); }; return TransformBinaryElements(builder, ty, transform, args[0], args[1]); } ConstEval::Result ConstEval::OpMultiply(const sem::Type* ty, utils::VectorRef args, const Source& source) { TINT_SCOPED_ASSIGNMENT(current_source, &source); auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { return Dispatch_fia_fiu32_f16(MulFunc(c0->Type()), c0, c1); }; return TransformBinaryElements(builder, ty, transform, args[0], args[1]); } ConstEval::Result ConstEval::OpMultiplyMatVec(const sem::Type* ty, utils::VectorRef args, const Source& source) { TINT_SCOPED_ASSIGNMENT(current_source, &source); auto* mat_ty = args[0]->Type()->As(); auto* vec_ty = args[1]->Type()->As(); auto* elem_ty = vec_ty->type(); auto dot = [&](const sem::Constant* m, size_t row, const sem::Constant* v) { ImplResult result; switch (mat_ty->columns()) { case 2: result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), // m->Index(0)->Index(row), // m->Index(1)->Index(row), // v->Index(0), // v->Index(1)); break; case 3: result = Dispatch_fa_f32_f16(Dot3Func(elem_ty), // m->Index(0)->Index(row), // m->Index(1)->Index(row), // m->Index(2)->Index(row), // v->Index(0), // v->Index(1), v->Index(2)); break; case 4: result = Dispatch_fa_f32_f16(Dot4Func(elem_ty), // m->Index(0)->Index(row), // m->Index(1)->Index(row), // m->Index(2)->Index(row), // m->Index(3)->Index(row), // v->Index(0), // v->Index(1), // v->Index(2), // v->Index(3)); break; } return result; }; utils::Vector result; for (size_t i = 0; i < mat_ty->rows(); ++i) { auto r = dot(args[0], i, args[1]); // matrix row i * vector if (!r) { return utils::Failure; } result.Push(r.Get()); } return CreateComposite(builder, ty, result); } ConstEval::Result ConstEval::OpMultiplyVecMat(const sem::Type* ty, utils::VectorRef args, const Source& source) { TINT_SCOPED_ASSIGNMENT(current_source, &source); auto* vec_ty = args[0]->Type()->As(); auto* mat_ty = args[1]->Type()->As(); auto* elem_ty = vec_ty->type(); auto dot = [&](const sem::Constant* v, const sem::Constant* m, size_t col) { ImplResult result; switch (mat_ty->rows()) { case 2: result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), // m->Index(col)->Index(0), // m->Index(col)->Index(1), // v->Index(0), // v->Index(1)); break; case 3: result = Dispatch_fa_f32_f16(Dot3Func(elem_ty), // m->Index(col)->Index(0), // m->Index(col)->Index(1), // m->Index(col)->Index(2), v->Index(0), // v->Index(1), // v->Index(2)); break; case 4: result = Dispatch_fa_f32_f16(Dot4Func(elem_ty), // m->Index(col)->Index(0), // m->Index(col)->Index(1), // m->Index(col)->Index(2), // m->Index(col)->Index(3), // v->Index(0), // v->Index(1), // v->Index(2), // v->Index(3)); } return result; }; utils::Vector result; for (size_t i = 0; i < mat_ty->columns(); ++i) { auto r = dot(args[0], args[1], i); // vector * matrix col i if (!r) { return utils::Failure; } result.Push(r.Get()); } return CreateComposite(builder, ty, result); } ConstEval::Result ConstEval::OpMultiplyMatMat(const sem::Type* ty, utils::VectorRef args, const Source& source) { TINT_SCOPED_ASSIGNMENT(current_source, &source); auto* mat1 = args[0]; auto* mat2 = args[1]; auto* mat1_ty = mat1->Type()->As(); auto* mat2_ty = mat2->Type()->As(); auto* elem_ty = mat1_ty->type(); auto dot = [&](const sem::Constant* m1, size_t row, const sem::Constant* m2, size_t col) { auto m1e = [&](size_t r, size_t c) { return m1->Index(c)->Index(r); }; auto m2e = [&](size_t r, size_t c) { return m2->Index(c)->Index(r); }; ImplResult result; switch (mat1_ty->columns()) { case 2: result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), // m1e(row, 0), // m1e(row, 1), // m2e(0, col), // m2e(1, col)); break; case 3: result = Dispatch_fa_f32_f16(Dot3Func(elem_ty), // m1e(row, 0), // m1e(row, 1), // m1e(row, 2), // m2e(0, col), // m2e(1, col), // m2e(2, col)); break; case 4: result = Dispatch_fa_f32_f16(Dot4Func(elem_ty), // m1e(row, 0), // m1e(row, 1), // m1e(row, 2), // m1e(row, 3), // m2e(0, col), // m2e(1, col), // m2e(2, col), // m2e(3, col)); break; } return result; }; utils::Vector result_mat; for (size_t c = 0; c < mat2_ty->columns(); ++c) { utils::Vector col_vec; for (size_t r = 0; r < mat1_ty->rows(); ++r) { auto v = dot(mat1, r, mat2, c); // mat1 row r * mat2 col c if (!v) { return utils::Failure; } col_vec.Push(v.Get()); // mat1 row r * mat2 col c } // Add column vector to matrix auto* col_vec_ty = ty->As()->ColumnType(); result_mat.Push(CreateComposite(builder, col_vec_ty, col_vec)); } return CreateComposite(builder, ty, result_mat); } ConstEval::Result ConstEval::OpDivide(const sem::Type* ty, utils::VectorRef args, const Source& source) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto create = [&](auto i, auto j) -> ImplResult { using NumberT = decltype(i); NumberT result; if constexpr (IsAbstract) { // Check for over/underflow for abstract values if (auto r = CheckedDiv(i, j)) { result = r->value; } else { AddError(OverflowErrorMessage(i, "/", j), source); return utils::Failure; } } else { using T = UnwrapNumber; auto divide_values = [](T lhs, T rhs) { if constexpr (std::is_integral_v) { // For integers, lhs / 0 returns lhs if (rhs == 0) { return lhs; } if constexpr (std::is_signed_v) { // For signed integers, for lhs / -1, return lhs if lhs is the // most negative value if (rhs == -1 && lhs == std::numeric_limits::min()) { return lhs; } } } return lhs / rhs; }; result = divide_values(i.value, j.value); } return CreateElement(builder, c0->Type(), result); }; return Dispatch_fia_fiu32_f16(create, c0, c1); }; return TransformBinaryElements(builder, ty, transform, args[0], args[1]); } ConstEval::Result ConstEval::OpEqual(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto create = [&](auto i, auto j) -> ImplResult { return CreateElement(builder, sem::Type::DeepestElementOf(ty), i == j); }; return Dispatch_fia_fiu32_f16_bool(create, c0, c1); }; return TransformElements(builder, ty, transform, args[0], args[1]); } ConstEval::Result ConstEval::OpNotEqual(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto create = [&](auto i, auto j) -> ImplResult { return CreateElement(builder, sem::Type::DeepestElementOf(ty), i != j); }; return Dispatch_fia_fiu32_f16_bool(create, c0, c1); }; return TransformElements(builder, ty, transform, args[0], args[1]); } ConstEval::Result ConstEval::OpLessThan(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto create = [&](auto i, auto j) -> ImplResult { return CreateElement(builder, sem::Type::DeepestElementOf(ty), i < j); }; return Dispatch_fia_fiu32_f16(create, c0, c1); }; return TransformElements(builder, ty, transform, args[0], args[1]); } ConstEval::Result ConstEval::OpGreaterThan(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto create = [&](auto i, auto j) -> ImplResult { return CreateElement(builder, sem::Type::DeepestElementOf(ty), i > j); }; return Dispatch_fia_fiu32_f16(create, c0, c1); }; return TransformElements(builder, ty, transform, args[0], args[1]); } ConstEval::Result ConstEval::OpLessThanEqual(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto create = [&](auto i, auto j) -> ImplResult { return CreateElement(builder, sem::Type::DeepestElementOf(ty), i <= j); }; return Dispatch_fia_fiu32_f16(create, c0, c1); }; return TransformElements(builder, ty, transform, args[0], args[1]); } ConstEval::Result ConstEval::OpGreaterThanEqual(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto create = [&](auto i, auto j) -> ImplResult { return CreateElement(builder, sem::Type::DeepestElementOf(ty), i >= j); }; return Dispatch_fia_fiu32_f16(create, c0, c1); }; return TransformElements(builder, ty, transform, args[0], args[1]); } ConstEval::Result ConstEval::OpAnd(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto create = [&](auto i, auto j) -> ImplResult { using T = decltype(i); T result; if constexpr (std::is_same_v) { result = i && j; } else { // integral result = i & j; } return CreateElement(builder, sem::Type::DeepestElementOf(ty), result); }; return Dispatch_ia_iu32_bool(create, c0, c1); }; return TransformElements(builder, ty, transform, args[0], args[1]); } ConstEval::Result ConstEval::OpOr(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto create = [&](auto i, auto j) -> ImplResult { using T = decltype(i); T result; if constexpr (std::is_same_v) { result = i || j; } else { // integral result = i | j; } return CreateElement(builder, sem::Type::DeepestElementOf(ty), result); }; return Dispatch_ia_iu32_bool(create, c0, c1); }; return TransformElements(builder, ty, transform, args[0], args[1]); } ConstEval::Result ConstEval::OpXor(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto create = [&](auto i, auto j) -> const ImplConstant* { return CreateElement(builder, sem::Type::DeepestElementOf(ty), decltype(i){i ^ j}); }; return Dispatch_ia_iu32(create, c0, c1); }; auto r = TransformElements(builder, ty, transform, args[0], args[1]); if (builder.Diagnostics().contains_errors()) { return utils::Failure; } return r; } ConstEval::Result ConstEval::OpShiftLeft(const sem::Type* ty, utils::VectorRef args, const Source& source) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto create = [&](auto e1, auto e2) -> const ImplConstant* { using NumberT = decltype(e1); using T = UnwrapNumber; using UT = std::make_unsigned_t; constexpr size_t bit_width = BitWidth; UT e1u = static_cast(e1); UT e2u = static_cast(e2); if constexpr (IsAbstract) { // The e2 + 1 most significant bits of e1 must have the same bit value, otherwise // sign change (overflow) would occur. // Check sign change only if e2 is less than bit width of e1. If e1 is larger // than bit width, we check for non-representable value below. if (e2u < bit_width) { UT must_match_msb = e2u + 1; UT mask = ~UT{0} << (bit_width - must_match_msb); if ((e1u & mask) != 0 && (e1u & mask) != mask) { AddError("shift left operation results in sign change", source); return nullptr; } } else { // If shift value >= bit_width, then any non-zero value would overflow if (e1 != 0) { AddError(OverflowErrorMessage(e1, "<<", e2), source); return nullptr; } // It's UB in C++ to shift by greater or equal to the bit width (even if the lhs // is 0), so we make sure to avoid this by setting the shift value to 0. e2 = 0; } } else { if (static_cast(e2) >= bit_width) { // At shader/pipeline-creation time, it is an error to shift by the bit width of // the lhs or greater. // NOTE: At runtime, we shift by e2 % (bit width of e1). AddError( "shift left value must be less than the bit width of the lhs, which is " + std::to_string(bit_width), source); return nullptr; } if constexpr (std::is_signed_v) { // If T is a signed integer type, and the e2+1 most significant bits of e1 do // not have the same bit value, then error. size_t must_match_msb = e2u + 1; UT mask = ~UT{0} << (bit_width - must_match_msb); if ((e1u & mask) != 0 && (e1u & mask) != mask) { AddError("shift left operation results in sign change", source); return nullptr; } } else { // If T is an unsigned integer type, and any of the e2 most significant bits of // e1 are 1, then error. if (e2u > 0) { size_t must_be_zero_msb = e2u; UT mask = ~UT{0} << (bit_width - must_be_zero_msb); if ((e1u & mask) != 0) { AddError(OverflowErrorMessage(e1, "<<", e2), source); } } } } // Avoid UB by left shifting as unsigned value auto result = static_cast(static_cast(e1) << e2); return CreateElement(builder, sem::Type::DeepestElementOf(ty), NumberT{result}); }; return Dispatch_ia_iu32(create, c0, c1); }; if (!sem::Type::DeepestElementOf(args[1]->Type())->Is()) { TINT_ICE(Resolver, builder.Diagnostics()) << "Element type of rhs of ShiftLeft must be a u32"; return nullptr; } auto r = TransformElements(builder, ty, transform, args[0], args[1]); if (builder.Diagnostics().contains_errors()) { return utils::Failure; } return r; } ConstEval::Result ConstEval::all(const sem::Type* ty, utils::VectorRef args, const Source&) { return CreateElement(builder, ty, !args[0]->AnyZero()); } ConstEval::Result ConstEval::any(const sem::Type* ty, utils::VectorRef args, const Source&) { return CreateElement(builder, ty, !args[0]->AllZero()); } ConstEval::Result ConstEval::asin(const sem::Type* ty, utils::VectorRef args, const Source& source) { auto transform = [&](const sem::Constant* c0) { auto create = [&](auto i) -> ImplResult { using NumberT = decltype(i); if (i.value < NumberT(-1.0) || i.value > NumberT(1.0)) { AddError("asin must be called with a value in the range [-1, 1]", source); return utils::Failure; } return CreateElement(builder, c0->Type(), decltype(i)(std::asin(i.value))); }; return Dispatch_fa_f32_f16(create, c0); }; return TransformElements(builder, ty, transform, args[0]); } ConstEval::Result ConstEval::asinh(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c0) { auto create = [&](auto i) { return CreateElement(builder, c0->Type(), decltype(i)(std::asinh(i.value))); }; return Dispatch_fa_f32_f16(create, c0); }; auto r = TransformElements(builder, ty, transform, args[0]); if (builder.Diagnostics().contains_errors()) { return utils::Failure; } return r; } ConstEval::Result ConstEval::atan(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c0) { auto create = [&](auto i) { return CreateElement(builder, c0->Type(), decltype(i)(std::atan(i.value))); }; return Dispatch_fa_f32_f16(create, c0); }; return TransformElements(builder, ty, transform, args[0]); } ConstEval::Result ConstEval::atanh(const sem::Type* ty, utils::VectorRef args, const Source& source) { auto transform = [&](const sem::Constant* c0) { auto create = [&](auto i) -> ImplResult { using NumberT = decltype(i); if (i.value <= NumberT(-1.0) || i.value >= NumberT(1.0)) { AddError("atanh must be called with a value in the range (-1, 1)", source); return utils::Failure; } return CreateElement(builder, c0->Type(), decltype(i)(std::atanh(i.value))); }; return Dispatch_fa_f32_f16(create, c0); }; return TransformElements(builder, ty, transform, args[0]); } ConstEval::Result ConstEval::atan2(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto create = [&](auto i, auto j) { return CreateElement(builder, c0->Type(), decltype(i)(std::atan2(i.value, j.value))); }; return Dispatch_fa_f32_f16(create, c0, c1); }; return TransformElements(builder, ty, transform, args[0], args[1]); } ConstEval::Result ConstEval::clamp(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1, const sem::Constant* c2) { auto create = [&](auto e, auto low, auto high) { return CreateElement(builder, c0->Type(), decltype(e)(std::min(std::max(e, low), high))); }; return Dispatch_fia_fiu32_f16(create, c0, c1, c2); }; return TransformElements(builder, ty, transform, args[0], args[1], args[2]); } ConstEval::Result ConstEval::countLeadingZeros(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c0) { auto create = [&](auto e) { using NumberT = decltype(e); using T = UnwrapNumber; auto count = CountLeadingBits(T{e}, T{0}); return CreateElement(builder, c0->Type(), NumberT(count)); }; return Dispatch_iu32(create, c0); }; return TransformElements(builder, ty, transform, args[0]); } ConstEval::Result ConstEval::countOneBits(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c0) { auto create = [&](auto e) { using NumberT = decltype(e); using T = UnwrapNumber; using UT = std::make_unsigned_t; constexpr UT kRightMost = UT{1}; auto count = UT{0}; for (auto v = static_cast(e); v != UT{0}; v >>= 1) { if ((v & kRightMost) == 1) { ++count; } } return CreateElement(builder, c0->Type(), NumberT(count)); }; return Dispatch_iu32(create, c0); }; return TransformElements(builder, ty, transform, args[0]); } ConstEval::Result ConstEval::countTrailingZeros(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c0) { auto create = [&](auto e) { using NumberT = decltype(e); using T = UnwrapNumber; auto count = CountTrailingBits(T{e}, T{0}); return CreateElement(builder, c0->Type(), NumberT(count)); }; return Dispatch_iu32(create, c0); }; return TransformElements(builder, ty, transform, args[0]); } ConstEval::Result ConstEval::firstLeadingBit(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c0) { auto create = [&](auto e) { using NumberT = decltype(e); using T = UnwrapNumber; using UT = std::make_unsigned_t; constexpr UT kNumBits = sizeof(UT) * 8; NumberT result; if constexpr (IsUnsignedIntegral) { if (e == T{0}) { // T(-1) if e is zero. result = NumberT(static_cast(-1)); } else { // Otherwise the position of the most significant 1 bit in e. static_assert(std::is_same_v); UT count = CountLeadingBits(UT{e}, UT{0}); UT pos = kNumBits - count - 1; result = NumberT(pos); } } else { if (e == T{0} || e == T{-1}) { // -1 if e is 0 or -1. result = NumberT(-1); } else { // Otherwise the position of the most significant bit in e that is different // from e's sign bit. UT eu = static_cast(e); UT sign_bit = eu >> (kNumBits - 1); UT count = CountLeadingBits(eu, sign_bit); UT pos = kNumBits - count - 1; result = NumberT(pos); } } return CreateElement(builder, c0->Type(), result); }; return Dispatch_iu32(create, c0); }; return TransformElements(builder, ty, transform, args[0]); } ConstEval::Result ConstEval::firstTrailingBit(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c0) { auto create = [&](auto e) { using NumberT = decltype(e); using T = UnwrapNumber; using UT = std::make_unsigned_t; NumberT result; if (e == T{0}) { // T(-1) if e is zero. result = NumberT(static_cast(-1)); } else { // Otherwise the position of the least significant 1 bit in e. UT pos = CountTrailingBits(T{e}, T{0}); result = NumberT(pos); } return CreateElement(builder, c0->Type(), result); }; return Dispatch_iu32(create, c0); }; return TransformElements(builder, ty, transform, args[0]); } ConstEval::Result ConstEval::insertBits(const sem::Type* ty, utils::VectorRef args, const Source& source) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto create = [&](auto in_e, auto in_newbits) -> ImplResult { using NumberT = decltype(in_e); using T = UnwrapNumber; using UT = std::make_unsigned_t; using NumberUT = Number; // Read args that are always scalar NumberUT in_offset = args[2]->As(); NumberUT in_count = args[3]->As(); constexpr UT w = sizeof(UT) * 8; if ((in_offset + in_count) > w) { AddError("'offset + 'count' must be less than or equal to the bit width of 'e'", source); return utils::Failure; } // Cast all to unsigned UT e = static_cast(in_e); UT newbits = static_cast(in_newbits); UT o = static_cast(in_offset); UT c = static_cast(in_count); NumberT result; if (c == UT{0}) { // The result is e if c is 0 result = NumberT{e}; } else if (c == w) { // The result is newbits if c is w result = NumberT{newbits}; } else { // Otherwise, bits o..o + c - 1 of the result are copied from bits 0..c - 1 of // newbits. Other bits of the result are copied from e. UT from = newbits << o; UT mask = ((UT{1} << c) - UT{1}) << UT{o}; auto r = e; // Start with 'e' as the result r = r & ~mask; // Zero the bits in 'e' we're overwriting r = r | (from & mask); // Overwrite from 'newbits' (shifted into position) result = NumberT{r}; } return CreateElement(builder, c0->Type(), result); }; return Dispatch_iu32(create, c0, c1); }; return TransformElements(builder, ty, transform, args[0], args[1]); } ConstEval::Result ConstEval::saturate(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c0) { auto create = [&](auto e) { using NumberT = decltype(e); return CreateElement(builder, c0->Type(), NumberT(std::min(std::max(e, NumberT(0.0)), NumberT(1.0)))); }; return Dispatch_fa_f32_f16(create, c0); }; return TransformElements(builder, ty, transform, args[0]); } ConstEval::Result ConstEval::select_bool(const sem::Type* ty, utils::VectorRef args, const Source&) { auto cond = args[2]->As(); auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto create = [&](auto f, auto t) -> ImplResult { return CreateElement(builder, sem::Type::DeepestElementOf(ty), cond ? t : f); }; return Dispatch_fia_fiu32_f16_bool(create, c0, c1); }; return TransformElements(builder, ty, transform, args[0], args[1]); } ConstEval::Result ConstEval::select_boolvec(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1, size_t index) { auto create = [&](auto f, auto t) -> ImplResult { // Get corresponding bool value at the current vector value index auto cond = args[2]->Index(index)->As(); return CreateElement(builder, sem::Type::DeepestElementOf(ty), cond ? t : f); }; return Dispatch_fia_fiu32_f16_bool(create, c0, c1); }; return TransformElements(builder, ty, transform, args[0], args[1]); } ConstEval::Result ConstEval::sign(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c0) { auto create = [&](auto e) -> ImplResult { using NumberT = decltype(e); NumberT result; NumberT zero{0.0}; if (e.value < zero) { result = NumberT{-1.0}; } else if (e.value > zero) { result = NumberT{1.0}; } else { result = zero; } return CreateElement(builder, c0->Type(), result); }; return Dispatch_fa_f32_f16(create, c0); }; return TransformElements(builder, ty, transform, args[0]); } ConstEval::Result ConstEval::step(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto create = [&](auto edge, auto x) -> ImplResult { using NumberT = decltype(edge); NumberT result = x.value < edge.value ? NumberT(0.0) : NumberT(1.0); return CreateElement(builder, c0->Type(), result); }; return Dispatch_fa_f32_f16(create, c0, c1); }; return TransformElements(builder, ty, transform, args[0], args[1]); } ConstEval::Result ConstEval::quantizeToF16(const sem::Type* ty, utils::VectorRef args, const Source&) { auto transform = [&](const sem::Constant* c) { auto conv = CheckedConvert(f16(c->As())); if (!conv) { // https://www.w3.org/TR/WGSL/#quantizeToF16-builtin // If e is outside the finite range of binary16, then the result is any value of type // f32 switch (conv.Failure()) { case ConversionFailure::kExceedsNegativeLimit: return CreateElement(builder, c->Type(), f16(f16::kLowestValue)); case ConversionFailure::kExceedsPositiveLimit: return CreateElement(builder, c->Type(), f16(f16::kHighestValue)); } } return CreateElement(builder, c->Type(), conv.Get()); }; return TransformElements(builder, ty, transform, args[0]); } ConstEval::Result ConstEval::Convert(const sem::Type* target_ty, const sem::Constant* value, const Source& source) { if (value->Type() == target_ty) { return value; } return static_cast(value)->Convert(builder, target_ty, source); } void ConstEval::AddError(const std::string& msg, const Source& source) const { builder.Diagnostics().add_error(diag::System::Resolver, msg, source); } void ConstEval::AddWarning(const std::string& msg, const Source& source) const { builder.Diagnostics().add_warning(diag::System::Resolver, msg, source); } } // namespace tint::resolver