Remove ImplResult.

The ImplResult type is the same as ConstEval::Result after recent
changes. This CL replaces all usages and removes ImplResult.

Bug: tint:1718
Change-Id: If424f3d00f953d97a339de8ae18c94083f3346bf
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/114162
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
dan sinclair 2022-12-14 19:23:48 +00:00 committed by Dawn LUCI CQ
parent 8626c9ee87
commit 9268561678
1 changed files with 114 additions and 114 deletions

View File

@ -232,25 +232,22 @@ std::make_unsigned_t<T> CountTrailingBits(T e, T bit_value_to_count) {
return count; return count;
} }
/// A result templated with a constant::Constant.
using ImplResult = utils::Result<const constant::Constant*>;
// Forward declaration // Forward declaration
const constant::Constant* CreateComposite(ProgramBuilder& builder, const constant::Constant* CreateComposite(ProgramBuilder& builder,
const type::Type* type, const type::Type* type,
utils::VectorRef<const constant::Constant*> elements); utils::VectorRef<const constant::Constant*> elements);
template <typename T> template <typename T>
ImplResult ScalarConvert(const constant::Scalar<T>* scalar, ConstEval::Result ScalarConvert(const constant::Scalar<T>* scalar,
ProgramBuilder& builder, ProgramBuilder& builder,
const type::Type* target_ty, const type::Type* target_ty,
const Source& source) { const Source& source) {
TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE); TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
if (target_ty == scalar->type) { if (target_ty == scalar->type) {
// If the types are identical, then no conversion is needed. // If the types are identical, then no conversion is needed.
return scalar; return scalar;
} }
return ZeroTypeDispatch(target_ty, [&](auto zero_to) -> ImplResult { return ZeroTypeDispatch(target_ty, [&](auto zero_to) -> ConstEval::Result {
// `value` is the source value. // `value` is the source value.
// `FROM` is the source type. // `FROM` is the source type.
// `TO` is the target type. // `TO` is the target type.
@ -299,15 +296,15 @@ ImplResult ScalarConvert(const constant::Scalar<T>* scalar,
} }
// Forward declare // Forward declare
ImplResult ConvertInternal(const constant::Constant* c, ConstEval::Result ConvertInternal(const constant::Constant* c,
ProgramBuilder& builder, ProgramBuilder& builder,
const type::Type* target_ty, const type::Type* target_ty,
const Source& source); const Source& source);
ImplResult SplatConvert(const constant::Splat* splat, ConstEval::Result SplatConvert(const constant::Splat* splat,
ProgramBuilder& builder, ProgramBuilder& builder,
const type::Type* target_ty, const type::Type* target_ty,
const Source& source) { const Source& source) {
// Convert the single splatted element type. // Convert the single splatted element type.
auto conv_el = ConvertInternal(splat->el, builder, type::Type::ElementOf(target_ty), source); auto conv_el = ConvertInternal(splat->el, builder, type::Type::ElementOf(target_ty), source);
if (!conv_el) { if (!conv_el) {
@ -319,10 +316,10 @@ ImplResult SplatConvert(const constant::Splat* splat,
return builder.create<constant::Splat>(target_ty, conv_el.Get(), splat->count); return builder.create<constant::Splat>(target_ty, conv_el.Get(), splat->count);
} }
ImplResult CompositeConvert(const constant::Composite* composite, ConstEval::Result CompositeConvert(const constant::Composite* composite,
ProgramBuilder& builder, ProgramBuilder& builder,
const type::Type* target_ty, const type::Type* target_ty,
const Source& source) { const Source& source) {
// Convert each of the composite element types. // Convert each of the composite element types.
utils::Vector<const constant::Constant*, 4> conv_els; utils::Vector<const constant::Constant*, 4> conv_els;
conv_els.Reserve(composite->elements.Length()); conv_els.Reserve(composite->elements.Length());
@ -353,10 +350,10 @@ ImplResult CompositeConvert(const constant::Composite* composite,
return CreateComposite(builder, target_ty, std::move(conv_els)); return CreateComposite(builder, target_ty, std::move(conv_els));
} }
ImplResult ConvertInternal(const constant::Constant* c, ConstEval::Result ConvertInternal(const constant::Constant* c,
ProgramBuilder& builder, ProgramBuilder& builder,
const type::Type* target_ty, const type::Type* target_ty,
const Source& source) { const Source& source) {
return Switch( return Switch(
c, c,
[&](const constant::Scalar<tint::AFloat>* val) { [&](const constant::Scalar<tint::AFloat>* val) {
@ -388,7 +385,10 @@ ImplResult ConvertInternal(const constant::Constant* c,
/// CreateScalar constructs and returns an constant::Scalar<T>. /// CreateScalar constructs and returns an constant::Scalar<T>.
template <typename T> template <typename T>
ImplResult CreateScalar(ProgramBuilder& builder, const Source& source, const type::Type* t, T v) { ConstEval::Result CreateScalar(ProgramBuilder& builder,
const Source& source,
const type::Type* t,
T v) {
static_assert(IsNumber<T> || std::is_same_v<T, bool>, "T must be a Number or bool"); static_assert(IsNumber<T> || std::is_same_v<T, bool>, "T must be a Number or bool");
TINT_ASSERT(Resolver, t->is_scalar()); TINT_ASSERT(Resolver, t->is_scalar());
@ -544,11 +544,11 @@ const constant::Constant* CreateComposite(ProgramBuilder& builder,
namespace detail { namespace detail {
/// Implementation of TransformElements /// Implementation of TransformElements
template <typename F, typename... CONSTANTS> template <typename F, typename... CONSTANTS>
ImplResult TransformElements(ProgramBuilder& builder, ConstEval::Result TransformElements(ProgramBuilder& builder,
const type::Type* composite_ty, const type::Type* composite_ty,
F&& f, F&& f,
size_t index, size_t index,
CONSTANTS&&... cs) { CONSTANTS&&... cs) {
uint32_t n = 0; uint32_t n = 0;
auto* ty = First(cs...)->Type(); auto* ty = First(cs...)->Type();
auto* el_ty = type::Type::ElementOf(ty, &n); auto* el_ty = type::Type::ElementOf(ty, &n);
@ -581,10 +581,10 @@ ImplResult TransformElements(ProgramBuilder& builder,
/// If `f`'s last argument is a `size_t`, then the index of the most deeply nested element inside /// If `f`'s last argument is a `size_t`, then the index of the most deeply nested element inside
/// the most deeply nested aggregate type will be passed in. /// the most deeply nested aggregate type will be passed in.
template <typename F, typename... CONSTANTS> template <typename F, typename... CONSTANTS>
ImplResult TransformElements(ProgramBuilder& builder, ConstEval::Result TransformElements(ProgramBuilder& builder,
const type::Type* composite_ty, const type::Type* composite_ty,
F&& f, F&& f,
CONSTANTS&&... cs) { CONSTANTS&&... cs) {
return detail::TransformElements(builder, composite_ty, f, 0, cs...); return detail::TransformElements(builder, composite_ty, f, 0, cs...);
} }
@ -593,11 +593,11 @@ ImplResult TransformElements(ProgramBuilder& builder,
/// Unlike TransformElements, this function handles the constants being of different arity, e.g. /// Unlike TransformElements, this function handles the constants being of different arity, e.g.
/// vector-scalar, scalar-vector. /// vector-scalar, scalar-vector.
template <typename F> template <typename F>
ImplResult TransformBinaryElements(ProgramBuilder& builder, ConstEval::Result TransformBinaryElements(ProgramBuilder& builder,
const type::Type* composite_ty, const type::Type* composite_ty,
F&& f, F&& f,
const constant::Constant* c0, const constant::Constant* c0,
const constant::Constant* c1) { const constant::Constant* c1) {
uint32_t n0 = 0; uint32_t n0 = 0;
type::Type::ElementOf(c0->Type(), &n0); type::Type::ElementOf(c0->Type(), &n0);
uint32_t n1 = 0; uint32_t n1 = 0;
@ -1027,7 +1027,7 @@ utils::Result<NumberT> ConstEval::Sqrt(const Source& source, NumberT v) {
} }
auto ConstEval::SqrtFunc(const Source& source, const type::Type* elem_ty) { auto ConstEval::SqrtFunc(const Source& source, const type::Type* elem_ty) {
return [=](auto v) -> ImplResult { return [=](auto v) -> ConstEval::Result {
if (auto r = Sqrt(source, v)) { if (auto r = Sqrt(source, v)) {
return CreateScalar(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
@ -1041,7 +1041,7 @@ utils::Result<NumberT> ConstEval::Clamp(const Source&, NumberT e, NumberT low, N
} }
auto ConstEval::ClampFunc(const Source& source, const type::Type* elem_ty) { auto ConstEval::ClampFunc(const Source& source, const type::Type* elem_ty) {
return [=](auto e, auto low, auto high) -> ImplResult { return [=](auto e, auto low, auto high) -> ConstEval::Result {
if (auto r = Clamp(source, e, low, high)) { if (auto r = Clamp(source, e, low, high)) {
return CreateScalar(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
@ -1050,7 +1050,7 @@ auto ConstEval::ClampFunc(const Source& source, const type::Type* elem_ty) {
} }
auto ConstEval::AddFunc(const Source& source, const type::Type* elem_ty) { auto ConstEval::AddFunc(const Source& source, const type::Type* elem_ty) {
return [=](auto a1, auto a2) -> ImplResult { return [=](auto a1, auto a2) -> ConstEval::Result {
if (auto r = Add(source, a1, a2)) { if (auto r = Add(source, a1, a2)) {
return CreateScalar(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
@ -1059,7 +1059,7 @@ auto ConstEval::AddFunc(const Source& source, const type::Type* elem_ty) {
} }
auto ConstEval::SubFunc(const Source& source, const type::Type* elem_ty) { auto ConstEval::SubFunc(const Source& source, const type::Type* elem_ty) {
return [=](auto a1, auto a2) -> ImplResult { return [=](auto a1, auto a2) -> ConstEval::Result {
if (auto r = Sub(source, a1, a2)) { if (auto r = Sub(source, a1, a2)) {
return CreateScalar(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
@ -1068,7 +1068,7 @@ auto ConstEval::SubFunc(const Source& source, const type::Type* elem_ty) {
} }
auto ConstEval::MulFunc(const Source& source, const type::Type* elem_ty) { auto ConstEval::MulFunc(const Source& source, const type::Type* elem_ty) {
return [=](auto a1, auto a2) -> ImplResult { return [=](auto a1, auto a2) -> ConstEval::Result {
if (auto r = Mul(source, a1, a2)) { if (auto r = Mul(source, a1, a2)) {
return CreateScalar(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
@ -1077,7 +1077,7 @@ auto ConstEval::MulFunc(const Source& source, const type::Type* elem_ty) {
} }
auto ConstEval::DivFunc(const Source& source, const type::Type* elem_ty) { auto ConstEval::DivFunc(const Source& source, const type::Type* elem_ty) {
return [=](auto a1, auto a2) -> ImplResult { return [=](auto a1, auto a2) -> ConstEval::Result {
if (auto r = Div(source, a1, a2)) { if (auto r = Div(source, a1, a2)) {
return CreateScalar(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
@ -1086,7 +1086,7 @@ auto ConstEval::DivFunc(const Source& source, const type::Type* elem_ty) {
} }
auto ConstEval::ModFunc(const Source& source, const type::Type* elem_ty) { auto ConstEval::ModFunc(const Source& source, const type::Type* elem_ty) {
return [=](auto a1, auto a2) -> ImplResult { return [=](auto a1, auto a2) -> ConstEval::Result {
if (auto r = Mod(source, a1, a2)) { if (auto r = Mod(source, a1, a2)) {
return CreateScalar(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
@ -1095,7 +1095,7 @@ auto ConstEval::ModFunc(const Source& source, const type::Type* elem_ty) {
} }
auto ConstEval::Dot2Func(const Source& source, const type::Type* elem_ty) { auto ConstEval::Dot2Func(const Source& source, const type::Type* elem_ty) {
return [=](auto a1, auto a2, auto b1, auto b2) -> ImplResult { return [=](auto a1, auto a2, auto b1, auto b2) -> ConstEval::Result {
if (auto r = Dot2(source, a1, a2, b1, b2)) { if (auto r = Dot2(source, a1, a2, b1, b2)) {
return CreateScalar(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
@ -1104,7 +1104,7 @@ auto ConstEval::Dot2Func(const Source& source, const type::Type* elem_ty) {
} }
auto ConstEval::Dot3Func(const Source& source, const type::Type* elem_ty) { auto ConstEval::Dot3Func(const Source& source, const type::Type* elem_ty) {
return [=](auto a1, auto a2, auto a3, auto b1, auto b2, auto b3) -> ImplResult { return [=](auto a1, auto a2, auto a3, auto b1, auto b2, auto b3) -> ConstEval::Result {
if (auto r = Dot3(source, a1, a2, a3, b1, b2, b3)) { if (auto r = Dot3(source, a1, a2, a3, b1, b2, b3)) {
return CreateScalar(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
@ -1113,13 +1113,13 @@ auto ConstEval::Dot3Func(const Source& source, const type::Type* elem_ty) {
} }
auto ConstEval::Dot4Func(const Source& source, const type::Type* elem_ty) { auto ConstEval::Dot4Func(const Source& source, const type::Type* elem_ty) {
return return [=](auto a1, auto a2, auto a3, auto a4, auto b1, auto b2, auto b3,
[=](auto a1, auto a2, auto a3, auto a4, auto b1, auto b2, auto b3, auto b4) -> ImplResult { auto b4) -> ConstEval::Result {
if (auto r = Dot4(source, a1, a2, a3, a4, b1, b2, b3, b4)) { if (auto r = Dot4(source, a1, a2, a3, a4, b1, b2, b3, b4)) {
return CreateScalar(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
return utils::Failure; return utils::Failure;
}; };
} }
ConstEval::Result ConstEval::Dot(const Source& source, ConstEval::Result ConstEval::Dot(const Source& source,
@ -1191,7 +1191,7 @@ ConstEval::Result ConstEval::Sub(const Source& source,
} }
auto ConstEval::Det2Func(const Source& source, const type::Type* elem_ty) { auto ConstEval::Det2Func(const Source& source, const type::Type* elem_ty) {
return [=](auto a, auto b, auto c, auto d) -> ImplResult { return [=](auto a, auto b, auto c, auto d) -> ConstEval::Result {
if (auto r = Det2(source, a, b, c, d)) { if (auto r = Det2(source, a, b, c, d)) {
return CreateScalar(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
@ -1200,18 +1200,18 @@ auto ConstEval::Det2Func(const Source& source, const type::Type* elem_ty) {
} }
auto ConstEval::Det3Func(const Source& source, const type::Type* elem_ty) { auto ConstEval::Det3Func(const Source& source, const type::Type* elem_ty) {
return return [=](auto a, auto b, auto c, auto d, auto e, auto f, auto g, auto h,
[=](auto a, auto b, auto c, auto d, auto e, auto f, auto g, auto h, auto i) -> ImplResult { auto i) -> ConstEval::Result {
if (auto r = Det3(source, a, b, c, d, e, f, g, h, i)) { if (auto r = Det3(source, a, b, c, d, e, f, g, h, i)) {
return CreateScalar(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
return utils::Failure; return utils::Failure;
}; };
} }
auto ConstEval::Det4Func(const Source& source, const type::Type* elem_ty) { auto ConstEval::Det4Func(const Source& source, const type::Type* elem_ty) {
return [=](auto a, auto b, auto c, auto d, auto e, auto f, auto g, auto h, auto i, auto j, return [=](auto a, auto b, auto c, auto d, auto e, auto f, auto g, auto h, auto i, auto j,
auto k, auto l, auto m, auto n, auto o, auto p) -> ImplResult { auto k, auto l, auto m, auto n, auto o, auto p) -> ConstEval::Result {
if (auto r = Det4(source, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p)) { if (auto r = Det4(source, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p)) {
return CreateScalar(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
@ -1226,7 +1226,7 @@ ConstEval::Result ConstEval::Literal(const type::Type* ty, const ast::LiteralExp
[&](const ast::BoolLiteralExpression* lit) { [&](const ast::BoolLiteralExpression* lit) {
return CreateScalar(builder, source, ty, lit->value); return CreateScalar(builder, source, ty, lit->value);
}, },
[&](const ast::IntLiteralExpression* lit) -> ImplResult { [&](const ast::IntLiteralExpression* lit) -> ConstEval::Result {
switch (lit->suffix) { switch (lit->suffix) {
case ast::IntLiteralExpression::Suffix::kNone: case ast::IntLiteralExpression::Suffix::kNone:
return CreateScalar(builder, source, ty, AInt(lit->value)); return CreateScalar(builder, source, ty, AInt(lit->value));
@ -1237,7 +1237,7 @@ ConstEval::Result ConstEval::Literal(const type::Type* ty, const ast::LiteralExp
} }
return nullptr; return nullptr;
}, },
[&](const ast::FloatLiteralExpression* lit) -> ImplResult { [&](const ast::FloatLiteralExpression* lit) -> ConstEval::Result {
switch (lit->suffix) { switch (lit->suffix) {
case ast::FloatLiteralExpression::Suffix::kNone: case ast::FloatLiteralExpression::Suffix::kNone:
return CreateScalar(builder, source, ty, AFloat(lit->value)); return CreateScalar(builder, source, ty, AFloat(lit->value));
@ -1500,7 +1500,7 @@ ConstEval::Result ConstEval::OpMultiplyMatVec(const type::Type* ty,
auto* elem_ty = vec_ty->type(); auto* elem_ty = vec_ty->type();
auto dot = [&](const constant::Constant* m, size_t row, const constant::Constant* v) { auto dot = [&](const constant::Constant* m, size_t row, const constant::Constant* v) {
ImplResult result; ConstEval::Result result;
switch (mat_ty->columns()) { switch (mat_ty->columns()) {
case 2: case 2:
result = Dispatch_fa_f32_f16(Dot2Func(source, elem_ty), // result = Dispatch_fa_f32_f16(Dot2Func(source, elem_ty), //
@ -1550,7 +1550,7 @@ ConstEval::Result ConstEval::OpMultiplyVecMat(const type::Type* ty,
auto* elem_ty = vec_ty->type(); auto* elem_ty = vec_ty->type();
auto dot = [&](const constant::Constant* v, const constant::Constant* m, size_t col) { auto dot = [&](const constant::Constant* v, const constant::Constant* m, size_t col) {
ImplResult result; ConstEval::Result result;
switch (mat_ty->rows()) { switch (mat_ty->rows()) {
case 2: case 2:
result = Dispatch_fa_f32_f16(Dot2Func(source, elem_ty), // result = Dispatch_fa_f32_f16(Dot2Func(source, elem_ty), //
@ -1607,7 +1607,7 @@ ConstEval::Result ConstEval::OpMultiplyMatMat(const type::Type* ty,
auto m1e = [&](size_t r, size_t c) { return m1->Index(c)->Index(r); }; auto m1e = [&](size_t r, size_t c) { return m1->Index(c)->Index(r); };
auto m2e = [&](size_t r, size_t c) { return m2->Index(c)->Index(r); }; auto m2e = [&](size_t r, size_t c) { return m2->Index(c)->Index(r); };
ImplResult result; ConstEval::Result result;
switch (mat1_ty->columns()) { switch (mat1_ty->columns()) {
case 2: case 2:
result = Dispatch_fa_f32_f16(Dot2Func(source, elem_ty), // result = Dispatch_fa_f32_f16(Dot2Func(source, elem_ty), //
@ -1682,7 +1682,7 @@ ConstEval::Result ConstEval::OpEqual(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult { auto create = [&](auto i, auto j) -> ConstEval::Result {
return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i == j); return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i == j);
}; };
return Dispatch_fia_fiu32_f16_bool(create, c0, c1); return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
@ -1695,7 +1695,7 @@ ConstEval::Result ConstEval::OpNotEqual(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult { auto create = [&](auto i, auto j) -> ConstEval::Result {
return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i != j); return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i != j);
}; };
return Dispatch_fia_fiu32_f16_bool(create, c0, c1); return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
@ -1708,7 +1708,7 @@ ConstEval::Result ConstEval::OpLessThan(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult { auto create = [&](auto i, auto j) -> ConstEval::Result {
return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i < j); return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i < j);
}; };
return Dispatch_fia_fiu32_f16(create, c0, c1); return Dispatch_fia_fiu32_f16(create, c0, c1);
@ -1721,7 +1721,7 @@ ConstEval::Result ConstEval::OpGreaterThan(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult { auto create = [&](auto i, auto j) -> ConstEval::Result {
return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i > j); return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i > j);
}; };
return Dispatch_fia_fiu32_f16(create, c0, c1); return Dispatch_fia_fiu32_f16(create, c0, c1);
@ -1734,7 +1734,7 @@ ConstEval::Result ConstEval::OpLessThanEqual(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult { auto create = [&](auto i, auto j) -> ConstEval::Result {
return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i <= j); return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i <= j);
}; };
return Dispatch_fia_fiu32_f16(create, c0, c1); return Dispatch_fia_fiu32_f16(create, c0, c1);
@ -1747,7 +1747,7 @@ ConstEval::Result ConstEval::OpGreaterThanEqual(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult { auto create = [&](auto i, auto j) -> ConstEval::Result {
return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i >= j); return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i >= j);
}; };
return Dispatch_fia_fiu32_f16(create, c0, c1); return Dispatch_fia_fiu32_f16(create, c0, c1);
@ -1776,7 +1776,7 @@ ConstEval::Result ConstEval::OpAnd(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult { auto create = [&](auto i, auto j) -> ConstEval::Result {
using T = decltype(i); using T = decltype(i);
T result; T result;
if constexpr (std::is_same_v<T, bool>) { if constexpr (std::is_same_v<T, bool>) {
@ -1796,7 +1796,7 @@ ConstEval::Result ConstEval::OpOr(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult { auto create = [&](auto i, auto j) -> ConstEval::Result {
using T = decltype(i); using T = decltype(i);
T result; T result;
if constexpr (std::is_same_v<T, bool>) { if constexpr (std::is_same_v<T, bool>) {
@ -1816,7 +1816,7 @@ ConstEval::Result ConstEval::OpXor(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult { auto create = [&](auto i, auto j) -> ConstEval::Result {
return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), return CreateScalar(builder, source, type::Type::DeepestElementOf(ty),
decltype(i){i ^ j}); decltype(i){i ^ j});
}; };
@ -1830,7 +1830,7 @@ ConstEval::Result ConstEval::OpShiftLeft(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto e1, auto e2) -> ImplResult { auto create = [&](auto e1, auto e2) -> ConstEval::Result {
using NumberT = decltype(e1); using NumberT = decltype(e1);
using T = UnwrapNumber<NumberT>; using T = UnwrapNumber<NumberT>;
using UT = std::make_unsigned_t<T>; using UT = std::make_unsigned_t<T>;
@ -1916,7 +1916,7 @@ ConstEval::Result ConstEval::OpShiftRight(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto e1, auto e2) -> ImplResult { auto create = [&](auto e1, auto e2) -> ConstEval::Result {
using NumberT = decltype(e1); using NumberT = decltype(e1);
using T = UnwrapNumber<NumberT>; using T = UnwrapNumber<NumberT>;
using UT = std::make_unsigned_t<T>; using UT = std::make_unsigned_t<T>;
@ -2005,7 +2005,7 @@ ConstEval::Result ConstEval::acos(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult { auto create = [&](auto i) -> ConstEval::Result {
using NumberT = decltype(i); using NumberT = decltype(i);
if (i < NumberT(-1.0) || i > NumberT(1.0)) { if (i < NumberT(-1.0) || i > NumberT(1.0)) {
AddError("acos must be called with a value in the range [-1 .. 1] (inclusive)", AddError("acos must be called with a value in the range [-1 .. 1] (inclusive)",
@ -2023,7 +2023,7 @@ ConstEval::Result ConstEval::acosh(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult { auto create = [&](auto i) -> ConstEval::Result {
using NumberT = decltype(i); using NumberT = decltype(i);
if (i < NumberT(1.0)) { if (i < NumberT(1.0)) {
AddError("acosh must be called with a value >= 1.0", source); AddError("acosh must be called with a value >= 1.0", source);
@ -2053,7 +2053,7 @@ ConstEval::Result ConstEval::asin(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult { auto create = [&](auto i) -> ConstEval::Result {
using NumberT = decltype(i); using NumberT = decltype(i);
if (i < NumberT(-1.0) || i > NumberT(1.0)) { if (i < NumberT(-1.0) || i > NumberT(1.0)) {
AddError("asin must be called with a value in the range [-1 .. 1] (inclusive)", AddError("asin must be called with a value in the range [-1 .. 1] (inclusive)",
@ -2096,7 +2096,7 @@ ConstEval::Result ConstEval::atanh(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult { auto create = [&](auto i) -> ConstEval::Result {
using NumberT = decltype(i); using NumberT = decltype(i);
if (i <= NumberT(-1.0) || i >= NumberT(1.0)) { if (i <= NumberT(-1.0) || i >= NumberT(1.0)) {
AddError("atanh must be called with a value in the range (-1 .. 1) (exclusive)", AddError("atanh must be called with a value in the range (-1 .. 1) (exclusive)",
@ -2150,7 +2150,7 @@ ConstEval::Result ConstEval::cos(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult { auto create = [&](auto i) -> ConstEval::Result {
using NumberT = decltype(i); using NumberT = decltype(i);
return CreateScalar(builder, source, c0->Type(), NumberT(std::cos(i.value))); return CreateScalar(builder, source, c0->Type(), NumberT(std::cos(i.value)));
}; };
@ -2163,7 +2163,7 @@ ConstEval::Result ConstEval::cosh(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult { auto create = [&](auto i) -> ConstEval::Result {
using NumberT = decltype(i); using NumberT = decltype(i);
return CreateScalar(builder, source, c0->Type(), NumberT(std::cosh(i.value))); return CreateScalar(builder, source, c0->Type(), NumberT(std::cosh(i.value)));
}; };
@ -2273,7 +2273,7 @@ ConstEval::Result ConstEval::degrees(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e) -> ImplResult { auto create = [&](auto e) -> ConstEval::Result {
using NumberT = decltype(e); using NumberT = decltype(e);
using T = UnwrapNumber<NumberT>; using T = UnwrapNumber<NumberT>;
@ -2334,7 +2334,7 @@ ConstEval::Result ConstEval::determinant(const type::Type* ty,
ConstEval::Result ConstEval::distance(const type::Type* ty, ConstEval::Result ConstEval::distance(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto err = [&]() -> ImplResult { auto err = [&]() -> ConstEval::Result {
AddNote("when calculating distance", source); AddNote("when calculating distance", source);
return utils::Failure; return utils::Failure;
}; };
@ -2365,7 +2365,7 @@ ConstEval::Result ConstEval::exp(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e0) -> ImplResult { auto create = [&](auto e0) -> ConstEval::Result {
using NumberT = decltype(e0); using NumberT = decltype(e0);
auto val = NumberT(std::exp(e0)); auto val = NumberT(std::exp(e0));
if (!std::isfinite(val.value)) { if (!std::isfinite(val.value)) {
@ -2383,7 +2383,7 @@ ConstEval::Result ConstEval::exp2(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e0) -> ImplResult { auto create = [&](auto e0) -> ConstEval::Result {
using NumberT = decltype(e0); using NumberT = decltype(e0);
auto val = NumberT(std::exp2(e0)); auto val = NumberT(std::exp2(e0));
if (!std::isfinite(val.value)) { if (!std::isfinite(val.value)) {
@ -2401,7 +2401,7 @@ ConstEval::Result ConstEval::extractBits(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto in_e) -> ImplResult { auto create = [&](auto in_e) -> ConstEval::Result {
using NumberT = decltype(in_e); using NumberT = decltype(in_e);
using T = UnwrapNumber<NumberT>; using T = UnwrapNumber<NumberT>;
using UT = std::make_unsigned_t<T>; using UT = std::make_unsigned_t<T>;
@ -2558,7 +2558,7 @@ ConstEval::Result ConstEval::fma(const type::Type* ty,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c1, const constant::Constant* c2, auto transform = [&](const constant::Constant* c1, const constant::Constant* c2,
const constant::Constant* c3) { const constant::Constant* c3) {
auto create = [&](auto e1, auto e2, auto e3) -> ImplResult { auto create = [&](auto e1, auto e2, auto e3) -> ConstEval::Result {
auto err_msg = [&] { auto err_msg = [&] {
AddNote("when calculating fma", source); AddNote("when calculating fma", source);
return utils::Failure; return utils::Failure;
@ -2584,7 +2584,7 @@ ConstEval::Result ConstEval::fract(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c1) { auto transform = [&](const constant::Constant* c1) {
auto create = [&](auto e) -> ImplResult { auto create = [&](auto e) -> ConstEval::Result {
using NumberT = decltype(e); using NumberT = decltype(e);
auto r = e - std::floor(e); auto r = e - std::floor(e);
return CreateScalar(builder, source, c1->Type(), NumberT{r}); return CreateScalar(builder, source, c1->Type(), NumberT{r});
@ -2600,8 +2600,8 @@ ConstEval::Result ConstEval::frexp(const type::Type* ty,
auto* arg = args[0]; auto* arg = args[0];
struct FractExp { struct FractExp {
ImplResult fract; ConstEval::Result fract;
ImplResult exp; ConstEval::Result exp;
}; };
auto scalar = [&](const constant::Constant* s) { auto scalar = [&](const constant::Constant* s) {
@ -2671,7 +2671,7 @@ ConstEval::Result ConstEval::insertBits(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto in_e, auto in_newbits) -> ImplResult { auto create = [&](auto in_e, auto in_newbits) -> ConstEval::Result {
using NumberT = decltype(in_e); using NumberT = decltype(in_e);
using T = UnwrapNumber<NumberT>; using T = UnwrapNumber<NumberT>;
using UT = std::make_unsigned_t<T>; using UT = std::make_unsigned_t<T>;
@ -2723,7 +2723,7 @@ ConstEval::Result ConstEval::inverseSqrt(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e) -> ImplResult { auto create = [&](auto e) -> ConstEval::Result {
using NumberT = decltype(e); using NumberT = decltype(e);
if (e <= NumberT(0)) { if (e <= NumberT(0)) {
@ -2767,7 +2767,7 @@ ConstEval::Result ConstEval::log(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto v) -> ImplResult { auto create = [&](auto v) -> ConstEval::Result {
using NumberT = decltype(v); using NumberT = decltype(v);
if (v <= NumberT(0)) { if (v <= NumberT(0)) {
AddError("log must be called with a value > 0", source); AddError("log must be called with a value > 0", source);
@ -2784,7 +2784,7 @@ ConstEval::Result ConstEval::log2(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto v) -> ImplResult { auto create = [&](auto v) -> ConstEval::Result {
using NumberT = decltype(v); using NumberT = decltype(v);
if (v <= NumberT(0)) { if (v <= NumberT(0)) {
AddError("log2 must be called with a value > 0", source); AddError("log2 must be called with a value > 0", source);
@ -2825,7 +2825,7 @@ ConstEval::Result ConstEval::mix(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1, size_t index) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1, size_t index) {
auto create = [&](auto e1, auto e2) -> ImplResult { auto create = [&](auto e1, auto e2) -> ConstEval::Result {
using NumberT = decltype(e1); using NumberT = decltype(e1);
// e3 is either a vector or a scalar // e3 is either a vector or a scalar
NumberT e3; NumberT e3;
@ -3019,7 +3019,7 @@ ConstEval::Result ConstEval::pow(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto e1, auto e2) -> ImplResult { auto create = [&](auto e1, auto e2) -> ConstEval::Result {
auto r = CheckedPow(e1, e2); auto r = CheckedPow(e1, e2);
if (!r) { if (!r) {
AddError(OverflowErrorMessage(e1, "^", e2), source); AddError(OverflowErrorMessage(e1, "^", e2), source);
@ -3036,7 +3036,7 @@ ConstEval::Result ConstEval::radians(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e) -> ImplResult { auto create = [&](auto e) -> ConstEval::Result {
using NumberT = decltype(e); using NumberT = decltype(e);
using T = UnwrapNumber<NumberT>; using T = UnwrapNumber<NumberT>;
@ -3076,7 +3076,7 @@ ConstEval::Result ConstEval::reflect(const type::Type* ty,
} }
// 2 * dot(e2, e1) // 2 * dot(e2, e1)
auto mul2 = [&](auto v) -> ImplResult { auto mul2 = [&](auto v) -> ConstEval::Result {
using NumberT = decltype(v); using NumberT = decltype(v);
return CreateScalar(builder, source, el_ty, NumberT{NumberT{2} * v}); return CreateScalar(builder, source, el_ty, NumberT{NumberT{2} * v});
}; };
@ -3203,7 +3203,7 @@ ConstEval::Result ConstEval::reverseBits(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto in_e) -> ImplResult { auto create = [&](auto in_e) -> ConstEval::Result {
using NumberT = decltype(in_e); using NumberT = decltype(in_e);
using T = UnwrapNumber<NumberT>; using T = UnwrapNumber<NumberT>;
using UT = std::make_unsigned_t<T>; using UT = std::make_unsigned_t<T>;
@ -3281,7 +3281,7 @@ ConstEval::Result ConstEval::select_bool(const type::Type* ty,
const Source& source) { const Source& source) {
auto cond = args[2]->As<bool>(); auto cond = args[2]->As<bool>();
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto f, auto t) -> ImplResult { auto create = [&](auto f, auto t) -> ConstEval::Result {
return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), cond ? t : f); return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), cond ? t : f);
}; };
return Dispatch_fia_fiu32_f16_bool(create, c0, c1); return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
@ -3294,7 +3294,7 @@ ConstEval::Result ConstEval::select_boolvec(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1, size_t index) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1, size_t index) {
auto create = [&](auto f, auto t) -> ImplResult { auto create = [&](auto f, auto t) -> ConstEval::Result {
// Get corresponding bool value at the current vector value index // Get corresponding bool value at the current vector value index
auto cond = args[2]->Index(index)->As<bool>(); auto cond = args[2]->Index(index)->As<bool>();
return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), cond ? t : f); return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), cond ? t : f);
@ -3309,7 +3309,7 @@ ConstEval::Result ConstEval::sign(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e) -> ImplResult { auto create = [&](auto e) -> ConstEval::Result {
using NumberT = decltype(e); using NumberT = decltype(e);
NumberT result; NumberT result;
NumberT zero{0.0}; NumberT zero{0.0};
@ -3331,7 +3331,7 @@ ConstEval::Result ConstEval::sin(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult { auto create = [&](auto i) -> ConstEval::Result {
using NumberT = decltype(i); using NumberT = decltype(i);
return CreateScalar(builder, source, c0->Type(), NumberT(std::sin(i.value))); return CreateScalar(builder, source, c0->Type(), NumberT(std::sin(i.value)));
}; };
@ -3344,7 +3344,7 @@ ConstEval::Result ConstEval::sinh(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult { auto create = [&](auto i) -> ConstEval::Result {
using NumberT = decltype(i); using NumberT = decltype(i);
return CreateScalar(builder, source, c0->Type(), NumberT(std::sinh(i.value))); return CreateScalar(builder, source, c0->Type(), NumberT(std::sinh(i.value)));
}; };
@ -3358,7 +3358,7 @@ ConstEval::Result ConstEval::smoothstep(const type::Type* ty,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1, auto transform = [&](const constant::Constant* c0, const constant::Constant* c1,
const constant::Constant* c2) { const constant::Constant* c2) {
auto create = [&](auto low, auto high, auto x) -> ImplResult { auto create = [&](auto low, auto high, auto x) -> ConstEval::Result {
using NumberT = decltype(low); using NumberT = decltype(low);
auto err = [&] { auto err = [&] {
@ -3408,7 +3408,7 @@ ConstEval::Result ConstEval::step(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto edge, auto x) -> ImplResult { auto create = [&](auto edge, auto x) -> ConstEval::Result {
using NumberT = decltype(edge); using NumberT = decltype(edge);
NumberT result = x.value < edge.value ? NumberT(0.0) : NumberT(1.0); NumberT result = x.value < edge.value ? NumberT(0.0) : NumberT(1.0);
return CreateScalar(builder, source, c0->Type(), result); return CreateScalar(builder, source, c0->Type(), result);
@ -3432,7 +3432,7 @@ ConstEval::Result ConstEval::tan(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult { auto create = [&](auto i) -> ConstEval::Result {
using NumberT = decltype(i); using NumberT = decltype(i);
return CreateScalar(builder, source, c0->Type(), NumberT(std::tan(i.value))); return CreateScalar(builder, source, c0->Type(), NumberT(std::tan(i.value)));
}; };
@ -3445,7 +3445,7 @@ ConstEval::Result ConstEval::tanh(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult { auto create = [&](auto i) -> ConstEval::Result {
using NumberT = decltype(i); using NumberT = decltype(i);
return CreateScalar(builder, source, c0->Type(), NumberT(std::tanh(i.value))); return CreateScalar(builder, source, c0->Type(), NumberT(std::tanh(i.value)));
}; };
@ -3591,7 +3591,7 @@ ConstEval::Result ConstEval::unpack4x8unorm(const type::Type* ty,
ConstEval::Result ConstEval::quantizeToF16(const type::Type* ty, ConstEval::Result ConstEval::quantizeToF16(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c) -> ImplResult { auto transform = [&](const constant::Constant* c) -> ConstEval::Result {
auto value = c->As<f32>(); auto value = c->As<f32>();
auto conv = CheckedConvert<f32>(f16(value)); auto conv = CheckedConvert<f32>(f16(value));
if (!conv) { if (!conv) {