Rename Element to Shader in ConstEval.

The const-eval Element is renamed to Scalar to better represent what is
stored.

Bug: tint:1718
Change-Id: I882a8d955f805bc04cea6794fdeaeba0ff2f2ae8
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/114101
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
dan sinclair 2022-12-14 13:43:51 +00:00 committed by Dawn LUCI CQ
parent 0335c7d65d
commit 2e737daaae
1 changed files with 134 additions and 136 deletions

View File

@ -269,20 +269,20 @@ const ImplConstant* CreateComposite(ProgramBuilder& builder,
const type::Type* type, const type::Type* type,
utils::VectorRef<const constant::Constant*> elements); utils::VectorRef<const constant::Constant*> elements);
/// Element holds a single scalar or abstract-numeric value. /// Scalar holds a single scalar or abstract-numeric value.
/// Element implements the Constant interface. /// Scalar implements the Constant interface.
template <typename T> template <typename T>
class Element : public Castable<Element<T>, ImplConstant> { class Scalar : public Castable<Scalar<T>, ImplConstant> {
public: public:
static_assert(!std::is_same_v<UnwrapNumber<T>, T> || std::is_same_v<T, bool>, static_assert(!std::is_same_v<UnwrapNumber<T>, T> || std::is_same_v<T, bool>,
"T must be a Number or bool"); "T must be a Number or bool");
Element(const type::Type* t, T v) : type(t), value(v) { Scalar(const type::Type* t, T v) : type(t), value(v) {
if constexpr (IsFloatingPoint<T>) { if constexpr (IsFloatingPoint<T>) {
TINT_ASSERT(Resolver, std::isfinite(v.value)); TINT_ASSERT(Resolver, std::isfinite(v.value));
} }
} }
~Element() override = default; ~Scalar() override = default;
const type::Type* Type() const override { return type; } const type::Type* Type() const override { return type; }
std::variant<std::monostate, AInt, AFloat> Value() const override { std::variant<std::monostate, AInt, AFloat> Value() const override {
if constexpr (IsFloatingPoint<UnwrapNumber<T>>) { if constexpr (IsFloatingPoint<UnwrapNumber<T>>) {
@ -313,13 +313,13 @@ class Element : public Castable<Element<T>, ImplConstant> {
using FROM = T; using FROM = T;
if constexpr (std::is_same_v<TO, bool>) { if constexpr (std::is_same_v<TO, bool>) {
// [x -> bool] // [x -> bool]
return builder.create<Element<TO>>(target_ty, !IsPositiveZero(value)); return builder.create<Scalar<TO>>(target_ty, !IsPositiveZero(value));
} else if constexpr (std::is_same_v<FROM, bool>) { } else if constexpr (std::is_same_v<FROM, bool>) {
// [bool -> x] // [bool -> x]
return builder.create<Element<TO>>(target_ty, TO(value ? 1 : 0)); return builder.create<Scalar<TO>>(target_ty, TO(value ? 1 : 0));
} else if (auto conv = CheckedConvert<TO>(value)) { } else if (auto conv = CheckedConvert<TO>(value)) {
// Conversion success // Conversion success
return builder.create<Element<TO>>(target_ty, conv.Get()); return builder.create<Scalar<TO>>(target_ty, conv.Get());
// --- Below this point are the failure cases --- // --- Below this point are the failure cases ---
} else if constexpr (IsAbstract<FROM>) { } else if constexpr (IsAbstract<FROM>) {
// [abstract-numeric -> x] - materialization failure // [abstract-numeric -> x] - materialization failure
@ -339,14 +339,14 @@ class Element : public Castable<Element<T>, ImplConstant> {
// https://www.w3.org/TR/WGSL/#floating-point-conversion // https://www.w3.org/TR/WGSL/#floating-point-conversion
switch (conv.Failure()) { switch (conv.Failure()) {
case ConversionFailure::kExceedsNegativeLimit: case ConversionFailure::kExceedsNegativeLimit:
return builder.create<Element<TO>>(target_ty, TO::Lowest()); return builder.create<Scalar<TO>>(target_ty, TO::Lowest());
case ConversionFailure::kExceedsPositiveLimit: case ConversionFailure::kExceedsPositiveLimit:
return builder.create<Element<TO>>(target_ty, TO::Highest()); return builder.create<Scalar<TO>>(target_ty, TO::Highest());
} }
} else if constexpr (IsIntegral<FROM>) { } else if constexpr (IsIntegral<FROM>) {
// [integer -> integer] - number not exactly representable // [integer -> integer] - number not exactly representable
// Static cast // Static cast
return builder.create<Element<TO>>(target_ty, static_cast<TO>(value)); return builder.create<Scalar<TO>>(target_ty, static_cast<TO>(value));
} }
return nullptr; // Expression is not constant. return nullptr; // Expression is not constant.
}); });
@ -472,22 +472,22 @@ class Composite : public Castable<Composite, ImplConstant> {
} // namespace tint::resolver } // namespace tint::resolver
TINT_INSTANTIATE_TYPEINFO(tint::resolver::ImplConstant); TINT_INSTANTIATE_TYPEINFO(tint::resolver::ImplConstant);
TINT_INSTANTIATE_TYPEINFO(tint::resolver::Element<tint::AInt>); TINT_INSTANTIATE_TYPEINFO(tint::resolver::Scalar<tint::AInt>);
TINT_INSTANTIATE_TYPEINFO(tint::resolver::Element<tint::AFloat>); TINT_INSTANTIATE_TYPEINFO(tint::resolver::Scalar<tint::AFloat>);
TINT_INSTANTIATE_TYPEINFO(tint::resolver::Element<tint::i32>); TINT_INSTANTIATE_TYPEINFO(tint::resolver::Scalar<tint::i32>);
TINT_INSTANTIATE_TYPEINFO(tint::resolver::Element<tint::u32>); TINT_INSTANTIATE_TYPEINFO(tint::resolver::Scalar<tint::u32>);
TINT_INSTANTIATE_TYPEINFO(tint::resolver::Element<tint::f16>); TINT_INSTANTIATE_TYPEINFO(tint::resolver::Scalar<tint::f16>);
TINT_INSTANTIATE_TYPEINFO(tint::resolver::Element<tint::f32>); TINT_INSTANTIATE_TYPEINFO(tint::resolver::Scalar<tint::f32>);
TINT_INSTANTIATE_TYPEINFO(tint::resolver::Element<bool>); TINT_INSTANTIATE_TYPEINFO(tint::resolver::Scalar<bool>);
TINT_INSTANTIATE_TYPEINFO(tint::resolver::Splat); TINT_INSTANTIATE_TYPEINFO(tint::resolver::Splat);
TINT_INSTANTIATE_TYPEINFO(tint::resolver::Composite); TINT_INSTANTIATE_TYPEINFO(tint::resolver::Composite);
namespace tint::resolver { namespace tint::resolver {
namespace { namespace {
/// CreateElement constructs and returns an Element<T>. /// CreateScalar constructs and returns an Scalar<T>.
template <typename T> template <typename T>
ImplResult CreateElement(ProgramBuilder& builder, const Source& source, const type::Type* t, T v) { ImplResult CreateScalar(ProgramBuilder& builder, const Source& source, const type::Type* t, T v) {
static_assert(IsNumber<T> || std::is_same_v<T, bool>, "T must be a Number or bool"); static_assert(IsNumber<T> || std::is_same_v<T, bool>, "T must be a Number or bool");
TINT_ASSERT(Resolver, t->is_scalar()); TINT_ASSERT(Resolver, t->is_scalar());
@ -498,7 +498,7 @@ ImplResult CreateElement(ProgramBuilder& builder, const Source& source, const ty
return utils::Failure; return utils::Failure;
} }
} }
return builder.create<Element<T>>(t, v); return builder.create<Scalar<T>>(t, v);
} }
/// ZeroValue returns a Constant for the zero-value of the type `type`. /// ZeroValue returns a Constant for the zero-value of the type `type`.
@ -541,7 +541,7 @@ const ImplConstant* ZeroValue(ProgramBuilder& builder, const type::Type* type) {
}, },
[&](Default) -> const ImplConstant* { [&](Default) -> const ImplConstant* {
return ZeroTypeDispatch(type, [&](auto zero) -> const ImplConstant* { return ZeroTypeDispatch(type, [&](auto zero) -> const ImplConstant* {
auto el = CreateElement(builder, Source{}, type, zero); auto el = CreateScalar(builder, Source{}, type, zero);
TINT_ASSERT(Resolver, el); TINT_ASSERT(Resolver, el);
return el.Get(); return el.Get();
}); });
@ -1128,7 +1128,7 @@ utils::Result<NumberT> ConstEval::Sqrt(const Source& source, NumberT v) {
auto ConstEval::SqrtFunc(const Source& source, const type::Type* elem_ty) { auto ConstEval::SqrtFunc(const Source& source, const type::Type* elem_ty) {
return [=](auto v) -> ImplResult { return [=](auto v) -> ImplResult {
if (auto r = Sqrt(source, v)) { if (auto r = Sqrt(source, v)) {
return CreateElement(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
return utils::Failure; return utils::Failure;
}; };
@ -1142,7 +1142,7 @@ utils::Result<NumberT> ConstEval::Clamp(const Source&, NumberT e, NumberT low, N
auto ConstEval::ClampFunc(const Source& source, const type::Type* elem_ty) { auto ConstEval::ClampFunc(const Source& source, const type::Type* elem_ty) {
return [=](auto e, auto low, auto high) -> ImplResult { return [=](auto e, auto low, auto high) -> ImplResult {
if (auto r = Clamp(source, e, low, high)) { if (auto r = Clamp(source, e, low, high)) {
return CreateElement(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
return utils::Failure; return utils::Failure;
}; };
@ -1151,7 +1151,7 @@ auto ConstEval::ClampFunc(const Source& source, const type::Type* elem_ty) {
auto ConstEval::AddFunc(const Source& source, const type::Type* elem_ty) { auto ConstEval::AddFunc(const Source& source, const type::Type* elem_ty) {
return [=](auto a1, auto a2) -> ImplResult { return [=](auto a1, auto a2) -> ImplResult {
if (auto r = Add(source, a1, a2)) { if (auto r = Add(source, a1, a2)) {
return CreateElement(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
return utils::Failure; return utils::Failure;
}; };
@ -1160,7 +1160,7 @@ auto ConstEval::AddFunc(const Source& source, const type::Type* elem_ty) {
auto ConstEval::SubFunc(const Source& source, const type::Type* elem_ty) { auto ConstEval::SubFunc(const Source& source, const type::Type* elem_ty) {
return [=](auto a1, auto a2) -> ImplResult { return [=](auto a1, auto a2) -> ImplResult {
if (auto r = Sub(source, a1, a2)) { if (auto r = Sub(source, a1, a2)) {
return CreateElement(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
return utils::Failure; return utils::Failure;
}; };
@ -1169,7 +1169,7 @@ auto ConstEval::SubFunc(const Source& source, const type::Type* elem_ty) {
auto ConstEval::MulFunc(const Source& source, const type::Type* elem_ty) { auto ConstEval::MulFunc(const Source& source, const type::Type* elem_ty) {
return [=](auto a1, auto a2) -> ImplResult { return [=](auto a1, auto a2) -> ImplResult {
if (auto r = Mul(source, a1, a2)) { if (auto r = Mul(source, a1, a2)) {
return CreateElement(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
return utils::Failure; return utils::Failure;
}; };
@ -1178,7 +1178,7 @@ auto ConstEval::MulFunc(const Source& source, const type::Type* elem_ty) {
auto ConstEval::DivFunc(const Source& source, const type::Type* elem_ty) { auto ConstEval::DivFunc(const Source& source, const type::Type* elem_ty) {
return [=](auto a1, auto a2) -> ImplResult { return [=](auto a1, auto a2) -> ImplResult {
if (auto r = Div(source, a1, a2)) { if (auto r = Div(source, a1, a2)) {
return CreateElement(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
return utils::Failure; return utils::Failure;
}; };
@ -1187,7 +1187,7 @@ auto ConstEval::DivFunc(const Source& source, const type::Type* elem_ty) {
auto ConstEval::ModFunc(const Source& source, const type::Type* elem_ty) { auto ConstEval::ModFunc(const Source& source, const type::Type* elem_ty) {
return [=](auto a1, auto a2) -> ImplResult { return [=](auto a1, auto a2) -> ImplResult {
if (auto r = Mod(source, a1, a2)) { if (auto r = Mod(source, a1, a2)) {
return CreateElement(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
return utils::Failure; return utils::Failure;
}; };
@ -1196,7 +1196,7 @@ auto ConstEval::ModFunc(const Source& source, const type::Type* elem_ty) {
auto ConstEval::Dot2Func(const Source& source, const type::Type* elem_ty) { auto ConstEval::Dot2Func(const Source& source, const type::Type* elem_ty) {
return [=](auto a1, auto a2, auto b1, auto b2) -> ImplResult { return [=](auto a1, auto a2, auto b1, auto b2) -> ImplResult {
if (auto r = Dot2(source, a1, a2, b1, b2)) { if (auto r = Dot2(source, a1, a2, b1, b2)) {
return CreateElement(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
return utils::Failure; return utils::Failure;
}; };
@ -1205,7 +1205,7 @@ auto ConstEval::Dot2Func(const Source& source, const type::Type* elem_ty) {
auto ConstEval::Dot3Func(const Source& source, const type::Type* elem_ty) { auto ConstEval::Dot3Func(const Source& source, const type::Type* elem_ty) {
return [=](auto a1, auto a2, auto a3, auto b1, auto b2, auto b3) -> ImplResult { return [=](auto a1, auto a2, auto a3, auto b1, auto b2, auto b3) -> ImplResult {
if (auto r = Dot3(source, a1, a2, a3, b1, b2, b3)) { if (auto r = Dot3(source, a1, a2, a3, b1, b2, b3)) {
return CreateElement(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
return utils::Failure; return utils::Failure;
}; };
@ -1215,7 +1215,7 @@ auto ConstEval::Dot4Func(const Source& source, const type::Type* elem_ty) {
return return
[=](auto a1, auto a2, auto a3, auto a4, auto b1, auto b2, auto b3, auto b4) -> ImplResult { [=](auto a1, auto a2, auto a3, auto a4, auto b1, auto b2, auto b3, auto b4) -> ImplResult {
if (auto r = Dot4(source, a1, a2, a3, a4, b1, b2, b3, b4)) { if (auto r = Dot4(source, a1, a2, a3, a4, b1, b2, b3, b4)) {
return CreateElement(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
return utils::Failure; return utils::Failure;
}; };
@ -1256,7 +1256,7 @@ ConstEval::Result ConstEval::Length(const Source& source,
if (vec_ty == nullptr) { if (vec_ty == nullptr) {
auto create = [&](auto e) { auto create = [&](auto e) {
using NumberT = decltype(e); using NumberT = decltype(e);
return CreateElement(builder, source, ty, NumberT{std::abs(e)}); return CreateScalar(builder, source, ty, NumberT{std::abs(e)});
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
} }
@ -1292,7 +1292,7 @@ ConstEval::Result ConstEval::Sub(const Source& source,
auto ConstEval::Det2Func(const Source& source, const type::Type* elem_ty) { auto ConstEval::Det2Func(const Source& source, const type::Type* elem_ty) {
return [=](auto a, auto b, auto c, auto d) -> ImplResult { return [=](auto a, auto b, auto c, auto d) -> ImplResult {
if (auto r = Det2(source, a, b, c, d)) { if (auto r = Det2(source, a, b, c, d)) {
return CreateElement(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
return utils::Failure; return utils::Failure;
}; };
@ -1302,7 +1302,7 @@ auto ConstEval::Det3Func(const Source& source, const type::Type* elem_ty) {
return return
[=](auto a, auto b, auto c, auto d, auto e, auto f, auto g, auto h, auto i) -> ImplResult { [=](auto a, auto b, auto c, auto d, auto e, auto f, auto g, auto h, auto i) -> ImplResult {
if (auto r = Det3(source, a, b, c, d, e, f, g, h, i)) { if (auto r = Det3(source, a, b, c, d, e, f, g, h, i)) {
return CreateElement(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
return utils::Failure; return utils::Failure;
}; };
@ -1312,7 +1312,7 @@ auto ConstEval::Det4Func(const Source& source, const type::Type* elem_ty) {
return [=](auto a, auto b, auto c, auto d, auto e, auto f, auto g, auto h, auto i, auto j, return [=](auto a, auto b, auto c, auto d, auto e, auto f, auto g, auto h, auto i, auto j,
auto k, auto l, auto m, auto n, auto o, auto p) -> ImplResult { auto k, auto l, auto m, auto n, auto o, auto p) -> ImplResult {
if (auto r = Det4(source, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p)) { if (auto r = Det4(source, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p)) {
return CreateElement(builder, source, elem_ty, r.Get()); return CreateScalar(builder, source, elem_ty, r.Get());
} }
return utils::Failure; return utils::Failure;
}; };
@ -1323,27 +1323,27 @@ ConstEval::Result ConstEval::Literal(const type::Type* ty, const ast::LiteralExp
return Switch( return Switch(
literal, literal,
[&](const ast::BoolLiteralExpression* lit) { [&](const ast::BoolLiteralExpression* lit) {
return CreateElement(builder, source, ty, lit->value); return CreateScalar(builder, source, ty, lit->value);
}, },
[&](const ast::IntLiteralExpression* lit) -> ImplResult { [&](const ast::IntLiteralExpression* lit) -> ImplResult {
switch (lit->suffix) { switch (lit->suffix) {
case ast::IntLiteralExpression::Suffix::kNone: case ast::IntLiteralExpression::Suffix::kNone:
return CreateElement(builder, source, ty, AInt(lit->value)); return CreateScalar(builder, source, ty, AInt(lit->value));
case ast::IntLiteralExpression::Suffix::kI: case ast::IntLiteralExpression::Suffix::kI:
return CreateElement(builder, source, ty, i32(lit->value)); return CreateScalar(builder, source, ty, i32(lit->value));
case ast::IntLiteralExpression::Suffix::kU: case ast::IntLiteralExpression::Suffix::kU:
return CreateElement(builder, source, ty, u32(lit->value)); return CreateScalar(builder, source, ty, u32(lit->value));
} }
return nullptr; return nullptr;
}, },
[&](const ast::FloatLiteralExpression* lit) -> ImplResult { [&](const ast::FloatLiteralExpression* lit) -> ImplResult {
switch (lit->suffix) { switch (lit->suffix) {
case ast::FloatLiteralExpression::Suffix::kNone: case ast::FloatLiteralExpression::Suffix::kNone:
return CreateElement(builder, source, ty, AFloat(lit->value)); return CreateScalar(builder, source, ty, AFloat(lit->value));
case ast::FloatLiteralExpression::Suffix::kF: case ast::FloatLiteralExpression::Suffix::kF:
return CreateElement(builder, source, ty, f32(lit->value)); return CreateScalar(builder, source, ty, f32(lit->value));
case ast::FloatLiteralExpression::Suffix::kH: case ast::FloatLiteralExpression::Suffix::kH:
return CreateElement(builder, source, ty, f16(lit->value)); return CreateScalar(builder, source, ty, f16(lit->value));
} }
return nullptr; return nullptr;
}); });
@ -1524,7 +1524,7 @@ ConstEval::Result ConstEval::OpComplement(const type::Type* ty,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c) { auto transform = [&](const constant::Constant* c) {
auto create = [&](auto i) { auto create = [&](auto i) {
return CreateElement(builder, source, c->Type(), decltype(i)(~i.value)); return CreateScalar(builder, source, c->Type(), decltype(i)(~i.value));
}; };
return Dispatch_ia_iu32(create, c); return Dispatch_ia_iu32(create, c);
}; };
@ -1546,9 +1546,9 @@ ConstEval::Result ConstEval::OpUnaryMinus(const type::Type* ty,
if (v != std::numeric_limits<T>::min()) { if (v != std::numeric_limits<T>::min()) {
v = -v; v = -v;
} }
return CreateElement(builder, source, c->Type(), decltype(i)(v)); return CreateScalar(builder, source, c->Type(), decltype(i)(v));
} else { } else {
return CreateElement(builder, source, c->Type(), decltype(i)(-i.value)); return CreateScalar(builder, source, c->Type(), decltype(i)(-i.value));
} }
}; };
return Dispatch_fia_fi32_f16(create, c); return Dispatch_fia_fi32_f16(create, c);
@ -1561,7 +1561,7 @@ ConstEval::Result ConstEval::OpNot(const type::Type* ty,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c) { auto transform = [&](const constant::Constant* c) {
auto create = [&](auto i) { auto create = [&](auto i) {
return CreateElement(builder, source, c->Type(), decltype(i)(!i)); return CreateScalar(builder, source, c->Type(), decltype(i)(!i));
}; };
return Dispatch_bool(create, c); return Dispatch_bool(create, c);
}; };
@ -1781,7 +1781,7 @@ ConstEval::Result ConstEval::OpEqual(const type::Type* ty,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult { auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, source, type::Type::DeepestElementOf(ty), i == j); return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i == j);
}; };
return Dispatch_fia_fiu32_f16_bool(create, c0, c1); return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
}; };
@ -1794,7 +1794,7 @@ ConstEval::Result ConstEval::OpNotEqual(const type::Type* ty,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult { auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, source, type::Type::DeepestElementOf(ty), i != j); return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i != j);
}; };
return Dispatch_fia_fiu32_f16_bool(create, c0, c1); return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
}; };
@ -1807,7 +1807,7 @@ ConstEval::Result ConstEval::OpLessThan(const type::Type* ty,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult { auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, source, type::Type::DeepestElementOf(ty), i < j); return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i < j);
}; };
return Dispatch_fia_fiu32_f16(create, c0, c1); return Dispatch_fia_fiu32_f16(create, c0, c1);
}; };
@ -1820,7 +1820,7 @@ ConstEval::Result ConstEval::OpGreaterThan(const type::Type* ty,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult { auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, source, type::Type::DeepestElementOf(ty), i > j); return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i > j);
}; };
return Dispatch_fia_fiu32_f16(create, c0, c1); return Dispatch_fia_fiu32_f16(create, c0, c1);
}; };
@ -1833,7 +1833,7 @@ ConstEval::Result ConstEval::OpLessThanEqual(const type::Type* ty,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult { auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, source, type::Type::DeepestElementOf(ty), i <= j); return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i <= j);
}; };
return Dispatch_fia_fiu32_f16(create, c0, c1); return Dispatch_fia_fiu32_f16(create, c0, c1);
}; };
@ -1846,7 +1846,7 @@ ConstEval::Result ConstEval::OpGreaterThanEqual(const type::Type* ty,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult { auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, source, type::Type::DeepestElementOf(ty), i >= j); return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i >= j);
}; };
return Dispatch_fia_fiu32_f16(create, c0, c1); return Dispatch_fia_fiu32_f16(create, c0, c1);
}; };
@ -1859,7 +1859,7 @@ ConstEval::Result ConstEval::OpLogicalAnd(const type::Type* ty,
const Source& source) { const Source& source) {
// Note: Due to short-circuiting, this function is only called if lhs is true, so we could // Note: Due to short-circuiting, this function is only called if lhs is true, so we could
// technically only return the value of the rhs. // technically only return the value of the rhs.
return CreateElement(builder, source, ty, args[0]->As<bool>() && args[1]->As<bool>()); return CreateScalar(builder, source, ty, args[0]->As<bool>() && args[1]->As<bool>());
} }
ConstEval::Result ConstEval::OpLogicalOr(const type::Type* ty, ConstEval::Result ConstEval::OpLogicalOr(const type::Type* ty,
@ -1867,7 +1867,7 @@ ConstEval::Result ConstEval::OpLogicalOr(const type::Type* ty,
const Source& source) { const Source& source) {
// Note: Due to short-circuiting, this function is only called if lhs is false, so we could // Note: Due to short-circuiting, this function is only called if lhs is false, so we could
// technically only return the value of the rhs. // technically only return the value of the rhs.
return CreateElement(builder, source, ty, args[1]->As<bool>()); return CreateScalar(builder, source, ty, args[1]->As<bool>());
} }
ConstEval::Result ConstEval::OpAnd(const type::Type* ty, ConstEval::Result ConstEval::OpAnd(const type::Type* ty,
@ -1882,7 +1882,7 @@ ConstEval::Result ConstEval::OpAnd(const type::Type* ty,
} else { // integral } else { // integral
result = i & j; result = i & j;
} }
return CreateElement(builder, source, type::Type::DeepestElementOf(ty), result); return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), result);
}; };
return Dispatch_ia_iu32_bool(create, c0, c1); return Dispatch_ia_iu32_bool(create, c0, c1);
}; };
@ -1902,7 +1902,7 @@ ConstEval::Result ConstEval::OpOr(const type::Type* ty,
} else { // integral } else { // integral
result = i | j; result = i | j;
} }
return CreateElement(builder, source, type::Type::DeepestElementOf(ty), result); return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), result);
}; };
return Dispatch_ia_iu32_bool(create, c0, c1); return Dispatch_ia_iu32_bool(create, c0, c1);
}; };
@ -1915,8 +1915,8 @@ ConstEval::Result ConstEval::OpXor(const type::Type* ty,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) -> ImplResult { auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, source, type::Type::DeepestElementOf(ty), return CreateScalar(builder, source, type::Type::DeepestElementOf(ty),
decltype(i){i ^ j}); decltype(i){i ^ j});
}; };
return Dispatch_ia_iu32(create, c0, c1); return Dispatch_ia_iu32(create, c0, c1);
}; };
@ -1996,8 +1996,7 @@ ConstEval::Result ConstEval::OpShiftLeft(const type::Type* ty,
// Avoid UB by left shifting as unsigned value // Avoid UB by left shifting as unsigned value
auto result = static_cast<T>(static_cast<UT>(e1) << e2); auto result = static_cast<T>(static_cast<UT>(e1) << e2);
return CreateElement(builder, source, type::Type::DeepestElementOf(ty), return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), NumberT{result});
NumberT{result});
}; };
return Dispatch_ia_iu32(create, c0, c1); return Dispatch_ia_iu32(create, c0, c1);
}; };
@ -2061,8 +2060,7 @@ ConstEval::Result ConstEval::OpShiftRight(const type::Type* ty,
result = e1 >> e2; result = e1 >> e2;
} }
} }
return CreateElement(builder, source, type::Type::DeepestElementOf(ty), return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), NumberT{result});
NumberT{result});
}; };
return Dispatch_ia_iu32(create, c0, c1); return Dispatch_ia_iu32(create, c0, c1);
}; };
@ -2094,7 +2092,7 @@ ConstEval::Result ConstEval::abs(const type::Type* ty,
} else { } else {
result = NumberT{std::abs(e)}; result = NumberT{std::abs(e)};
} }
return CreateElement(builder, source, c0->Type(), result); return CreateScalar(builder, source, c0->Type(), result);
}; };
return Dispatch_fia_fiu32_f16(create, c0); return Dispatch_fia_fiu32_f16(create, c0);
}; };
@ -2112,7 +2110,7 @@ ConstEval::Result ConstEval::acos(const type::Type* ty,
source); source);
return utils::Failure; return utils::Failure;
} }
return CreateElement(builder, source, c0->Type(), NumberT(std::acos(i.value))); return CreateScalar(builder, source, c0->Type(), NumberT(std::acos(i.value)));
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
}; };
@ -2129,7 +2127,7 @@ ConstEval::Result ConstEval::acosh(const type::Type* ty,
AddError("acosh must be called with a value >= 1.0", source); AddError("acosh must be called with a value >= 1.0", source);
return utils::Failure; return utils::Failure;
} }
return CreateElement(builder, source, c0->Type(), NumberT(std::acosh(i.value))); return CreateScalar(builder, source, c0->Type(), NumberT(std::acosh(i.value)));
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
}; };
@ -2140,13 +2138,13 @@ ConstEval::Result ConstEval::acosh(const type::Type* ty,
ConstEval::Result ConstEval::all(const type::Type* ty, ConstEval::Result ConstEval::all(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
return CreateElement(builder, source, ty, !args[0]->AnyZero()); return CreateScalar(builder, source, ty, !args[0]->AnyZero());
} }
ConstEval::Result ConstEval::any(const type::Type* ty, ConstEval::Result ConstEval::any(const type::Type* ty,
utils::VectorRef<const constant::Constant*> args, utils::VectorRef<const constant::Constant*> args,
const Source& source) { const Source& source) {
return CreateElement(builder, source, ty, !args[0]->AllZero()); return CreateScalar(builder, source, ty, !args[0]->AllZero());
} }
ConstEval::Result ConstEval::asin(const type::Type* ty, ConstEval::Result ConstEval::asin(const type::Type* ty,
@ -2160,7 +2158,7 @@ ConstEval::Result ConstEval::asin(const type::Type* ty,
source); source);
return utils::Failure; return utils::Failure;
} }
return CreateElement(builder, source, c0->Type(), NumberT(std::asin(i.value))); return CreateScalar(builder, source, c0->Type(), NumberT(std::asin(i.value)));
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
}; };
@ -2172,7 +2170,7 @@ ConstEval::Result ConstEval::asinh(const type::Type* ty,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) { auto create = [&](auto i) {
return CreateElement(builder, source, c0->Type(), decltype(i)(std::asinh(i.value))); return CreateScalar(builder, source, c0->Type(), decltype(i)(std::asinh(i.value)));
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
}; };
@ -2185,7 +2183,7 @@ ConstEval::Result ConstEval::atan(const type::Type* ty,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) { auto create = [&](auto i) {
return CreateElement(builder, source, c0->Type(), decltype(i)(std::atan(i.value))); return CreateScalar(builder, source, c0->Type(), decltype(i)(std::atan(i.value)));
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
}; };
@ -2203,7 +2201,7 @@ ConstEval::Result ConstEval::atanh(const type::Type* ty,
source); source);
return utils::Failure; return utils::Failure;
} }
return CreateElement(builder, source, c0->Type(), NumberT(std::atanh(i.value))); return CreateScalar(builder, source, c0->Type(), NumberT(std::atanh(i.value)));
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
}; };
@ -2216,8 +2214,8 @@ ConstEval::Result ConstEval::atan2(const type::Type* ty,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto i, auto j) { auto create = [&](auto i, auto j) {
return CreateElement(builder, source, c0->Type(), return CreateScalar(builder, source, c0->Type(),
decltype(i)(std::atan2(i.value, j.value))); decltype(i)(std::atan2(i.value, j.value)));
}; };
return Dispatch_fa_f32_f16(create, c0, c1); return Dispatch_fa_f32_f16(create, c0, c1);
}; };
@ -2229,7 +2227,7 @@ ConstEval::Result ConstEval::ceil(const type::Type* ty,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e) { auto create = [&](auto e) {
return CreateElement(builder, source, c0->Type(), decltype(e)(std::ceil(e))); return CreateScalar(builder, source, c0->Type(), decltype(e)(std::ceil(e)));
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
}; };
@ -2252,7 +2250,7 @@ ConstEval::Result ConstEval::cos(const type::Type* ty,
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult { auto create = [&](auto i) -> ImplResult {
using NumberT = decltype(i); using NumberT = decltype(i);
return CreateElement(builder, source, c0->Type(), NumberT(std::cos(i.value))); return CreateScalar(builder, source, c0->Type(), NumberT(std::cos(i.value)));
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
}; };
@ -2265,7 +2263,7 @@ ConstEval::Result ConstEval::cosh(const type::Type* ty,
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult { auto create = [&](auto i) -> ImplResult {
using NumberT = decltype(i); using NumberT = decltype(i);
return CreateElement(builder, source, c0->Type(), NumberT(std::cosh(i.value))); return CreateScalar(builder, source, c0->Type(), NumberT(std::cosh(i.value)));
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
}; };
@ -2280,7 +2278,7 @@ ConstEval::Result ConstEval::countLeadingZeros(const type::Type* ty,
using NumberT = decltype(e); using NumberT = decltype(e);
using T = UnwrapNumber<NumberT>; using T = UnwrapNumber<NumberT>;
auto count = CountLeadingBits(T{e}, T{0}); auto count = CountLeadingBits(T{e}, T{0});
return CreateElement(builder, source, c0->Type(), NumberT(count)); return CreateScalar(builder, source, c0->Type(), NumberT(count));
}; };
return Dispatch_iu32(create, c0); return Dispatch_iu32(create, c0);
}; };
@ -2304,7 +2302,7 @@ ConstEval::Result ConstEval::countOneBits(const type::Type* ty,
} }
} }
return CreateElement(builder, source, c0->Type(), NumberT(count)); return CreateScalar(builder, source, c0->Type(), NumberT(count));
}; };
return Dispatch_iu32(create, c0); return Dispatch_iu32(create, c0);
}; };
@ -2319,7 +2317,7 @@ ConstEval::Result ConstEval::countTrailingZeros(const type::Type* ty,
using NumberT = decltype(e); using NumberT = decltype(e);
using T = UnwrapNumber<NumberT>; using T = UnwrapNumber<NumberT>;
auto count = CountTrailingBits(T{e}, T{0}); auto count = CountTrailingBits(T{e}, T{0});
return CreateElement(builder, source, c0->Type(), NumberT(count)); return CreateScalar(builder, source, c0->Type(), NumberT(count));
}; };
return Dispatch_iu32(create, c0); return Dispatch_iu32(create, c0);
}; };
@ -2388,7 +2386,7 @@ ConstEval::Result ConstEval::degrees(const type::Type* ty,
AddNote("when calculating degrees", source); AddNote("when calculating degrees", source);
return utils::Failure; return utils::Failure;
} }
return CreateElement(builder, source, c0->Type(), result.Get()); return CreateScalar(builder, source, c0->Type(), result.Get());
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
}; };
@ -2472,7 +2470,7 @@ ConstEval::Result ConstEval::exp(const type::Type* ty,
AddError(OverflowExpErrorMessage("e", e0), source); AddError(OverflowExpErrorMessage("e", e0), source);
return utils::Failure; return utils::Failure;
} }
return CreateElement(builder, source, c0->Type(), val); return CreateScalar(builder, source, c0->Type(), val);
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
}; };
@ -2490,7 +2488,7 @@ ConstEval::Result ConstEval::exp2(const type::Type* ty,
AddError(OverflowExpErrorMessage("2", e0), source); AddError(OverflowExpErrorMessage("2", e0), source);
return utils::Failure; return utils::Failure;
} }
return CreateElement(builder, source, c0->Type(), val); return CreateScalar(builder, source, c0->Type(), val);
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
}; };
@ -2545,7 +2543,7 @@ ConstEval::Result ConstEval::extractBits(const type::Type* ty,
result = NumberT{r}; result = NumberT{r};
} }
return CreateElement(builder, source, c0->Type(), result); return CreateScalar(builder, source, c0->Type(), result);
}; };
return Dispatch_iu32(create, c0); return Dispatch_iu32(create, c0);
}; };
@ -2608,7 +2606,7 @@ ConstEval::Result ConstEval::firstLeadingBit(const type::Type* ty,
} }
} }
return CreateElement(builder, source, c0->Type(), result); return CreateScalar(builder, source, c0->Type(), result);
}; };
return Dispatch_iu32(create, c0); return Dispatch_iu32(create, c0);
}; };
@ -2634,7 +2632,7 @@ ConstEval::Result ConstEval::firstTrailingBit(const type::Type* ty,
result = NumberT(pos); result = NumberT(pos);
} }
return CreateElement(builder, source, c0->Type(), result); return CreateScalar(builder, source, c0->Type(), result);
}; };
return Dispatch_iu32(create, c0); return Dispatch_iu32(create, c0);
}; };
@ -2646,7 +2644,7 @@ ConstEval::Result ConstEval::floor(const type::Type* ty,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e) { auto create = [&](auto e) {
return CreateElement(builder, source, c0->Type(), decltype(e)(std::floor(e))); return CreateScalar(builder, source, c0->Type(), decltype(e)(std::floor(e)));
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
}; };
@ -2673,7 +2671,7 @@ ConstEval::Result ConstEval::fma(const type::Type* ty,
if (!val) { if (!val) {
return err_msg(); return err_msg();
} }
return CreateElement(builder, source, c1->Type(), val.Get()); return CreateScalar(builder, source, c1->Type(), val.Get());
}; };
return Dispatch_fa_f32_f16(create, c1, c2, c3); return Dispatch_fa_f32_f16(create, c1, c2, c3);
}; };
@ -2687,7 +2685,7 @@ ConstEval::Result ConstEval::fract(const type::Type* ty,
auto create = [&](auto e) -> ImplResult { auto create = [&](auto e) -> ImplResult {
using NumberT = decltype(e); using NumberT = decltype(e);
auto r = e - std::floor(e); auto r = e - std::floor(e);
return CreateElement(builder, source, c1->Type(), NumberT{r}); return CreateScalar(builder, source, c1->Type(), NumberT{r});
}; };
return Dispatch_fa_f32_f16(create, c1); return Dispatch_fa_f32_f16(create, c1);
}; };
@ -2711,21 +2709,21 @@ ConstEval::Result ConstEval::frexp(const type::Type* ty,
s->Type(), s->Type(),
[&](const type::F32*) { [&](const type::F32*) {
return FractExp{ return FractExp{
CreateElement(builder, source, builder.create<type::F32>(), f32(fract)), CreateScalar(builder, source, builder.create<type::F32>(), f32(fract)),
CreateElement(builder, source, builder.create<type::I32>(), i32(exp)), CreateScalar(builder, source, builder.create<type::I32>(), i32(exp)),
}; };
}, },
[&](const type::F16*) { [&](const type::F16*) {
return FractExp{ return FractExp{
CreateElement(builder, source, builder.create<type::F16>(), f16(fract)), CreateScalar(builder, source, builder.create<type::F16>(), f16(fract)),
CreateElement(builder, source, builder.create<type::I32>(), i32(exp)), CreateScalar(builder, source, builder.create<type::I32>(), i32(exp)),
}; };
}, },
[&](const type::AbstractFloat*) { [&](const type::AbstractFloat*) {
return FractExp{ return FractExp{
CreateElement(builder, source, builder.create<type::AbstractFloat>(), CreateScalar(builder, source, builder.create<type::AbstractFloat>(),
AFloat(fract)), AFloat(fract)),
CreateElement(builder, source, builder.create<type::AbstractInt>(), AInt(exp)), CreateScalar(builder, source, builder.create<type::AbstractInt>(), AInt(exp)),
}; };
}, },
[&](Default) { [&](Default) {
@ -2812,7 +2810,7 @@ ConstEval::Result ConstEval::insertBits(const type::Type* ty,
result = NumberT{r}; result = NumberT{r};
} }
return CreateElement(builder, source, c0->Type(), result); return CreateScalar(builder, source, c0->Type(), result);
}; };
return Dispatch_iu32(create, c0, c1); return Dispatch_iu32(create, c0, c1);
}; };
@ -2845,7 +2843,7 @@ ConstEval::Result ConstEval::inverseSqrt(const type::Type* ty,
return err(); return err();
} }
return CreateElement(builder, source, c0->Type(), div.Get()); return CreateScalar(builder, source, c0->Type(), div.Get());
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
}; };
@ -2873,7 +2871,7 @@ ConstEval::Result ConstEval::log(const type::Type* ty,
AddError("log must be called with a value > 0", source); AddError("log must be called with a value > 0", source);
return utils::Failure; return utils::Failure;
} }
return CreateElement(builder, source, c0->Type(), NumberT(std::log(v))); return CreateScalar(builder, source, c0->Type(), NumberT(std::log(v)));
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
}; };
@ -2890,7 +2888,7 @@ ConstEval::Result ConstEval::log2(const type::Type* ty,
AddError("log2 must be called with a value > 0", source); AddError("log2 must be called with a value > 0", source);
return utils::Failure; return utils::Failure;
} }
return CreateElement(builder, source, c0->Type(), NumberT(std::log2(v))); return CreateScalar(builder, source, c0->Type(), NumberT(std::log2(v)));
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
}; };
@ -2902,7 +2900,7 @@ ConstEval::Result ConstEval::max(const type::Type* ty,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto e0, auto e1) { auto create = [&](auto e0, auto e1) {
return CreateElement(builder, source, c0->Type(), decltype(e0)(std::max(e0, e1))); return CreateScalar(builder, source, c0->Type(), decltype(e0)(std::max(e0, e1)));
}; };
return Dispatch_fia_fiu32_f16(create, c0, c1); return Dispatch_fia_fiu32_f16(create, c0, c1);
}; };
@ -2914,7 +2912,7 @@ ConstEval::Result ConstEval::min(const type::Type* ty,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto e0, auto e1) { auto create = [&](auto e0, auto e1) {
return CreateElement(builder, source, c0->Type(), decltype(e0)(std::min(e0, e1))); return CreateScalar(builder, source, c0->Type(), decltype(e0)(std::min(e0, e1)));
}; };
return Dispatch_fia_fiu32_f16(create, c0, c1); return Dispatch_fia_fiu32_f16(create, c0, c1);
}; };
@ -2953,7 +2951,7 @@ ConstEval::Result ConstEval::mix(const type::Type* ty,
if (!r) { if (!r) {
return utils::Failure; return utils::Failure;
} }
return CreateElement(builder, source, c0->Type(), r.Get()); return CreateScalar(builder, source, c0->Type(), r.Get());
}; };
return Dispatch_fa_f32_f16(create, c0, c1); return Dispatch_fa_f32_f16(create, c0, c1);
}; };
@ -2969,14 +2967,14 @@ ConstEval::Result ConstEval::modf(const type::Type* ty,
const Source& source) { const Source& source) {
auto transform_fract = [&](const constant::Constant* c) { auto transform_fract = [&](const constant::Constant* c) {
auto create = [&](auto e) { auto create = [&](auto e) {
return CreateElement(builder, source, c->Type(), return CreateScalar(builder, source, c->Type(),
decltype(e)(e.value - std::trunc(e.value))); decltype(e)(e.value - std::trunc(e.value)));
}; };
return Dispatch_fa_f32_f16(create, c); return Dispatch_fa_f32_f16(create, c);
}; };
auto transform_whole = [&](const constant::Constant* c) { auto transform_whole = [&](const constant::Constant* c) {
auto create = [&](auto e) { auto create = [&](auto e) {
return CreateElement(builder, source, c->Type(), decltype(e)(std::trunc(e.value))); return CreateScalar(builder, source, c->Type(), decltype(e)(std::trunc(e.value)));
}; };
return Dispatch_fa_f32_f16(create, c); return Dispatch_fa_f32_f16(create, c);
}; };
@ -3040,7 +3038,7 @@ ConstEval::Result ConstEval::pack2x16float(const type::Type* ty,
} }
u32 ret = u32((e0.Get() & 0x0000'ffff) | (e1.Get() << 16)); u32 ret = u32((e0.Get() & 0x0000'ffff) | (e1.Get() << 16));
return CreateElement(builder, source, ty, ret); return CreateScalar(builder, source, ty, ret);
} }
ConstEval::Result ConstEval::pack2x16snorm(const type::Type* ty, ConstEval::Result ConstEval::pack2x16snorm(const type::Type* ty,
@ -3057,7 +3055,7 @@ ConstEval::Result ConstEval::pack2x16snorm(const type::Type* ty,
auto e1 = calc(e->Index(1)->As<f32>()); auto e1 = calc(e->Index(1)->As<f32>());
u32 ret = u32((e0 & 0x0000'ffff) | (e1 << 16)); u32 ret = u32((e0 & 0x0000'ffff) | (e1 << 16));
return CreateElement(builder, source, ty, ret); return CreateScalar(builder, source, ty, ret);
} }
ConstEval::Result ConstEval::pack2x16unorm(const type::Type* ty, ConstEval::Result ConstEval::pack2x16unorm(const type::Type* ty,
@ -3073,7 +3071,7 @@ ConstEval::Result ConstEval::pack2x16unorm(const type::Type* ty,
auto e1 = calc(e->Index(1)->As<f32>()); auto e1 = calc(e->Index(1)->As<f32>());
u32 ret = u32((e0 & 0x0000'ffff) | (e1 << 16)); u32 ret = u32((e0 & 0x0000'ffff) | (e1 << 16));
return CreateElement(builder, source, ty, ret); return CreateScalar(builder, source, ty, ret);
} }
ConstEval::Result ConstEval::pack4x8snorm(const type::Type* ty, ConstEval::Result ConstEval::pack4x8snorm(const type::Type* ty,
@ -3093,7 +3091,7 @@ ConstEval::Result ConstEval::pack4x8snorm(const type::Type* ty,
uint32_t mask = 0x0000'00ff; uint32_t mask = 0x0000'00ff;
u32 ret = u32((e0 & mask) | ((e1 & mask) << 8) | ((e2 & mask) << 16) | ((e3 & mask) << 24)); u32 ret = u32((e0 & mask) | ((e1 & mask) << 8) | ((e2 & mask) << 16) | ((e3 & mask) << 24));
return CreateElement(builder, source, ty, ret); return CreateScalar(builder, source, ty, ret);
} }
ConstEval::Result ConstEval::pack4x8unorm(const type::Type* ty, ConstEval::Result ConstEval::pack4x8unorm(const type::Type* ty,
@ -3112,7 +3110,7 @@ ConstEval::Result ConstEval::pack4x8unorm(const type::Type* ty,
uint32_t mask = 0x0000'00ff; uint32_t mask = 0x0000'00ff;
u32 ret = u32((e0 & mask) | ((e1 & mask) << 8) | ((e2 & mask) << 16) | ((e3 & mask) << 24)); u32 ret = u32((e0 & mask) | ((e1 & mask) << 8) | ((e2 & mask) << 16) | ((e3 & mask) << 24));
return CreateElement(builder, source, ty, ret); return CreateScalar(builder, source, ty, ret);
} }
ConstEval::Result ConstEval::pow(const type::Type* ty, ConstEval::Result ConstEval::pow(const type::Type* ty,
@ -3125,7 +3123,7 @@ ConstEval::Result ConstEval::pow(const type::Type* ty,
AddError(OverflowErrorMessage(e1, "^", e2), source); AddError(OverflowErrorMessage(e1, "^", e2), source);
return utils::Failure; return utils::Failure;
} }
return CreateElement(builder, source, c0->Type(), *r); return CreateScalar(builder, source, c0->Type(), *r);
}; };
return Dispatch_fa_f32_f16(create, c0, c1); return Dispatch_fa_f32_f16(create, c0, c1);
}; };
@ -3151,7 +3149,7 @@ ConstEval::Result ConstEval::radians(const type::Type* ty,
AddNote("when calculating radians", source); AddNote("when calculating radians", source);
return utils::Failure; return utils::Failure;
} }
return CreateElement(builder, source, c0->Type(), result.Get()); return CreateScalar(builder, source, c0->Type(), result.Get());
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
}; };
@ -3178,7 +3176,7 @@ ConstEval::Result ConstEval::reflect(const type::Type* ty,
// 2 * dot(e2, e1) // 2 * dot(e2, e1)
auto mul2 = [&](auto v) -> ImplResult { auto mul2 = [&](auto v) -> ImplResult {
using NumberT = decltype(v); using NumberT = decltype(v);
return CreateElement(builder, source, el_ty, NumberT{NumberT{2} * v}); return CreateScalar(builder, source, el_ty, NumberT{NumberT{2} * v});
}; };
auto dot_e2_e1_2 = Dispatch_fa_f32_f16(mul2, dot_e2_e1.Get()); auto dot_e2_e1_2 = Dispatch_fa_f32_f16(mul2, dot_e2_e1.Get());
if (!dot_e2_e1_2) { if (!dot_e2_e1_2) {
@ -3230,7 +3228,7 @@ ConstEval::Result ConstEval::refract(const type::Type* ty,
if (!r) { if (!r) {
return utils::Failure; return utils::Failure;
} }
return CreateElement(builder, source, el_ty, r.Get()); return CreateScalar(builder, source, el_ty, r.Get());
}; };
auto compute_e2_scale = [&](auto e3, auto dot_e2_e1, auto k) -> ConstEval::Result { auto compute_e2_scale = [&](auto e3, auto dot_e2_e1, auto k) -> ConstEval::Result {
@ -3247,7 +3245,7 @@ ConstEval::Result ConstEval::refract(const type::Type* ty,
if (!r) { if (!r) {
return utils::Failure; return utils::Failure;
} }
return CreateElement(builder, source, el_ty, r.Get()); return CreateScalar(builder, source, el_ty, r.Get());
}; };
auto calculate = [&]() -> ConstEval::Result { auto calculate = [&]() -> ConstEval::Result {
@ -3319,7 +3317,7 @@ ConstEval::Result ConstEval::reverseBits(const type::Type* ty,
} }
} }
return CreateElement(builder, source, c0->Type(), NumberT{r}); return CreateScalar(builder, source, c0->Type(), NumberT{r});
}; };
return Dispatch_iu32(create, c0); return Dispatch_iu32(create, c0);
}; };
@ -3355,7 +3353,7 @@ ConstEval::Result ConstEval::round(const type::Type* ty,
} else { } else {
result = NumberT(std::round(e.value)); result = NumberT(std::round(e.value));
} }
return CreateElement(builder, source, c0->Type(), result); return CreateScalar(builder, source, c0->Type(), result);
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
}; };
@ -3368,8 +3366,8 @@ ConstEval::Result ConstEval::saturate(const type::Type* ty,
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto e) { auto create = [&](auto e) {
using NumberT = decltype(e); using NumberT = decltype(e);
return CreateElement(builder, source, c0->Type(), return CreateScalar(builder, source, c0->Type(),
NumberT(std::min(std::max(e, NumberT(0.0)), NumberT(1.0)))); NumberT(std::min(std::max(e, NumberT(0.0)), NumberT(1.0))));
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
}; };
@ -3382,7 +3380,7 @@ ConstEval::Result ConstEval::select_bool(const type::Type* ty,
auto cond = args[2]->As<bool>(); auto cond = args[2]->As<bool>();
auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) {
auto create = [&](auto f, auto t) -> ImplResult { auto create = [&](auto f, auto t) -> ImplResult {
return CreateElement(builder, source, type::Type::DeepestElementOf(ty), cond ? t : f); return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), cond ? t : f);
}; };
return Dispatch_fia_fiu32_f16_bool(create, c0, c1); return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
}; };
@ -3397,7 +3395,7 @@ ConstEval::Result ConstEval::select_boolvec(const type::Type* ty,
auto create = [&](auto f, auto t) -> ImplResult { auto create = [&](auto f, auto t) -> ImplResult {
// Get corresponding bool value at the current vector value index // Get corresponding bool value at the current vector value index
auto cond = args[2]->Index(index)->As<bool>(); auto cond = args[2]->Index(index)->As<bool>();
return CreateElement(builder, source, type::Type::DeepestElementOf(ty), cond ? t : f); return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), cond ? t : f);
}; };
return Dispatch_fia_fiu32_f16_bool(create, c0, c1); return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
}; };
@ -3420,7 +3418,7 @@ ConstEval::Result ConstEval::sign(const type::Type* ty,
} else { } else {
result = zero; result = zero;
} }
return CreateElement(builder, source, c0->Type(), result); return CreateScalar(builder, source, c0->Type(), result);
}; };
return Dispatch_fia_fi32_f16(create, c0); return Dispatch_fia_fi32_f16(create, c0);
}; };
@ -3433,7 +3431,7 @@ ConstEval::Result ConstEval::sin(const type::Type* ty,
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult { auto create = [&](auto i) -> ImplResult {
using NumberT = decltype(i); using NumberT = decltype(i);
return CreateElement(builder, source, c0->Type(), NumberT(std::sin(i.value))); return CreateScalar(builder, source, c0->Type(), NumberT(std::sin(i.value)));
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
}; };
@ -3446,7 +3444,7 @@ ConstEval::Result ConstEval::sinh(const type::Type* ty,
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult { auto create = [&](auto i) -> ImplResult {
using NumberT = decltype(i); using NumberT = decltype(i);
return CreateElement(builder, source, c0->Type(), NumberT(std::sinh(i.value))); return CreateScalar(builder, source, c0->Type(), NumberT(std::sinh(i.value)));
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
}; };
@ -3497,7 +3495,7 @@ ConstEval::Result ConstEval::smoothstep(const type::Type* ty,
if (!result) { if (!result) {
return err(); return err();
} }
return CreateElement(builder, source, c0->Type(), result.Get()); return CreateScalar(builder, source, c0->Type(), result.Get());
}; };
return Dispatch_fa_f32_f16(create, c0, c1, c2); return Dispatch_fa_f32_f16(create, c0, c1, c2);
}; };
@ -3511,7 +3509,7 @@ ConstEval::Result ConstEval::step(const type::Type* ty,
auto create = [&](auto edge, auto x) -> ImplResult { auto create = [&](auto edge, auto x) -> ImplResult {
using NumberT = decltype(edge); using NumberT = decltype(edge);
NumberT result = x.value < edge.value ? NumberT(0.0) : NumberT(1.0); NumberT result = x.value < edge.value ? NumberT(0.0) : NumberT(1.0);
return CreateElement(builder, source, c0->Type(), result); return CreateScalar(builder, source, c0->Type(), result);
}; };
return Dispatch_fa_f32_f16(create, c0, c1); return Dispatch_fa_f32_f16(create, c0, c1);
}; };
@ -3534,7 +3532,7 @@ ConstEval::Result ConstEval::tan(const type::Type* ty,
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult { auto create = [&](auto i) -> ImplResult {
using NumberT = decltype(i); using NumberT = decltype(i);
return CreateElement(builder, source, c0->Type(), NumberT(std::tan(i.value))); return CreateScalar(builder, source, c0->Type(), NumberT(std::tan(i.value)));
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
}; };
@ -3547,7 +3545,7 @@ ConstEval::Result ConstEval::tanh(const type::Type* ty,
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) -> ImplResult { auto create = [&](auto i) -> ImplResult {
using NumberT = decltype(i); using NumberT = decltype(i);
return CreateElement(builder, source, c0->Type(), NumberT(std::tanh(i.value))); return CreateScalar(builder, source, c0->Type(), NumberT(std::tanh(i.value)));
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
}; };
@ -3579,7 +3577,7 @@ ConstEval::Result ConstEval::trunc(const type::Type* ty,
const Source& source) { const Source& source) {
auto transform = [&](const constant::Constant* c0) { auto transform = [&](const constant::Constant* c0) {
auto create = [&](auto i) { auto create = [&](auto i) {
return CreateElement(builder, source, c0->Type(), decltype(i)(std::trunc(i.value))); return CreateScalar(builder, source, c0->Type(), decltype(i)(std::trunc(i.value)));
}; };
return Dispatch_fa_f32_f16(create, c0); return Dispatch_fa_f32_f16(create, c0);
}; };
@ -3601,7 +3599,7 @@ ConstEval::Result ConstEval::unpack2x16float(const type::Type* ty,
AddError(OverflowErrorMessage(in, "f32"), source); AddError(OverflowErrorMessage(in, "f32"), source);
return utils::Failure; return utils::Failure;
} }
auto el = CreateElement(builder, source, inner_ty, val.Get()); auto el = CreateScalar(builder, source, inner_ty, val.Get());
if (!el) { if (!el) {
return el; return el;
} }
@ -3621,7 +3619,7 @@ ConstEval::Result ConstEval::unpack2x16snorm(const type::Type* ty,
for (size_t i = 0; i < 2; ++i) { for (size_t i = 0; i < 2; ++i) {
auto val = f32( auto val = f32(
std::max(static_cast<float>(int16_t((e >> (16 * i)) & 0x0000'ffff)) / 32767.f, -1.f)); std::max(static_cast<float>(int16_t((e >> (16 * i)) & 0x0000'ffff)) / 32767.f, -1.f));
auto el = CreateElement(builder, source, inner_ty, val); auto el = CreateScalar(builder, source, inner_ty, val);
if (!el) { if (!el) {
return el; return el;
} }
@ -3640,7 +3638,7 @@ ConstEval::Result ConstEval::unpack2x16unorm(const type::Type* ty,
els.Reserve(2); els.Reserve(2);
for (size_t i = 0; i < 2; ++i) { for (size_t i = 0; i < 2; ++i) {
auto val = f32(static_cast<float>(uint16_t((e >> (16 * i)) & 0x0000'ffff)) / 65535.f); auto val = f32(static_cast<float>(uint16_t((e >> (16 * i)) & 0x0000'ffff)) / 65535.f);
auto el = CreateElement(builder, source, inner_ty, val); auto el = CreateScalar(builder, source, inner_ty, val);
if (!el) { if (!el) {
return el; return el;
} }
@ -3660,7 +3658,7 @@ ConstEval::Result ConstEval::unpack4x8snorm(const type::Type* ty,
for (size_t i = 0; i < 4; ++i) { for (size_t i = 0; i < 4; ++i) {
auto val = auto val =
f32(std::max(static_cast<float>(int8_t((e >> (8 * i)) & 0x0000'00ff)) / 127.f, -1.f)); f32(std::max(static_cast<float>(int8_t((e >> (8 * i)) & 0x0000'00ff)) / 127.f, -1.f));
auto el = CreateElement(builder, source, inner_ty, val); auto el = CreateScalar(builder, source, inner_ty, val);
if (!el) { if (!el) {
return el; return el;
} }
@ -3679,7 +3677,7 @@ ConstEval::Result ConstEval::unpack4x8unorm(const type::Type* ty,
els.Reserve(4); els.Reserve(4);
for (size_t i = 0; i < 4; ++i) { for (size_t i = 0; i < 4; ++i) {
auto val = f32(static_cast<float>(uint8_t((e >> (8 * i)) & 0x0000'00ff)) / 255.f); auto val = f32(static_cast<float>(uint8_t((e >> (8 * i)) & 0x0000'00ff)) / 255.f);
auto el = CreateElement(builder, source, inner_ty, val); auto el = CreateScalar(builder, source, inner_ty, val);
if (!el) { if (!el) {
return el; return el;
} }
@ -3698,7 +3696,7 @@ ConstEval::Result ConstEval::quantizeToF16(const type::Type* ty,
AddError(OverflowErrorMessage(value, "f16"), source); AddError(OverflowErrorMessage(value, "f16"), source);
return utils::Failure; return utils::Failure;
} }
return CreateElement(builder, source, c->Type(), conv.Get()); return CreateScalar(builder, source, c->Type(), conv.Get());
}; };
return TransformElements(builder, ty, transform, args[0]); return TransformElements(builder, ty, transform, args[0]);
} }