From fc2083a6169d65428dbddd4ec7c73886d32635b4 Mon Sep 17 00:00:00 2001 From: James Price Date: Mon, 6 Feb 2023 22:01:48 +0000 Subject: [PATCH] tint: Make CreateScalar and ZeroValue members This avoids the need to pass the `use_runtime_semantics_` flag at each callsite. Change-Id: I2cce3f147226e1295b5dfa0239beeacd519d5bb4 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/118641 Commit-Queue: James Price Reviewed-by: Ben Clayton Kokoro: Kokoro --- src/tint/resolver/const_eval.cc | 456 ++++++++++++++------------------ src/tint/resolver/const_eval.h | 12 + 2 files changed, 205 insertions(+), 263 deletions(-) diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc index 09878c4721..741c9d84a6 100644 --- a/src/tint/resolver/const_eval.cc +++ b/src/tint/resolver/const_eval.cc @@ -326,7 +326,6 @@ ConstEval::Result ConvertInternal(const constant::Value* c, const type::Type* target_ty, const Source& source, bool use_runtime_semantics); -const constant::Value* ZeroValue(ProgramBuilder& builder, const type::Type* type); ConstEval::Result SplatConvert(const constant::Splat* splat, ProgramBuilder& builder, @@ -417,78 +416,6 @@ ConstEval::Result ConvertInternal(const constant::Value* c, }); } -/// CreateScalar constructs and returns an constant::Scalar. -template -ConstEval::Result CreateScalar(ProgramBuilder& builder, - const Source& source, - const type::Type* t, - T v, - bool use_runtime_semantics) { - static_assert(IsNumber || std::is_same_v, "T must be a Number or bool"); - TINT_ASSERT(Resolver, t->is_scalar()); - - if constexpr (IsFloatingPoint) { - if (!std::isfinite(v.value)) { - auto msg = OverflowErrorMessage(v, builder.FriendlyName(t)); - if (use_runtime_semantics) { - builder.Diagnostics().add_warning(diag::System::Resolver, msg, source); - return ZeroValue(builder, t); - } else { - builder.Diagnostics().add_error(diag::System::Resolver, msg, source); - return utils::Failure; - } - } - } - return builder.create>(t, v); -} - -/// ZeroValue returns a Constant for the zero-value of the type `type`. -const constant::Value* ZeroValue(ProgramBuilder& builder, const type::Type* type) { - return Switch( - type, // - [&](const type::Vector* v) -> const constant::Value* { - auto* zero_el = ZeroValue(builder, v->type()); - return builder.create(type, zero_el, v->Width()); - }, - [&](const type::Matrix* m) -> const constant::Value* { - auto* zero_el = ZeroValue(builder, m->ColumnType()); - return builder.create(type, zero_el, m->columns()); - }, - [&](const type::Array* a) -> const constant::Value* { - if (auto n = a->ConstantCount()) { - if (auto* zero_el = ZeroValue(builder, a->ElemType())) { - return builder.create(type, zero_el, n.value()); - } - } - return nullptr; - }, - [&](const type::Struct* s) -> const constant::Value* { - utils::Hashmap zero_by_type; - utils::Vector zeros; - zeros.Reserve(s->Members().Length()); - for (auto* member : s->Members()) { - auto* zero = zero_by_type.GetOrCreate( - member->Type(), [&] { return ZeroValue(builder, member->Type()); }); - if (!zero) { - return nullptr; - } - zeros.Push(zero); - } - if (zero_by_type.Count() == 1) { - // All members were of the same type, so the zero value is the same for all members. - return builder.create(type, zeros[0], s->Members().Length()); - } - return builder.create(s, std::move(zeros)); - }, - [&](Default) -> const constant::Value* { - return ZeroTypeDispatch(type, [&](auto zero) -> const constant::Value* { - auto el = CreateScalar(builder, Source{}, type, zero, false); - TINT_ASSERT(Resolver, el); - return el.Get(); - }); - }); -} - namespace detail { /// Implementation of TransformElements template @@ -580,6 +507,70 @@ ConstEval::Result TransformBinaryElements(ProgramBuilder& builder, ConstEval::ConstEval(ProgramBuilder& b, bool use_runtime_semantics /* = false */) : builder(b), use_runtime_semantics_(use_runtime_semantics) {} +template +ConstEval::Result ConstEval::CreateScalar(const Source& source, const type::Type* t, T v) { + static_assert(IsNumber || std::is_same_v, "T must be a Number or bool"); + TINT_ASSERT(Resolver, t->is_scalar()); + + if constexpr (IsFloatingPoint) { + if (!std::isfinite(v.value)) { + AddError(OverflowErrorMessage(v, builder.FriendlyName(t)), source); + if (use_runtime_semantics_) { + return ZeroValue(t); + } else { + return utils::Failure; + } + } + } + return builder.create>(t, v); +} + +const constant::Value* ConstEval::ZeroValue(const type::Type* type) { + return Switch( + type, // + [&](const type::Vector* v) -> const constant::Value* { + auto* zero_el = ZeroValue(v->type()); + return builder.create(type, zero_el, v->Width()); + }, + [&](const type::Matrix* m) -> const constant::Value* { + auto* zero_el = ZeroValue(m->ColumnType()); + return builder.create(type, zero_el, m->columns()); + }, + [&](const type::Array* a) -> const constant::Value* { + if (auto n = a->ConstantCount()) { + if (auto* zero_el = ZeroValue(a->ElemType())) { + return builder.create(type, zero_el, n.value()); + } + } + return nullptr; + }, + [&](const type::Struct* s) -> const constant::Value* { + utils::Hashmap zero_by_type; + utils::Vector zeros; + zeros.Reserve(s->Members().Length()); + for (auto* member : s->Members()) { + auto* zero = zero_by_type.GetOrCreate(member->Type(), + [&] { return ZeroValue(member->Type()); }); + if (!zero) { + return nullptr; + } + zeros.Push(zero); + } + if (zero_by_type.Count() == 1) { + // All members were of the same type, so the zero value is the same for all members. + return builder.create(type, zeros[0], s->Members().Length()); + } + return builder.create(s, std::move(zeros)); + }, + [&](Default) -> const constant::Value* { + return ZeroTypeDispatch(type, [&](auto zero) -> const constant::Value* { + auto el = CreateScalar(Source{}, type, zero); + TINT_ASSERT(Resolver, el); + return el.Get(); + }); + }); +} + template utils::Result ConstEval::Add(const Source& source, NumberT a, NumberT b) { NumberT result; @@ -1018,7 +1009,7 @@ utils::Result ConstEval::Sqrt(const Source& source, NumberT v) { auto ConstEval::SqrtFunc(const Source& source, const type::Type* elem_ty) { return [=](auto v) -> ConstEval::Result { if (auto r = Sqrt(source, v)) { - return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_); + return CreateScalar(source, elem_ty, r.Get()); } return utils::Failure; }; @@ -1032,7 +1023,7 @@ utils::Result ConstEval::Clamp(const Source&, NumberT e, NumberT low, N auto ConstEval::ClampFunc(const Source& source, const type::Type* elem_ty) { return [=](auto e, auto low, auto high) -> ConstEval::Result { if (auto r = Clamp(source, e, low, high)) { - return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_); + return CreateScalar(source, elem_ty, r.Get()); } return utils::Failure; }; @@ -1041,7 +1032,7 @@ auto ConstEval::ClampFunc(const Source& source, const type::Type* elem_ty) { auto ConstEval::AddFunc(const Source& source, const type::Type* elem_ty) { return [=](auto a1, auto a2) -> ConstEval::Result { if (auto r = Add(source, a1, a2)) { - return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_); + return CreateScalar(source, elem_ty, r.Get()); } return utils::Failure; }; @@ -1050,7 +1041,7 @@ auto ConstEval::AddFunc(const Source& source, const type::Type* elem_ty) { auto ConstEval::SubFunc(const Source& source, const type::Type* elem_ty) { return [=](auto a1, auto a2) -> ConstEval::Result { if (auto r = Sub(source, a1, a2)) { - return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_); + return CreateScalar(source, elem_ty, r.Get()); } return utils::Failure; }; @@ -1059,7 +1050,7 @@ auto ConstEval::SubFunc(const Source& source, const type::Type* elem_ty) { auto ConstEval::MulFunc(const Source& source, const type::Type* elem_ty) { return [=](auto a1, auto a2) -> ConstEval::Result { if (auto r = Mul(source, a1, a2)) { - return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_); + return CreateScalar(source, elem_ty, r.Get()); } return utils::Failure; }; @@ -1068,7 +1059,7 @@ auto ConstEval::MulFunc(const Source& source, const type::Type* elem_ty) { auto ConstEval::DivFunc(const Source& source, const type::Type* elem_ty) { return [=](auto a1, auto a2) -> ConstEval::Result { if (auto r = Div(source, a1, a2)) { - return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_); + return CreateScalar(source, elem_ty, r.Get()); } return utils::Failure; }; @@ -1077,7 +1068,7 @@ auto ConstEval::DivFunc(const Source& source, const type::Type* elem_ty) { auto ConstEval::ModFunc(const Source& source, const type::Type* elem_ty) { return [=](auto a1, auto a2) -> ConstEval::Result { if (auto r = Mod(source, a1, a2)) { - return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_); + return CreateScalar(source, elem_ty, r.Get()); } return utils::Failure; }; @@ -1086,7 +1077,7 @@ auto ConstEval::ModFunc(const Source& source, const type::Type* elem_ty) { auto ConstEval::Dot2Func(const Source& source, const type::Type* elem_ty) { return [=](auto a1, auto a2, auto b1, auto b2) -> ConstEval::Result { if (auto r = Dot2(source, a1, a2, b1, b2)) { - return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_); + return CreateScalar(source, elem_ty, r.Get()); } return utils::Failure; }; @@ -1095,7 +1086,7 @@ auto ConstEval::Dot2Func(const Source& source, const type::Type* elem_ty) { auto ConstEval::Dot3Func(const Source& source, const type::Type* elem_ty) { return [=](auto a1, auto a2, auto a3, auto b1, auto b2, auto b3) -> ConstEval::Result { if (auto r = Dot3(source, a1, a2, a3, b1, b2, b3)) { - return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_); + return CreateScalar(source, elem_ty, r.Get()); } return utils::Failure; }; @@ -1105,7 +1096,7 @@ auto ConstEval::Dot4Func(const Source& source, const type::Type* elem_ty) { return [=](auto a1, auto a2, auto a3, auto a4, auto b1, auto b2, auto b3, auto b4) -> ConstEval::Result { if (auto r = Dot4(source, a1, a2, a3, a4, b1, b2, b3, b4)) { - return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_); + return CreateScalar(source, elem_ty, r.Get()); } return utils::Failure; }; @@ -1146,7 +1137,7 @@ ConstEval::Result ConstEval::Length(const Source& source, if (vec_ty == nullptr) { auto create = [&](auto e) { using NumberT = decltype(e); - return CreateScalar(builder, source, ty, NumberT{std::abs(e)}, use_runtime_semantics_); + return CreateScalar(source, ty, NumberT{std::abs(e)}); }; return Dispatch_fa_f32_f16(create, c0); } @@ -1182,7 +1173,7 @@ ConstEval::Result ConstEval::Sub(const Source& source, auto ConstEval::Det2Func(const Source& source, const type::Type* elem_ty) { return [=](auto a, auto b, auto c, auto d) -> ConstEval::Result { if (auto r = Det2(source, a, b, c, d)) { - return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_); + return CreateScalar(source, elem_ty, r.Get()); } return utils::Failure; }; @@ -1192,7 +1183,7 @@ auto ConstEval::Det3Func(const Source& source, const type::Type* elem_ty) { return [=](auto a, auto b, auto c, auto d, auto e, auto f, auto g, auto h, auto i) -> ConstEval::Result { if (auto r = Det3(source, a, b, c, d, e, f, g, h, i)) { - return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_); + return CreateScalar(source, elem_ty, r.Get()); } return utils::Failure; }; @@ -1202,7 +1193,7 @@ auto ConstEval::Det4Func(const Source& source, const type::Type* elem_ty) { return [=](auto a, auto b, auto c, auto d, auto e, auto f, auto g, auto h, auto i, auto j, auto k, auto l, auto m, auto n, auto o, auto p) -> ConstEval::Result { if (auto r = Det4(source, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p)) { - return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_); + return CreateScalar(source, elem_ty, r.Get()); } return utils::Failure; }; @@ -1212,34 +1203,26 @@ ConstEval::Result ConstEval::Literal(const type::Type* ty, const ast::LiteralExp auto& source = literal->source; return Switch( literal, - [&](const ast::BoolLiteralExpression* lit) { - return CreateScalar(builder, source, ty, lit->value, use_runtime_semantics_); - }, + [&](const ast::BoolLiteralExpression* lit) { return CreateScalar(source, ty, lit->value); }, [&](const ast::IntLiteralExpression* lit) -> ConstEval::Result { switch (lit->suffix) { case ast::IntLiteralExpression::Suffix::kNone: - return CreateScalar(builder, source, ty, AInt(lit->value), - use_runtime_semantics_); + return CreateScalar(source, ty, AInt(lit->value)); case ast::IntLiteralExpression::Suffix::kI: - return CreateScalar(builder, source, ty, i32(lit->value), - use_runtime_semantics_); + return CreateScalar(source, ty, i32(lit->value)); case ast::IntLiteralExpression::Suffix::kU: - return CreateScalar(builder, source, ty, u32(lit->value), - use_runtime_semantics_); + return CreateScalar(source, ty, u32(lit->value)); } return nullptr; }, [&](const ast::FloatLiteralExpression* lit) -> ConstEval::Result { switch (lit->suffix) { case ast::FloatLiteralExpression::Suffix::kNone: - return CreateScalar(builder, source, ty, AFloat(lit->value), - use_runtime_semantics_); + return CreateScalar(source, ty, AFloat(lit->value)); case ast::FloatLiteralExpression::Suffix::kF: - return CreateScalar(builder, source, ty, f32(lit->value), - use_runtime_semantics_); + return CreateScalar(source, ty, f32(lit->value)); case ast::FloatLiteralExpression::Suffix::kH: - return CreateScalar(builder, source, ty, f16(lit->value), - use_runtime_semantics_); + return CreateScalar(source, ty, f16(lit->value)); } return nullptr; }); @@ -1248,7 +1231,7 @@ ConstEval::Result ConstEval::Literal(const type::Type* ty, const ast::LiteralExp ConstEval::Result ConstEval::ArrayOrStructInit(const type::Type* ty, utils::VectorRef args) { if (args.IsEmpty()) { - return ZeroValue(builder, ty); + return ZeroValue(ty); } if (args.Length() == 1 && args[0]->Type() == ty) { @@ -1284,7 +1267,7 @@ ConstEval::Result ConstEval::Conv(const type::Type* ty, ConstEval::Result ConstEval::Zero(const type::Type* ty, utils::VectorRef, const Source&) { - return ZeroValue(builder, ty); + return ZeroValue(ty); } ConstEval::Result ConstEval::Identity(const type::Type*, @@ -1378,7 +1361,7 @@ ConstEval::Result ConstEval::Index(const type::Type* ty, AddError("index " + std::to_string(idx) + " out of bounds" + range, idx_expr->Declaration()->source); if (use_runtime_semantics_) { - return ZeroValue(builder, ty); + return ZeroValue(ty); } else { return utils::Failure; } @@ -1426,15 +1409,15 @@ ConstEval::Result ConstEval::Bitcast(const type::Type* ty, el_ty, [&](const type::U32*) { // auto r = utils::Bitcast(e); - return CreateScalar(builder, source, el_ty, r, use_runtime_semantics_); + return CreateScalar(source, el_ty, r); }, [&](const type::I32*) { // auto r = utils::Bitcast(e); - return CreateScalar(builder, source, el_ty, r, use_runtime_semantics_); + return CreateScalar(source, el_ty, r); }, [&](const type::F32*) { // auto r = utils::Bitcast(e); - return CreateScalar(builder, source, el_ty, r, use_runtime_semantics_); + return CreateScalar(source, el_ty, r); }); }; return Dispatch_fiu32(create, c0); @@ -1447,8 +1430,7 @@ ConstEval::Result ConstEval::OpComplement(const type::Type* ty, const Source& source) { auto transform = [&](const constant::Value* c) { auto create = [&](auto i) { - return CreateScalar(builder, source, c->Type(), decltype(i)(~i.value), - use_runtime_semantics_); + return CreateScalar(source, c->Type(), decltype(i)(~i.value)); }; return Dispatch_ia_iu32(create, c); }; @@ -1470,11 +1452,9 @@ ConstEval::Result ConstEval::OpUnaryMinus(const type::Type* ty, if (v != std::numeric_limits::min()) { v = -v; } - return CreateScalar(builder, source, c->Type(), decltype(i)(v), - use_runtime_semantics_); + return CreateScalar(source, c->Type(), decltype(i)(v)); } else { - return CreateScalar(builder, source, c->Type(), decltype(i)(-i.value), - use_runtime_semantics_); + return CreateScalar(source, c->Type(), decltype(i)(-i.value)); } }; return Dispatch_fia_fi32_f16(create, c); @@ -1486,10 +1466,7 @@ ConstEval::Result ConstEval::OpNot(const type::Type* ty, utils::VectorRef args, const Source& source) { auto transform = [&](const constant::Value* c) { - auto create = [&](auto i) { - return CreateScalar(builder, source, c->Type(), decltype(i)(!i), - use_runtime_semantics_); - }; + auto create = [&](auto i) { return CreateScalar(source, c->Type(), decltype(i)(!i)); }; return Dispatch_bool(create, c); }; return TransformElements(builder, ty, transform, args[0]); @@ -1707,8 +1684,7 @@ ConstEval::Result ConstEval::OpEqual(const type::Type* ty, const Source& source) { auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto i, auto j) -> ConstEval::Result { - return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i == j, - use_runtime_semantics_); + return CreateScalar(source, type::Type::DeepestElementOf(ty), i == j); }; return Dispatch_fia_fiu32_f16_bool(create, c0, c1); }; @@ -1721,8 +1697,7 @@ ConstEval::Result ConstEval::OpNotEqual(const type::Type* ty, const Source& source) { auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto i, auto j) -> ConstEval::Result { - return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i != j, - use_runtime_semantics_); + return CreateScalar(source, type::Type::DeepestElementOf(ty), i != j); }; return Dispatch_fia_fiu32_f16_bool(create, c0, c1); }; @@ -1735,8 +1710,7 @@ ConstEval::Result ConstEval::OpLessThan(const type::Type* ty, const Source& source) { auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto i, auto j) -> ConstEval::Result { - return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i < j, - use_runtime_semantics_); + return CreateScalar(source, type::Type::DeepestElementOf(ty), i < j); }; return Dispatch_fia_fiu32_f16(create, c0, c1); }; @@ -1749,8 +1723,7 @@ ConstEval::Result ConstEval::OpGreaterThan(const type::Type* ty, const Source& source) { auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto i, auto j) -> ConstEval::Result { - return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i > j, - use_runtime_semantics_); + return CreateScalar(source, type::Type::DeepestElementOf(ty), i > j); }; return Dispatch_fia_fiu32_f16(create, c0, c1); }; @@ -1763,8 +1736,7 @@ ConstEval::Result ConstEval::OpLessThanEqual(const type::Type* ty, const Source& source) { auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto i, auto j) -> ConstEval::Result { - return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i <= j, - use_runtime_semantics_); + return CreateScalar(source, type::Type::DeepestElementOf(ty), i <= j); }; return Dispatch_fia_fiu32_f16(create, c0, c1); }; @@ -1777,8 +1749,7 @@ ConstEval::Result ConstEval::OpGreaterThanEqual(const type::Type* ty, const Source& source) { auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto i, auto j) -> ConstEval::Result { - return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i >= j, - use_runtime_semantics_); + return CreateScalar(source, type::Type::DeepestElementOf(ty), i >= j); }; return Dispatch_fia_fiu32_f16(create, c0, c1); }; @@ -1791,8 +1762,7 @@ ConstEval::Result ConstEval::OpLogicalAnd(const type::Type* ty, const Source& source) { // Note: Due to short-circuiting, this function is only called if lhs is true, so we could // technically only return the value of the rhs. - return CreateScalar(builder, source, ty, args[0]->ValueAs() && args[1]->ValueAs(), - use_runtime_semantics_); + return CreateScalar(source, ty, args[0]->ValueAs() && args[1]->ValueAs()); } ConstEval::Result ConstEval::OpLogicalOr(const type::Type* ty, @@ -1800,7 +1770,7 @@ ConstEval::Result ConstEval::OpLogicalOr(const type::Type* ty, const Source& source) { // Note: Due to short-circuiting, this function is only called if lhs is false, so we could // technically only return the value of the rhs. - return CreateScalar(builder, source, ty, args[1]->ValueAs(), use_runtime_semantics_); + return CreateScalar(source, ty, args[1]->ValueAs()); } ConstEval::Result ConstEval::OpAnd(const type::Type* ty, @@ -1815,8 +1785,7 @@ ConstEval::Result ConstEval::OpAnd(const type::Type* ty, } else { // integral result = i & j; } - return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), result, - use_runtime_semantics_); + return CreateScalar(source, type::Type::DeepestElementOf(ty), result); }; return Dispatch_ia_iu32_bool(create, c0, c1); }; @@ -1836,8 +1805,7 @@ ConstEval::Result ConstEval::OpOr(const type::Type* ty, } else { // integral result = i | j; } - return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), result, - use_runtime_semantics_); + return CreateScalar(source, type::Type::DeepestElementOf(ty), result); }; return Dispatch_ia_iu32_bool(create, c0, c1); }; @@ -1850,8 +1818,7 @@ ConstEval::Result ConstEval::OpXor(const type::Type* ty, const Source& source) { auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto i, auto j) -> ConstEval::Result { - return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), - decltype(i){i ^ j}, use_runtime_semantics_); + return CreateScalar(source, type::Type::DeepestElementOf(ty), decltype(i){i ^ j}); }; return Dispatch_ia_iu32(create, c0, c1); }; @@ -1943,8 +1910,7 @@ ConstEval::Result ConstEval::OpShiftLeft(const type::Type* ty, // Avoid UB by left shifting as unsigned value auto result = static_cast(static_cast(e1) << e2u); - return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), NumberT{result}, - use_runtime_semantics_); + return CreateScalar(source, type::Type::DeepestElementOf(ty), NumberT{result}); }; return Dispatch_ia_iu32(create, c0, c1); }; @@ -2012,8 +1978,7 @@ ConstEval::Result ConstEval::OpShiftRight(const type::Type* ty, result = e1 >> e2u; } } - return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), NumberT{result}, - use_runtime_semantics_); + return CreateScalar(source, type::Type::DeepestElementOf(ty), NumberT{result}); }; return Dispatch_ia_iu32(create, c0, c1); }; @@ -2045,7 +2010,7 @@ ConstEval::Result ConstEval::abs(const type::Type* ty, } else { result = NumberT{std::abs(e)}; } - return CreateScalar(builder, source, c0->Type(), result, use_runtime_semantics_); + return CreateScalar(source, c0->Type(), result); }; return Dispatch_fia_fiu32_f16(create, c0); }; @@ -2062,13 +2027,12 @@ ConstEval::Result ConstEval::acos(const type::Type* ty, AddError("acos must be called with a value in the range [-1 .. 1] (inclusive)", source); if (use_runtime_semantics_) { - return ZeroValue(builder, c0->Type()); + return ZeroValue(c0->Type()); } else { return utils::Failure; } } - return CreateScalar(builder, source, c0->Type(), NumberT(std::acos(i.value)), - use_runtime_semantics_); + return CreateScalar(source, c0->Type(), NumberT(std::acos(i.value))); }; return Dispatch_fa_f32_f16(create, c0); }; @@ -2084,13 +2048,12 @@ ConstEval::Result ConstEval::acosh(const type::Type* ty, if (i < NumberT(1.0)) { AddError("acosh must be called with a value >= 1.0", source); if (use_runtime_semantics_) { - return ZeroValue(builder, c0->Type()); + return ZeroValue(c0->Type()); } else { return utils::Failure; } } - return CreateScalar(builder, source, c0->Type(), NumberT(std::acosh(i.value)), - use_runtime_semantics_); + return CreateScalar(source, c0->Type(), NumberT(std::acosh(i.value))); }; return Dispatch_fa_f32_f16(create, c0); }; @@ -2101,13 +2064,13 @@ ConstEval::Result ConstEval::acosh(const type::Type* ty, ConstEval::Result ConstEval::all(const type::Type* ty, utils::VectorRef args, const Source& source) { - return CreateScalar(builder, source, ty, !args[0]->AnyZero(), use_runtime_semantics_); + return CreateScalar(source, ty, !args[0]->AnyZero()); } ConstEval::Result ConstEval::any(const type::Type* ty, utils::VectorRef args, const Source& source) { - return CreateScalar(builder, source, ty, !args[0]->AllZero(), use_runtime_semantics_); + return CreateScalar(source, ty, !args[0]->AllZero()); } ConstEval::Result ConstEval::asin(const type::Type* ty, @@ -2120,13 +2083,12 @@ ConstEval::Result ConstEval::asin(const type::Type* ty, AddError("asin must be called with a value in the range [-1 .. 1] (inclusive)", source); if (use_runtime_semantics_) { - return ZeroValue(builder, c0->Type()); + return ZeroValue(c0->Type()); } else { return utils::Failure; } } - return CreateScalar(builder, source, c0->Type(), NumberT(std::asin(i.value)), - use_runtime_semantics_); + return CreateScalar(source, c0->Type(), NumberT(std::asin(i.value))); }; return Dispatch_fa_f32_f16(create, c0); }; @@ -2138,8 +2100,7 @@ ConstEval::Result ConstEval::asinh(const type::Type* ty, const Source& source) { auto transform = [&](const constant::Value* c0) { auto create = [&](auto i) { - return CreateScalar(builder, source, c0->Type(), decltype(i)(std::asinh(i.value)), - use_runtime_semantics_); + return CreateScalar(source, c0->Type(), decltype(i)(std::asinh(i.value))); }; return Dispatch_fa_f32_f16(create, c0); }; @@ -2152,8 +2113,7 @@ ConstEval::Result ConstEval::atan(const type::Type* ty, const Source& source) { auto transform = [&](const constant::Value* c0) { auto create = [&](auto i) { - return CreateScalar(builder, source, c0->Type(), decltype(i)(std::atan(i.value)), - use_runtime_semantics_); + return CreateScalar(source, c0->Type(), decltype(i)(std::atan(i.value))); }; return Dispatch_fa_f32_f16(create, c0); }; @@ -2170,13 +2130,12 @@ ConstEval::Result ConstEval::atanh(const type::Type* ty, AddError("atanh must be called with a value in the range (-1 .. 1) (exclusive)", source); if (use_runtime_semantics_) { - return ZeroValue(builder, c0->Type()); + return ZeroValue(c0->Type()); } else { return utils::Failure; } } - return CreateScalar(builder, source, c0->Type(), NumberT(std::atanh(i.value)), - use_runtime_semantics_); + return CreateScalar(source, c0->Type(), NumberT(std::atanh(i.value))); }; return Dispatch_fa_f32_f16(create, c0); }; @@ -2189,8 +2148,7 @@ ConstEval::Result ConstEval::atan2(const type::Type* ty, const Source& source) { auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto i, auto j) { - return CreateScalar(builder, source, c0->Type(), - decltype(i)(std::atan2(i.value, j.value)), use_runtime_semantics_); + return CreateScalar(source, c0->Type(), decltype(i)(std::atan2(i.value, j.value))); }; return Dispatch_fa_f32_f16(create, c0, c1); }; @@ -2202,8 +2160,7 @@ ConstEval::Result ConstEval::ceil(const type::Type* ty, const Source& source) { auto transform = [&](const constant::Value* c0) { auto create = [&](auto e) { - return CreateScalar(builder, source, c0->Type(), decltype(e)(std::ceil(e)), - use_runtime_semantics_); + return CreateScalar(source, c0->Type(), decltype(e)(std::ceil(e))); }; return Dispatch_fa_f32_f16(create, c0); }; @@ -2226,8 +2183,7 @@ ConstEval::Result ConstEval::cos(const type::Type* ty, auto transform = [&](const constant::Value* c0) { auto create = [&](auto i) -> ConstEval::Result { using NumberT = decltype(i); - return CreateScalar(builder, source, c0->Type(), NumberT(std::cos(i.value)), - use_runtime_semantics_); + return CreateScalar(source, c0->Type(), NumberT(std::cos(i.value))); }; return Dispatch_fa_f32_f16(create, c0); }; @@ -2240,8 +2196,7 @@ ConstEval::Result ConstEval::cosh(const type::Type* ty, auto transform = [&](const constant::Value* c0) { auto create = [&](auto i) -> ConstEval::Result { using NumberT = decltype(i); - return CreateScalar(builder, source, c0->Type(), NumberT(std::cosh(i.value)), - use_runtime_semantics_); + return CreateScalar(source, c0->Type(), NumberT(std::cosh(i.value))); }; return Dispatch_fa_f32_f16(create, c0); }; @@ -2256,8 +2211,7 @@ ConstEval::Result ConstEval::countLeadingZeros(const type::Type* ty, using NumberT = decltype(e); using T = UnwrapNumber; auto count = CountLeadingBits(T{e}, T{0}); - return CreateScalar(builder, source, c0->Type(), NumberT(count), - use_runtime_semantics_); + return CreateScalar(source, c0->Type(), NumberT(count)); }; return Dispatch_iu32(create, c0); }; @@ -2281,8 +2235,7 @@ ConstEval::Result ConstEval::countOneBits(const type::Type* ty, } } - return CreateScalar(builder, source, c0->Type(), NumberT(count), - use_runtime_semantics_); + return CreateScalar(source, c0->Type(), NumberT(count)); }; return Dispatch_iu32(create, c0); }; @@ -2297,8 +2250,7 @@ ConstEval::Result ConstEval::countTrailingZeros(const type::Type* ty, using NumberT = decltype(e); using T = UnwrapNumber; auto count = CountTrailingBits(T{e}, T{0}); - return CreateScalar(builder, source, c0->Type(), NumberT(count), - use_runtime_semantics_); + return CreateScalar(source, c0->Type(), NumberT(count)); }; return Dispatch_iu32(create, c0); }; @@ -2367,7 +2319,7 @@ ConstEval::Result ConstEval::degrees(const type::Type* ty, AddNote("when calculating degrees", source); return utils::Failure; } - return CreateScalar(builder, source, c0->Type(), result.Get(), use_runtime_semantics_); + return CreateScalar(source, c0->Type(), result.Get()); }; return Dispatch_fa_f32_f16(create, c0); }; @@ -2450,12 +2402,12 @@ ConstEval::Result ConstEval::exp(const type::Type* ty, if (!std::isfinite(val.value)) { AddError(OverflowExpErrorMessage("e", e0), source); if (use_runtime_semantics_) { - return ZeroValue(builder, c0->Type()); + return ZeroValue(c0->Type()); } else { return utils::Failure; } } - return CreateScalar(builder, source, c0->Type(), val, use_runtime_semantics_); + return CreateScalar(source, c0->Type(), val); }; return Dispatch_fa_f32_f16(create, c0); }; @@ -2472,12 +2424,12 @@ ConstEval::Result ConstEval::exp2(const type::Type* ty, if (!std::isfinite(val.value)) { AddError(OverflowExpErrorMessage("2", e0), source); if (use_runtime_semantics_) { - return ZeroValue(builder, c0->Type()); + return ZeroValue(c0->Type()); } else { return utils::Failure; } } - return CreateScalar(builder, source, c0->Type(), val, use_runtime_semantics_); + return CreateScalar(source, c0->Type(), val); }; return Dispatch_fa_f32_f16(create, c0); }; @@ -2537,7 +2489,7 @@ ConstEval::Result ConstEval::extractBits(const type::Type* ty, result = NumberT{r}; } - return CreateScalar(builder, source, c0->Type(), result, use_runtime_semantics_); + return CreateScalar(source, c0->Type(), result); }; return Dispatch_iu32(create, c0); }; @@ -2600,7 +2552,7 @@ ConstEval::Result ConstEval::firstLeadingBit(const type::Type* ty, } } - return CreateScalar(builder, source, c0->Type(), result, use_runtime_semantics_); + return CreateScalar(source, c0->Type(), result); }; return Dispatch_iu32(create, c0); }; @@ -2626,7 +2578,7 @@ ConstEval::Result ConstEval::firstTrailingBit(const type::Type* ty, result = NumberT(pos); } - return CreateScalar(builder, source, c0->Type(), result, use_runtime_semantics_); + return CreateScalar(source, c0->Type(), result); }; return Dispatch_iu32(create, c0); }; @@ -2638,8 +2590,7 @@ ConstEval::Result ConstEval::floor(const type::Type* ty, const Source& source) { auto transform = [&](const constant::Value* c0) { auto create = [&](auto e) { - return CreateScalar(builder, source, c0->Type(), decltype(e)(std::floor(e)), - use_runtime_semantics_); + return CreateScalar(source, c0->Type(), decltype(e)(std::floor(e))); }; return Dispatch_fa_f32_f16(create, c0); }; @@ -2666,7 +2617,7 @@ ConstEval::Result ConstEval::fma(const type::Type* ty, if (!val) { return err_msg(); } - return CreateScalar(builder, source, c1->Type(), val.Get(), use_runtime_semantics_); + return CreateScalar(source, c1->Type(), val.Get()); }; return Dispatch_fa_f32_f16(create, c1, c2, c3); }; @@ -2680,7 +2631,7 @@ ConstEval::Result ConstEval::fract(const type::Type* ty, auto create = [&](auto e) -> ConstEval::Result { using NumberT = decltype(e); auto r = e - std::floor(e); - return CreateScalar(builder, source, c1->Type(), NumberT{r}, use_runtime_semantics_); + return CreateScalar(source, c1->Type(), NumberT{r}); }; return Dispatch_fa_f32_f16(create, c1); }; @@ -2704,26 +2655,20 @@ ConstEval::Result ConstEval::frexp(const type::Type* ty, s->Type(), [&](const type::F32*) { return FractExp{ - CreateScalar(builder, source, builder.create(), f32(fract), - use_runtime_semantics_), - CreateScalar(builder, source, builder.create(), i32(exp), - use_runtime_semantics_), + CreateScalar(source, builder.create(), f32(fract)), + CreateScalar(source, builder.create(), i32(exp)), }; }, [&](const type::F16*) { return FractExp{ - CreateScalar(builder, source, builder.create(), f16(fract), - use_runtime_semantics_), - CreateScalar(builder, source, builder.create(), i32(exp), - use_runtime_semantics_), + CreateScalar(source, builder.create(), f16(fract)), + CreateScalar(source, builder.create(), i32(exp)), }; }, [&](const type::AbstractFloat*) { return FractExp{ - CreateScalar(builder, source, builder.create(), - AFloat(fract), use_runtime_semantics_), - CreateScalar(builder, source, builder.create(), AInt(exp), - use_runtime_semantics_), + CreateScalar(source, builder.create(), AFloat(fract)), + CreateScalar(source, builder.create(), AInt(exp)), }; }, [&](Default) { @@ -2814,7 +2759,7 @@ ConstEval::Result ConstEval::insertBits(const type::Type* ty, result = NumberT{r}; } - return CreateScalar(builder, source, c0->Type(), result, use_runtime_semantics_); + return CreateScalar(source, c0->Type(), result); }; return Dispatch_iu32(create, c0, c1); }; @@ -2831,7 +2776,7 @@ ConstEval::Result ConstEval::inverseSqrt(const type::Type* ty, if (e <= NumberT(0)) { AddError("inverseSqrt must be called with a value > 0", source); if (use_runtime_semantics_) { - return ZeroValue(builder, c0->Type()); + return ZeroValue(c0->Type()); } else { return utils::Failure; } @@ -2851,7 +2796,7 @@ ConstEval::Result ConstEval::inverseSqrt(const type::Type* ty, return err(); } - return CreateScalar(builder, source, c0->Type(), div.Get(), use_runtime_semantics_); + return CreateScalar(source, c0->Type(), div.Get()); }; return Dispatch_fa_f32_f16(create, c0); }; @@ -2888,7 +2833,7 @@ ConstEval::Result ConstEval::ldexp(const type::Type* ty, if (e2 > bias + 1) { AddError("e2 must be less than or equal to " + std::to_string(bias + 1), source); if (use_runtime_semantics_) { - return ZeroValue(builder, c1->Type()); + return ZeroValue(c1->Type()); } else { return utils::Failure; } @@ -2897,7 +2842,7 @@ ConstEval::Result ConstEval::ldexp(const type::Type* ty, auto target_ty = type::Type::DeepestElementOf(ty); auto r = std::ldexp(e1, static_cast(e2)); - return CreateScalar(builder, source, target_ty, E1Type{r}, use_runtime_semantics_); + return CreateScalar(source, target_ty, E1Type{r}); }; return Dispatch_fa_f32_f16(create, c1); }; @@ -2924,13 +2869,12 @@ ConstEval::Result ConstEval::log(const type::Type* ty, if (v <= NumberT(0)) { AddError("log must be called with a value > 0", source); if (use_runtime_semantics_) { - return ZeroValue(builder, c0->Type()); + return ZeroValue(c0->Type()); } else { return utils::Failure; } } - return CreateScalar(builder, source, c0->Type(), NumberT(std::log(v)), - use_runtime_semantics_); + return CreateScalar(source, c0->Type(), NumberT(std::log(v))); }; return Dispatch_fa_f32_f16(create, c0); }; @@ -2946,13 +2890,12 @@ ConstEval::Result ConstEval::log2(const type::Type* ty, if (v <= NumberT(0)) { AddError("log2 must be called with a value > 0", source); if (use_runtime_semantics_) { - return ZeroValue(builder, c0->Type()); + return ZeroValue(c0->Type()); } else { return utils::Failure; } } - return CreateScalar(builder, source, c0->Type(), NumberT(std::log2(v)), - use_runtime_semantics_); + return CreateScalar(source, c0->Type(), NumberT(std::log2(v))); }; return Dispatch_fa_f32_f16(create, c0); }; @@ -2964,8 +2907,7 @@ ConstEval::Result ConstEval::max(const type::Type* ty, const Source& source) { auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto e0, auto e1) { - return CreateScalar(builder, source, c0->Type(), decltype(e0)(std::max(e0, e1)), - use_runtime_semantics_); + return CreateScalar(source, c0->Type(), decltype(e0)(std::max(e0, e1))); }; return Dispatch_fia_fiu32_f16(create, c0, c1); }; @@ -2977,8 +2919,7 @@ ConstEval::Result ConstEval::min(const type::Type* ty, const Source& source) { auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto e0, auto e1) { - return CreateScalar(builder, source, c0->Type(), decltype(e0)(std::min(e0, e1)), - use_runtime_semantics_); + return CreateScalar(source, c0->Type(), decltype(e0)(std::min(e0, e1))); }; return Dispatch_fia_fiu32_f16(create, c0, c1); }; @@ -3017,7 +2958,7 @@ ConstEval::Result ConstEval::mix(const type::Type* ty, if (!r) { return utils::Failure; } - return CreateScalar(builder, source, c0->Type(), r.Get(), use_runtime_semantics_); + return CreateScalar(source, c0->Type(), r.Get()); }; return Dispatch_fa_f32_f16(create, c0, c1); }; @@ -3033,15 +2974,13 @@ ConstEval::Result ConstEval::modf(const type::Type* ty, const Source& source) { auto transform_fract = [&](const constant::Value* c) { auto create = [&](auto e) { - return CreateScalar(builder, source, c->Type(), - decltype(e)(e.value - std::trunc(e.value)), use_runtime_semantics_); + return CreateScalar(source, c->Type(), decltype(e)(e.value - std::trunc(e.value))); }; return Dispatch_fa_f32_f16(create, c); }; auto transform_whole = [&](const constant::Value* c) { auto create = [&](auto e) { - return CreateScalar(builder, source, c->Type(), decltype(e)(std::trunc(e.value)), - use_runtime_semantics_); + return CreateScalar(source, c->Type(), decltype(e)(std::trunc(e.value))); }; return Dispatch_fa_f32_f16(create, c); }; @@ -3076,7 +3015,7 @@ ConstEval::Result ConstEval::normalize(const type::Type* ty, if (v->AllZero()) { AddError("zero length vector can not be normalized", source); if (use_runtime_semantics_) { - return ZeroValue(builder, ty); + return ZeroValue(ty); } else { return utils::Failure; } @@ -3113,7 +3052,7 @@ ConstEval::Result ConstEval::pack2x16float(const type::Type* ty, } u32 ret = u32((e0.Get() & 0x0000'ffff) | (e1.Get() << 16)); - return CreateScalar(builder, source, ty, ret, use_runtime_semantics_); + return CreateScalar(source, ty, ret); } ConstEval::Result ConstEval::pack2x16snorm(const type::Type* ty, @@ -3130,7 +3069,7 @@ ConstEval::Result ConstEval::pack2x16snorm(const type::Type* ty, auto e1 = calc(e->Index(1)->ValueAs()); u32 ret = u32((e0 & 0x0000'ffff) | (e1 << 16)); - return CreateScalar(builder, source, ty, ret, use_runtime_semantics_); + return CreateScalar(source, ty, ret); } ConstEval::Result ConstEval::pack2x16unorm(const type::Type* ty, @@ -3146,7 +3085,7 @@ ConstEval::Result ConstEval::pack2x16unorm(const type::Type* ty, auto e1 = calc(e->Index(1)->ValueAs()); u32 ret = u32((e0 & 0x0000'ffff) | (e1 << 16)); - return CreateScalar(builder, source, ty, ret, use_runtime_semantics_); + return CreateScalar(source, ty, ret); } ConstEval::Result ConstEval::pack4x8snorm(const type::Type* ty, @@ -3166,7 +3105,7 @@ ConstEval::Result ConstEval::pack4x8snorm(const type::Type* ty, uint32_t mask = 0x0000'00ff; u32 ret = u32((e0 & mask) | ((e1 & mask) << 8) | ((e2 & mask) << 16) | ((e3 & mask) << 24)); - return CreateScalar(builder, source, ty, ret, use_runtime_semantics_); + return CreateScalar(source, ty, ret); } ConstEval::Result ConstEval::pack4x8unorm(const type::Type* ty, @@ -3185,7 +3124,7 @@ ConstEval::Result ConstEval::pack4x8unorm(const type::Type* ty, uint32_t mask = 0x0000'00ff; u32 ret = u32((e0 & mask) | ((e1 & mask) << 8) | ((e2 & mask) << 16) | ((e3 & mask) << 24)); - return CreateScalar(builder, source, ty, ret, use_runtime_semantics_); + return CreateScalar(source, ty, ret); } ConstEval::Result ConstEval::pow(const type::Type* ty, @@ -3197,12 +3136,12 @@ ConstEval::Result ConstEval::pow(const type::Type* ty, if (!r) { AddError(OverflowErrorMessage(e1, "^", e2), source); if (use_runtime_semantics_) { - return ZeroValue(builder, c0->Type()); + return ZeroValue(c0->Type()); } else { return utils::Failure; } } - return CreateScalar(builder, source, c0->Type(), *r, use_runtime_semantics_); + return CreateScalar(source, c0->Type(), *r); }; return Dispatch_fa_f32_f16(create, c0, c1); }; @@ -3228,7 +3167,7 @@ ConstEval::Result ConstEval::radians(const type::Type* ty, AddNote("when calculating radians", source); return utils::Failure; } - return CreateScalar(builder, source, c0->Type(), result.Get(), use_runtime_semantics_); + return CreateScalar(source, c0->Type(), result.Get()); }; return Dispatch_fa_f32_f16(create, c0); }; @@ -3255,8 +3194,7 @@ ConstEval::Result ConstEval::reflect(const type::Type* ty, // 2 * dot(e2, e1) auto mul2 = [&](auto v) -> ConstEval::Result { using NumberT = decltype(v); - return CreateScalar(builder, source, el_ty, NumberT{NumberT{2} * v}, - use_runtime_semantics_); + return CreateScalar(source, el_ty, NumberT{NumberT{2} * v}); }; auto dot_e2_e1_2 = Dispatch_fa_f32_f16(mul2, dot_e2_e1.Get()); if (!dot_e2_e1_2) { @@ -3308,7 +3246,7 @@ ConstEval::Result ConstEval::refract(const type::Type* ty, if (!r) { return utils::Failure; } - return CreateScalar(builder, source, el_ty, r.Get(), use_runtime_semantics_); + return CreateScalar(source, el_ty, r.Get()); }; auto compute_e2_scale = [&](auto e3, auto dot_e2_e1, auto k) -> ConstEval::Result { @@ -3325,7 +3263,7 @@ ConstEval::Result ConstEval::refract(const type::Type* ty, if (!r) { return utils::Failure; } - return CreateScalar(builder, source, el_ty, r.Get(), use_runtime_semantics_); + return CreateScalar(source, el_ty, r.Get()); }; auto calculate = [&]() -> ConstEval::Result { @@ -3352,7 +3290,7 @@ ConstEval::Result ConstEval::refract(const type::Type* ty, // If k < 0.0, returns the refraction vector 0.0 if (k.Get()->ValueAs() < 0) { - return ZeroValue(builder, ty); + return ZeroValue(ty); } // Otherwise return the refraction vector e3 * e1 - (e3 * dot(e2, e1) + sqrt(k)) * e2 @@ -3397,7 +3335,7 @@ ConstEval::Result ConstEval::reverseBits(const type::Type* ty, } } - return CreateScalar(builder, source, c0->Type(), NumberT{r}, use_runtime_semantics_); + return CreateScalar(source, c0->Type(), NumberT{r}); }; return Dispatch_iu32(create, c0); }; @@ -3433,7 +3371,7 @@ ConstEval::Result ConstEval::round(const type::Type* ty, } else { result = NumberT(std::round(e.value)); } - return CreateScalar(builder, source, c0->Type(), result, use_runtime_semantics_); + return CreateScalar(source, c0->Type(), result); }; return Dispatch_fa_f32_f16(create, c0); }; @@ -3446,9 +3384,8 @@ ConstEval::Result ConstEval::saturate(const type::Type* ty, auto transform = [&](const constant::Value* c0) { auto create = [&](auto e) { using NumberT = decltype(e); - return CreateScalar(builder, source, c0->Type(), - NumberT(std::min(std::max(e, NumberT(0.0)), NumberT(1.0))), - use_runtime_semantics_); + return CreateScalar(source, c0->Type(), + NumberT(std::min(std::max(e, NumberT(0.0)), NumberT(1.0)))); }; return Dispatch_fa_f32_f16(create, c0); }; @@ -3461,8 +3398,7 @@ ConstEval::Result ConstEval::select_bool(const type::Type* ty, auto cond = args[2]->ValueAs(); auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto f, auto t) -> ConstEval::Result { - return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), cond ? t : f, - use_runtime_semantics_); + return CreateScalar(source, type::Type::DeepestElementOf(ty), cond ? t : f); }; return Dispatch_fia_fiu32_f16_bool(create, c0, c1); }; @@ -3477,8 +3413,7 @@ ConstEval::Result ConstEval::select_boolvec(const type::Type* ty, auto create = [&](auto f, auto t) -> ConstEval::Result { // Get corresponding bool value at the current vector value index auto cond = args[2]->Index(index)->ValueAs(); - return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), cond ? t : f, - use_runtime_semantics_); + return CreateScalar(source, type::Type::DeepestElementOf(ty), cond ? t : f); }; return Dispatch_fia_fiu32_f16_bool(create, c0, c1); }; @@ -3501,7 +3436,7 @@ ConstEval::Result ConstEval::sign(const type::Type* ty, } else { result = zero; } - return CreateScalar(builder, source, c0->Type(), result, use_runtime_semantics_); + return CreateScalar(source, c0->Type(), result); }; return Dispatch_fia_fi32_f16(create, c0); }; @@ -3514,8 +3449,7 @@ ConstEval::Result ConstEval::sin(const type::Type* ty, auto transform = [&](const constant::Value* c0) { auto create = [&](auto i) -> ConstEval::Result { using NumberT = decltype(i); - return CreateScalar(builder, source, c0->Type(), NumberT(std::sin(i.value)), - use_runtime_semantics_); + return CreateScalar(source, c0->Type(), NumberT(std::sin(i.value))); }; return Dispatch_fa_f32_f16(create, c0); }; @@ -3528,8 +3462,7 @@ ConstEval::Result ConstEval::sinh(const type::Type* ty, auto transform = [&](const constant::Value* c0) { auto create = [&](auto i) -> ConstEval::Result { using NumberT = decltype(i); - return CreateScalar(builder, source, c0->Type(), NumberT(std::sinh(i.value)), - use_runtime_semantics_); + return CreateScalar(source, c0->Type(), NumberT(std::sinh(i.value))); }; return Dispatch_fa_f32_f16(create, c0); }; @@ -3580,7 +3513,7 @@ ConstEval::Result ConstEval::smoothstep(const type::Type* ty, if (!result) { return err(); } - return CreateScalar(builder, source, c0->Type(), result.Get(), use_runtime_semantics_); + return CreateScalar(source, c0->Type(), result.Get()); }; return Dispatch_fa_f32_f16(create, c0, c1, c2); }; @@ -3594,7 +3527,7 @@ ConstEval::Result ConstEval::step(const type::Type* ty, auto create = [&](auto edge, auto x) -> ConstEval::Result { using NumberT = decltype(edge); NumberT result = x.value < edge.value ? NumberT(0.0) : NumberT(1.0); - return CreateScalar(builder, source, c0->Type(), result, use_runtime_semantics_); + return CreateScalar(source, c0->Type(), result); }; return Dispatch_fa_f32_f16(create, c0, c1); }; @@ -3617,8 +3550,7 @@ ConstEval::Result ConstEval::tan(const type::Type* ty, auto transform = [&](const constant::Value* c0) { auto create = [&](auto i) -> ConstEval::Result { using NumberT = decltype(i); - return CreateScalar(builder, source, c0->Type(), NumberT(std::tan(i.value)), - use_runtime_semantics_); + return CreateScalar(source, c0->Type(), NumberT(std::tan(i.value))); }; return Dispatch_fa_f32_f16(create, c0); }; @@ -3631,8 +3563,7 @@ ConstEval::Result ConstEval::tanh(const type::Type* ty, auto transform = [&](const constant::Value* c0) { auto create = [&](auto i) -> ConstEval::Result { using NumberT = decltype(i); - return CreateScalar(builder, source, c0->Type(), NumberT(std::tanh(i.value)), - use_runtime_semantics_); + return CreateScalar(source, c0->Type(), NumberT(std::tanh(i.value))); }; return Dispatch_fa_f32_f16(create, c0); }; @@ -3665,8 +3596,7 @@ ConstEval::Result ConstEval::trunc(const type::Type* ty, const Source& source) { auto transform = [&](const constant::Value* c0) { auto create = [&](auto i) { - return CreateScalar(builder, source, c0->Type(), decltype(i)(std::trunc(i.value)), - use_runtime_semantics_); + return CreateScalar(source, c0->Type(), decltype(i)(std::trunc(i.value))); }; return Dispatch_fa_f32_f16(create, c0); }; @@ -3692,7 +3622,7 @@ ConstEval::Result ConstEval::unpack2x16float(const type::Type* ty, return utils::Failure; } } - auto el = CreateScalar(builder, source, inner_ty, val.Get(), use_runtime_semantics_); + auto el = CreateScalar(source, inner_ty, val.Get()); if (!el) { return el; } @@ -3712,7 +3642,7 @@ ConstEval::Result ConstEval::unpack2x16snorm(const type::Type* ty, for (size_t i = 0; i < 2; ++i) { auto val = f32( std::max(static_cast(int16_t((e >> (16 * i)) & 0x0000'ffff)) / 32767.f, -1.f)); - auto el = CreateScalar(builder, source, inner_ty, val, use_runtime_semantics_); + auto el = CreateScalar(source, inner_ty, val); if (!el) { return el; } @@ -3731,7 +3661,7 @@ ConstEval::Result ConstEval::unpack2x16unorm(const type::Type* ty, els.Reserve(2); for (size_t i = 0; i < 2; ++i) { auto val = f32(static_cast(uint16_t((e >> (16 * i)) & 0x0000'ffff)) / 65535.f); - auto el = CreateScalar(builder, source, inner_ty, val, use_runtime_semantics_); + auto el = CreateScalar(source, inner_ty, val); if (!el) { return el; } @@ -3751,7 +3681,7 @@ ConstEval::Result ConstEval::unpack4x8snorm(const type::Type* ty, for (size_t i = 0; i < 4; ++i) { auto val = f32(std::max(static_cast(int8_t((e >> (8 * i)) & 0x0000'00ff)) / 127.f, -1.f)); - auto el = CreateScalar(builder, source, inner_ty, val, use_runtime_semantics_); + auto el = CreateScalar(source, inner_ty, val); if (!el) { return el; } @@ -3770,7 +3700,7 @@ ConstEval::Result ConstEval::unpack4x8unorm(const type::Type* ty, els.Reserve(4); for (size_t i = 0; i < 4; ++i) { auto val = f32(static_cast(uint8_t((e >> (8 * i)) & 0x0000'00ff)) / 255.f); - auto el = CreateScalar(builder, source, inner_ty, val, use_runtime_semantics_); + auto el = CreateScalar(source, inner_ty, val); if (!el) { return el; } @@ -3788,12 +3718,12 @@ ConstEval::Result ConstEval::quantizeToF16(const type::Type* ty, if (!conv) { AddError(OverflowErrorMessage(value, "f16"), source); if (use_runtime_semantics_) { - return ZeroValue(builder, c->Type()); + return ZeroValue(c->Type()); } else { return utils::Failure; } } - return CreateScalar(builder, source, c->Type(), conv.Get(), use_runtime_semantics_); + return CreateScalar(source, c->Type(), conv.Get()); }; return TransformElements(builder, ty, transform, args[0]); } diff --git a/src/tint/resolver/const_eval.h b/src/tint/resolver/const_eval.h index df91cc0615..934924e57c 100644 --- a/src/tint/resolver/const_eval.h +++ b/src/tint/resolver/const_eval.h @@ -18,6 +18,7 @@ #include #include +#include "src/tint/number.h" #include "src/tint/type/type.h" #include "src/tint/utils/result.h" #include "src/tint/utils/vector.h" @@ -1093,6 +1094,17 @@ class ConstEval { /// Adds the given note message to the diagnostics void AddNote(const std::string& msg, const Source& source) const; + /// CreateScalar constructs and returns a constant::Scalar. + /// @param source the source location + /// @param t the result type + /// @param v the scalar value + /// @return the constant value with the same type and value + template + ConstEval::Result CreateScalar(const Source& source, const type::Type* t, T v); + + /// ZeroValue returns a Constant for the zero-value of the type `type`. + const constant::Value* ZeroValue(const type::Type* type); + /// Adds two Numbers /// @param source the source location /// @param a the lhs number