diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc index 7e33244cdb..820823677c 100644 --- a/src/tint/resolver/const_eval.cc +++ b/src/tint/resolver/const_eval.cc @@ -182,24 +182,27 @@ std::string OverflowErrorMessage(NumberT lhs, const char* op, NumberT rhs) { return ss.str(); } -/// Constant inherits from sem::Constant to add an private implementation method for conversion. -struct Constant : public sem::Constant { +/// ImplConstant inherits from sem::Constant to add an private implementation method for conversion. +struct ImplConstant : public sem::Constant { /// Convert attempts to convert the constant value to the given type. On error, Convert() /// creates a new diagnostic message and returns a Failure. - virtual utils::Result Convert(ProgramBuilder& builder, - const sem::Type* target_ty, - const Source& source) const = 0; + virtual utils::Result Convert(ProgramBuilder& builder, + const sem::Type* target_ty, + const Source& source) const = 0; }; +/// A result templated with a ImplConstant. +using ImplResult = utils::Result; + // Forward declaration -const Constant* CreateComposite(ProgramBuilder& builder, - const sem::Type* type, - utils::VectorRef elements); +const ImplConstant* CreateComposite(ProgramBuilder& builder, + const sem::Type* type, + utils::VectorRef elements); /// Element holds a single scalar or abstract-numeric value. /// Element implements the Constant interface. template -struct Element : Constant { +struct Element : ImplConstant { static_assert(!std::is_same_v, T> || std::is_same_v, "T must be a Number or bool"); @@ -219,16 +222,15 @@ struct Element : Constant { bool AllEqual() const override { return true; } size_t Hash() const override { return utils::Hash(type, ValueOf(value)); } - utils::Result Convert(ProgramBuilder& builder, - const sem::Type* target_ty, - const Source& source) const override { + ImplResult Convert(ProgramBuilder& builder, + const sem::Type* target_ty, + const Source& source) const override { TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE); if (target_ty == type) { // If the types are identical, then no conversion is needed. return this; } - bool failed = false; - auto* res = ZeroTypeDispatch(target_ty, [&](auto zero_to) -> const Constant* { + return ZeroTypeDispatch(target_ty, [&](auto zero_to) -> ImplResult { // `T` is the source type, `value` is the source value. // `TO` is the target type. using TO = std::decay_t; @@ -248,7 +250,7 @@ struct Element : Constant { ss << "value " << value << " cannot be represented as "; ss << "'" << builder.FriendlyName(target_ty) << "'"; builder.Diagnostics().add_error(tint::diag::System::Resolver, ss.str(), source); - failed = true; + return utils::Failure; } else if constexpr (IsFloatingPoint>) { // [x -> floating-point] - number not exactly representable // https://www.w3.org/TR/WGSL/#floating-point-conversion @@ -270,11 +272,6 @@ struct Element : Constant { } return nullptr; // Expression is not constant. }); - if (failed) { - // A diagnostic error has been raised, and resolving should abort. - return utils::Failure; - } - return res; TINT_END_DISABLE_WARNING(UNREACHABLE_CODE); } @@ -286,7 +283,7 @@ struct Element : Constant { /// Splat is used for zero-initializers, 'splat' constructors, or constructors where each element is /// identical. Splat may be of a vector, matrix or array type. /// Splat implements the Constant interface. -struct Splat : Constant { +struct Splat : ImplConstant { Splat(const sem::Type* t, const sem::Constant* e, size_t n) : type(t), el(e), count(n) {} ~Splat() override = default; const sem::Type* Type() const override { return type; } @@ -297,13 +294,13 @@ struct Splat : Constant { bool AllEqual() const override { return true; } size_t Hash() const override { return utils::Hash(type, el->Hash(), count); } - utils::Result Convert(ProgramBuilder& builder, - const sem::Type* target_ty, - const Source& source) const override { + ImplResult Convert(ProgramBuilder& builder, + const sem::Type* target_ty, + const Source& source) const override { // Convert the single splatted element type. // Note: This file is the only place where `sem::Constant`s are created, so this static_cast // is safe. - auto conv_el = static_cast(el)->Convert( + auto conv_el = static_cast(el)->Convert( builder, sem::Type::ElementOf(target_ty), source); if (!conv_el) { return utils::Failure; @@ -324,7 +321,7 @@ struct Splat : Constant { /// If each element is the same type and value, then a Splat would be a more efficient constant /// implementation. Use CreateComposite() to create the appropriate Constant type. /// Composite implements the Constant interface. -struct Composite : Constant { +struct Composite : ImplConstant { Composite(const sem::Type* t, utils::VectorRef els, bool all_0, @@ -341,9 +338,9 @@ struct Composite : Constant { bool AllEqual() const override { return false; /* otherwise this should be a Splat */ } size_t Hash() const override { return hash; } - utils::Result Convert(ProgramBuilder& builder, - const sem::Type* target_ty, - const Source& source) const override { + ImplResult Convert(ProgramBuilder& builder, + const sem::Type* target_ty, + const Source& source) const override { // Convert each of the composite element types. auto* el_ty = sem::Type::ElementOf(target_ty); utils::Vector conv_els; @@ -351,7 +348,7 @@ struct Composite : Constant { for (auto* el : elements) { // Note: This file is the only place where `sem::Constant`s are created, so this // static_cast is safe. - auto conv_el = static_cast(el)->Convert(builder, el_ty, source); + auto conv_el = static_cast(el)->Convert(builder, el_ty, source); if (!conv_el) { return utils::Failure; } @@ -380,30 +377,30 @@ struct Composite : Constant { /// CreateElement constructs and returns an Element. template -const Constant* CreateElement(ProgramBuilder& builder, const sem::Type* t, T v) { +const ImplConstant* CreateElement(ProgramBuilder& builder, const sem::Type* t, T v) { return builder.create>(t, v); } /// ZeroValue returns a Constant for the zero-value of the type `type`. -const Constant* ZeroValue(ProgramBuilder& builder, const sem::Type* type) { +const ImplConstant* ZeroValue(ProgramBuilder& builder, const sem::Type* type) { return Switch( type, // - [&](const sem::Vector* v) -> const Constant* { + [&](const sem::Vector* v) -> const ImplConstant* { auto* zero_el = ZeroValue(builder, v->type()); return builder.create(type, zero_el, v->Width()); }, - [&](const sem::Matrix* m) -> const Constant* { + [&](const sem::Matrix* m) -> const ImplConstant* { auto* zero_el = ZeroValue(builder, m->ColumnType()); return builder.create(type, zero_el, m->columns()); }, - [&](const sem::Array* a) -> const Constant* { + [&](const sem::Array* a) -> const ImplConstant* { if (auto* zero_el = ZeroValue(builder, a->ElemType())) { return builder.create(type, zero_el, a->Count()); } return nullptr; }, - [&](const sem::Struct* s) -> const Constant* { - std::unordered_map zero_by_type; + [&](const sem::Struct* s) -> const ImplConstant* { + std::unordered_map zero_by_type; utils::Vector zeros; zeros.Reserve(s->Members().size()); for (auto* member : s->Members()) { @@ -420,8 +417,8 @@ const Constant* ZeroValue(ProgramBuilder& builder, const sem::Type* type) { } return CreateComposite(builder, s, std::move(zeros)); }, - [&](Default) -> const Constant* { - return ZeroTypeDispatch(type, [&](auto zero) -> const Constant* { + [&](Default) -> const ImplConstant* { + return ZeroTypeDispatch(type, [&](auto zero) -> const ImplConstant* { return CreateElement(builder, type, zero); }); }); @@ -467,9 +464,9 @@ bool Equal(const sem::Constant* a, const sem::Constant* b) { /// CreateComposite is used to construct a constant of a vector, matrix or array type. /// CreateComposite examines the element values and will return either a Composite or a Splat, /// depending on the element types and values. -const Constant* CreateComposite(ProgramBuilder& builder, - const sem::Type* type, - utils::VectorRef elements) { +const ImplConstant* CreateComposite(ProgramBuilder& builder, + const sem::Type* type, + utils::VectorRef elements) { if (elements.IsEmpty()) { return nullptr; } @@ -504,10 +501,10 @@ const Constant* CreateComposite(ProgramBuilder& builder, /// transformation function 'f' on each of the most deeply nested elements of 'cs'. Assumes that all /// input constants `cs` are of the same type. template -const Constant* TransformElements(ProgramBuilder& builder, - const sem::Type* composite_ty, - F&& f, - CONSTANTS&&... cs) { +ImplResult TransformElements(ProgramBuilder& builder, + const sem::Type* composite_ty, + F&& f, + CONSTANTS&&... cs) { uint32_t n = 0; auto* ty = First(cs...)->Type(); auto* el_ty = sem::Type::ElementOf(ty, &n); @@ -517,8 +514,13 @@ const Constant* TransformElements(ProgramBuilder& builder, utils::Vector els; els.Reserve(n); for (uint32_t i = 0; i < n; i++) { - els.Push(TransformElements(builder, sem::Type::ElementOf(composite_ty), std::forward(f), - cs->Index(i)...)); + if (auto el = TransformElements(builder, sem::Type::ElementOf(composite_ty), + std::forward(f), cs->Index(i)...)) { + els.Push(el.Get()); + + } else { + return el.Failure(); + } } return CreateComposite(builder, composite_ty, std::move(els)); } @@ -528,11 +530,11 @@ const Constant* TransformElements(ProgramBuilder& builder, /// Unlike TransformElements, this function handles the constants being of different types, e.g. /// vector-scalar, scalar-vector. template -const Constant* TransformBinaryElements(ProgramBuilder& builder, - const sem::Type* composite_ty, - F&& f, - const sem::Constant* c0, - const sem::Constant* c1) { +ImplResult TransformBinaryElements(ProgramBuilder& builder, + const sem::Type* composite_ty, + F&& f, + const sem::Constant* c0, + const sem::Constant* c1) { uint32_t n0 = 0, n1 = 0; sem::Type::ElementOf(c0->Type(), &n0); sem::Type::ElementOf(c1->Type(), &n1); @@ -551,9 +553,13 @@ const Constant* TransformBinaryElements(ProgramBuilder& builder, } return c->Index(i); }; - els.Push(TransformBinaryElements(builder, sem::Type::ElementOf(composite_ty), - std::forward(f), nested_or_self(c0, n0), - nested_or_self(c1, n1))); + if (auto el = TransformBinaryElements(builder, sem::Type::ElementOf(composite_ty), + std::forward(f), nested_or_self(c0, n0), + nested_or_self(c1, n1))) { + els.Push(el.Get()); + } else { + return el.Failure(); + } } return CreateComposite(builder, composite_ty, std::move(els)); } @@ -703,7 +709,7 @@ utils::Result ConstEval::Dot4(NumberT a1, } auto ConstEval::AddFunc(const sem::Type* elem_ty) { - return [=](auto a1, auto a2) -> utils::Result { + return [=](auto a1, auto a2) -> ImplResult { if (auto r = Add(a1, a2)) { return CreateElement(builder, elem_ty, r.Get()); } @@ -712,7 +718,7 @@ auto ConstEval::AddFunc(const sem::Type* elem_ty) { } auto ConstEval::MulFunc(const sem::Type* elem_ty) { - return [=](auto a1, auto a2) -> utils::Result { + return [=](auto a1, auto a2) -> ImplResult { if (auto r = Mul(a1, a2)) { return CreateElement(builder, elem_ty, r.Get()); } @@ -721,7 +727,7 @@ auto ConstEval::MulFunc(const sem::Type* elem_ty) { } auto ConstEval::Dot2Func(const sem::Type* elem_ty) { - return [=](auto a1, auto a2, auto b1, auto b2) -> utils::Result { + return [=](auto a1, auto a2, auto b1, auto b2) -> ImplResult { if (auto r = Dot2(a1, a2, b1, b2)) { return CreateElement(builder, elem_ty, r.Get()); } @@ -730,8 +736,7 @@ auto ConstEval::Dot2Func(const sem::Type* elem_ty) { } auto ConstEval::Dot3Func(const sem::Type* elem_ty) { - return [=](auto a1, auto a2, auto a3, auto b1, auto b2, - auto b3) -> utils::Result { + return [=](auto a1, auto a2, auto a3, auto b1, auto b2, auto b3) -> ImplResult { if (auto r = Dot3(a1, a2, a3, b1, b2, b3)) { return CreateElement(builder, elem_ty, r.Get()); } @@ -740,23 +745,22 @@ auto ConstEval::Dot3Func(const sem::Type* elem_ty) { } auto ConstEval::Dot4Func(const sem::Type* elem_ty) { - return [=](auto a1, auto a2, auto a3, auto a4, auto b1, auto b2, auto b3, - auto b4) -> utils::Result { - if (auto r = Dot4(a1, a2, a3, a4, b1, b2, b3, b4)) { - return CreateElement(builder, elem_ty, r.Get()); - } - return utils::Failure; - }; + return + [=](auto a1, auto a2, auto a3, auto a4, auto b1, auto b2, auto b3, auto b4) -> ImplResult { + if (auto r = Dot4(a1, a2, a3, a4, b1, b2, b3, b4)) { + return CreateElement(builder, elem_ty, r.Get()); + } + return utils::Failure; + }; } -ConstEval::ConstantResult ConstEval::Literal(const sem::Type* ty, - const ast::LiteralExpression* literal) { +ConstEval::Result ConstEval::Literal(const sem::Type* ty, const ast::LiteralExpression* literal) { return Switch( literal, [&](const ast::BoolLiteralExpression* lit) { return CreateElement(builder, ty, lit->value); }, - [&](const ast::IntLiteralExpression* lit) -> const Constant* { + [&](const ast::IntLiteralExpression* lit) -> ImplResult { switch (lit->suffix) { case ast::IntLiteralExpression::Suffix::kNone: return CreateElement(builder, ty, AInt(lit->value)); @@ -767,7 +771,7 @@ ConstEval::ConstantResult ConstEval::Literal(const sem::Type* ty, } return nullptr; }, - [&](const ast::FloatLiteralExpression* lit) -> const Constant* { + [&](const ast::FloatLiteralExpression* lit) -> ImplResult { switch (lit->suffix) { case ast::FloatLiteralExpression::Suffix::kNone: return CreateElement(builder, ty, AFloat(lit->value)); @@ -780,9 +784,8 @@ ConstEval::ConstantResult ConstEval::Literal(const sem::Type* ty, }); } -ConstEval::ConstantResult ConstEval::ArrayOrStructCtor( - const sem::Type* ty, - utils::VectorRef args) { +ConstEval::Result ConstEval::ArrayOrStructCtor(const sem::Type* ty, + utils::VectorRef args) { if (args.IsEmpty()) { return ZeroValue(builder, ty); } @@ -801,9 +804,9 @@ ConstEval::ConstantResult ConstEval::ArrayOrStructCtor( return CreateComposite(builder, ty, std::move(els)); } -ConstEval::ConstantResult ConstEval::Conv(const sem::Type* ty, - utils::VectorRef args, - const Source& source) { +ConstEval::Result ConstEval::Conv(const sem::Type* ty, + utils::VectorRef args, + const Source& source) { uint32_t el_count = 0; auto* el_ty = sem::Type::ElementOf(ty, &el_count); if (!el_ty) { @@ -821,36 +824,36 @@ ConstEval::ConstantResult ConstEval::Conv(const sem::Type* ty, return nullptr; } -ConstEval::ConstantResult ConstEval::Zero(const sem::Type* ty, - utils::VectorRef, - const Source&) { +ConstEval::Result ConstEval::Zero(const sem::Type* ty, + utils::VectorRef, + const Source&) { return ZeroValue(builder, ty); } -ConstEval::ConstantResult ConstEval::Identity(const sem::Type*, - utils::VectorRef args, - const Source&) { +ConstEval::Result ConstEval::Identity(const sem::Type*, + utils::VectorRef args, + const Source&) { return args[0]; } -ConstEval::ConstantResult ConstEval::VecSplat(const sem::Type* ty, - utils::VectorRef args, - const Source&) { +ConstEval::Result ConstEval::VecSplat(const sem::Type* ty, + utils::VectorRef args, + const Source&) { if (auto* arg = args[0]) { return builder.create(ty, arg, static_cast(ty)->Width()); } return nullptr; } -ConstEval::ConstantResult ConstEval::VecCtorS(const sem::Type* ty, - utils::VectorRef args, - const Source&) { +ConstEval::Result ConstEval::VecCtorS(const sem::Type* ty, + utils::VectorRef args, + const Source&) { return CreateComposite(builder, ty, args); } -ConstEval::ConstantResult ConstEval::VecCtorM(const sem::Type* ty, - utils::VectorRef args, - const Source&) { +ConstEval::Result ConstEval::VecCtorM(const sem::Type* ty, + utils::VectorRef args, + const Source&) { utils::Vector els; for (auto* arg : args) { auto* val = arg; @@ -874,9 +877,9 @@ ConstEval::ConstantResult ConstEval::VecCtorM(const sem::Type* ty, return CreateComposite(builder, ty, std::move(els)); } -ConstEval::ConstantResult ConstEval::MatCtorS(const sem::Type* ty, - utils::VectorRef args, - const Source&) { +ConstEval::Result ConstEval::MatCtorS(const sem::Type* ty, + utils::VectorRef args, + const Source&) { auto* m = static_cast(ty); utils::Vector els; @@ -891,14 +894,14 @@ ConstEval::ConstantResult ConstEval::MatCtorS(const sem::Type* ty, return CreateComposite(builder, ty, std::move(els)); } -ConstEval::ConstantResult ConstEval::MatCtorV(const sem::Type* ty, - utils::VectorRef args, - const Source&) { +ConstEval::Result ConstEval::MatCtorV(const sem::Type* ty, + utils::VectorRef args, + const Source&) { return CreateComposite(builder, ty, args); } -ConstEval::ConstantResult ConstEval::Index(const sem::Expression* obj_expr, - const sem::Expression* idx_expr) { +ConstEval::Result ConstEval::Index(const sem::Expression* obj_expr, + const sem::Expression* idx_expr) { auto idx_val = idx_expr->ConstantValue(); if (!idx_val) { return nullptr; @@ -926,8 +929,8 @@ ConstEval::ConstantResult ConstEval::Index(const sem::Expression* obj_expr, return obj_val->Index(static_cast(idx)); } -ConstEval::ConstantResult ConstEval::MemberAccess(const sem::Expression* obj_expr, - const sem::StructMember* member) { +ConstEval::Result ConstEval::MemberAccess(const sem::Expression* obj_expr, + const sem::StructMember* member) { auto obj_val = obj_expr->ConstantValue(); if (!obj_val) { return nullptr; @@ -935,30 +938,29 @@ ConstEval::ConstantResult ConstEval::MemberAccess(const sem::Expression* obj_exp return obj_val->Index(static_cast(member->Index())); } -ConstEval::ConstantResult ConstEval::Swizzle(const sem::Type* ty, - const sem::Expression* vec_expr, - utils::VectorRef indices) { +ConstEval::Result ConstEval::Swizzle(const sem::Type* ty, + const sem::Expression* vec_expr, + utils::VectorRef indices) { auto* vec_val = vec_expr->ConstantValue(); if (!vec_val) { return nullptr; } if (indices.Length() == 1) { return vec_val->Index(static_cast(indices[0])); - } else { - auto values = utils::Transform<4>( - indices, [&](uint32_t i) { return vec_val->Index(static_cast(i)); }); - return CreateComposite(builder, ty, std::move(values)); } + auto values = utils::Transform<4>( + indices, [&](uint32_t i) { return vec_val->Index(static_cast(i)); }); + return CreateComposite(builder, ty, std::move(values)); } -ConstEval::ConstantResult ConstEval::Bitcast(const sem::Type*, const sem::Expression*) { +ConstEval::Result ConstEval::Bitcast(const sem::Type*, const sem::Expression*) { // TODO(crbug.com/tint/1581): Implement @const intrinsics return nullptr; } -ConstEval::ConstantResult ConstEval::OpComplement(const sem::Type* ty, - utils::VectorRef args, - const Source&) { +ConstEval::Result ConstEval::OpComplement(const sem::Type* ty, + utils::VectorRef args, + const Source&) { auto transform = [&](const sem::Constant* c) { auto create = [&](auto i) { return CreateElement(builder, c->Type(), decltype(i)(~i.value)); @@ -968,9 +970,9 @@ ConstEval::ConstantResult ConstEval::OpComplement(const sem::Type* ty, return TransformElements(builder, ty, transform, args[0]); } -ConstEval::ConstantResult ConstEval::OpUnaryMinus(const sem::Type* ty, - utils::VectorRef args, - const Source&) { +ConstEval::Result ConstEval::OpUnaryMinus(const sem::Type* ty, + utils::VectorRef args, + const Source&) { auto transform = [&](const sem::Constant* c) { auto create = [&](auto i) { // For signed integrals, avoid C++ UB by not negating the @@ -993,9 +995,9 @@ ConstEval::ConstantResult ConstEval::OpUnaryMinus(const sem::Type* ty, return TransformElements(builder, ty, transform, args[0]); } -ConstEval::ConstantResult ConstEval::OpNot(const sem::Type* ty, - utils::VectorRef args, - const Source&) { +ConstEval::Result ConstEval::OpNot(const sem::Type* ty, + utils::VectorRef args, + const Source&) { auto transform = [&](const sem::Constant* c) { auto create = [&](auto i) { return CreateElement(builder, c->Type(), decltype(i)(!i)); }; return Dispatch_bool(create, c); @@ -1003,29 +1005,22 @@ ConstEval::ConstantResult ConstEval::OpNot(const sem::Type* ty, return TransformElements(builder, ty, transform, args[0]); } -ConstEval::ConstantResult ConstEval::OpPlus(const sem::Type* ty, - utils::VectorRef args, - const Source& source) { +ConstEval::Result ConstEval::OpPlus(const sem::Type* ty, + utils::VectorRef args, + const Source& source) { TINT_SCOPED_ASSIGNMENT(current_source, &source); - auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) -> const Constant* { - if (auto r = Dispatch_fia_fiu32_f16(AddFunc(c0->Type()), c0, c1)) { - return r.Get(); - } - return nullptr; + auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { + return Dispatch_fia_fiu32_f16(AddFunc(c0->Type()), c0, c1); }; - auto r = TransformBinaryElements(builder, ty, transform, args[0], args[1]); - if (builder.Diagnostics().contains_errors()) { - return utils::Failure; - } - return r; + return TransformBinaryElements(builder, ty, transform, args[0], args[1]); } -ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type* ty, - utils::VectorRef args, - const Source& source) { +ConstEval::Result ConstEval::OpMinus(const sem::Type* ty, + utils::VectorRef args, + const Source& source) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { - auto create = [&](auto i, auto j) -> const Constant* { + auto create = [&](auto i, auto j) -> ImplResult { using NumberT = decltype(i); NumberT result; if constexpr (std::is_same_v || std::is_same_v) { @@ -1034,7 +1029,7 @@ ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type* ty, result = r->value; } else { AddError(OverflowErrorMessage(i, "-", j), source); - return nullptr; + return utils::Failure; } } else { using T = UnwrapNumber; @@ -1054,41 +1049,30 @@ ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type* ty, return Dispatch_fia_fiu32_f16(create, c0, c1); }; - auto r = TransformBinaryElements(builder, ty, transform, args[0], args[1]); - if (builder.Diagnostics().contains_errors()) { - return utils::Failure; - } - return r; + return TransformBinaryElements(builder, ty, transform, args[0], args[1]); } -ConstEval::ConstantResult ConstEval::OpMultiply(const sem::Type* ty, - utils::VectorRef args, - const Source& source) { +ConstEval::Result ConstEval::OpMultiply(const sem::Type* ty, + utils::VectorRef args, + const Source& source) { TINT_SCOPED_ASSIGNMENT(current_source, &source); - auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) -> const Constant* { - if (auto r = Dispatch_fia_fiu32_f16(MulFunc(c0->Type()), c0, c1)) { - return r.Get(); - } - return nullptr; + auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { + return Dispatch_fia_fiu32_f16(MulFunc(c0->Type()), c0, c1); }; - auto r = TransformBinaryElements(builder, ty, transform, args[0], args[1]); - if (builder.Diagnostics().contains_errors()) { - return utils::Failure; - } - return r; + return TransformBinaryElements(builder, ty, transform, args[0], args[1]); } -ConstEval::ConstantResult ConstEval::OpMultiplyMatVec(const sem::Type* ty, - utils::VectorRef args, - const Source& source) { +ConstEval::Result ConstEval::OpMultiplyMatVec(const sem::Type* ty, + utils::VectorRef args, + const Source& source) { TINT_SCOPED_ASSIGNMENT(current_source, &source); auto* mat_ty = args[0]->Type()->As(); auto* vec_ty = args[1]->Type()->As(); auto* elem_ty = vec_ty->type(); auto dot = [&](const sem::Constant* m, size_t row, const sem::Constant* v) { - utils::Result result; + ImplResult result; switch (mat_ty->columns()) { case 2: result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), // @@ -1130,16 +1114,16 @@ ConstEval::ConstantResult ConstEval::OpMultiplyMatVec(const sem::Type* ty, } return CreateComposite(builder, ty, result); } -ConstEval::ConstantResult ConstEval::OpMultiplyVecMat(const sem::Type* ty, - utils::VectorRef args, - const Source& source) { +ConstEval::Result ConstEval::OpMultiplyVecMat(const sem::Type* ty, + utils::VectorRef args, + const Source& source) { TINT_SCOPED_ASSIGNMENT(current_source, &source); auto* vec_ty = args[0]->Type()->As(); auto* mat_ty = args[1]->Type()->As(); auto* elem_ty = vec_ty->type(); auto dot = [&](const sem::Constant* v, const sem::Constant* m, size_t col) { - utils::Result result; + ImplResult result; switch (mat_ty->rows()) { case 2: result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), // @@ -1182,9 +1166,9 @@ ConstEval::ConstantResult ConstEval::OpMultiplyVecMat(const sem::Type* ty, return CreateComposite(builder, ty, result); } -ConstEval::ConstantResult ConstEval::OpMultiplyMatMat(const sem::Type* ty, - utils::VectorRef args, - const Source& source) { +ConstEval::Result ConstEval::OpMultiplyMatMat(const sem::Type* ty, + utils::VectorRef args, + const Source& source) { TINT_SCOPED_ASSIGNMENT(current_source, &source); auto* mat1 = args[0]; auto* mat2 = args[1]; @@ -1196,7 +1180,7 @@ ConstEval::ConstantResult ConstEval::OpMultiplyMatMat(const sem::Type* ty, auto m1e = [&](size_t r, size_t c) { return m1->Index(c)->Index(r); }; auto m2e = [&](size_t r, size_t c) { return m2->Index(c)->Index(r); }; - utils::Result result; + ImplResult result; switch (mat1_ty->columns()) { case 2: result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), // @@ -1247,11 +1231,11 @@ ConstEval::ConstantResult ConstEval::OpMultiplyMatMat(const sem::Type* ty, return CreateComposite(builder, ty, result_mat); } -ConstEval::ConstantResult ConstEval::OpDivide(const sem::Type* ty, - utils::VectorRef args, - const Source& source) { +ConstEval::Result ConstEval::OpDivide(const sem::Type* ty, + utils::VectorRef args, + const Source& source) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { - auto create = [&](auto i, auto j) -> const Constant* { + auto create = [&](auto i, auto j) -> ImplResult { using NumberT = decltype(i); NumberT result; if constexpr (std::is_same_v || std::is_same_v) { @@ -1260,7 +1244,7 @@ ConstEval::ConstantResult ConstEval::OpDivide(const sem::Type* ty, result = r->value; } else { AddError(OverflowErrorMessage(i, "/", j), source); - return nullptr; + return utils::Failure; } } else { using T = UnwrapNumber; @@ -1288,120 +1272,92 @@ ConstEval::ConstantResult ConstEval::OpDivide(const sem::Type* ty, return Dispatch_fia_fiu32_f16(create, c0, c1); }; - auto r = TransformBinaryElements(builder, ty, transform, args[0], args[1]); - if (builder.Diagnostics().contains_errors()) { - return utils::Failure; - } - return r; + return TransformBinaryElements(builder, ty, transform, args[0], args[1]); } -ConstEval::ConstantResult ConstEval::OpEqual(const sem::Type* ty, - utils::VectorRef args, - const Source&) { +ConstEval::Result ConstEval::OpEqual(const sem::Type* ty, + utils::VectorRef args, + const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { - auto create = [&](auto i, auto j) -> const Constant* { + auto create = [&](auto i, auto j) -> ImplResult { return CreateElement(builder, sem::Type::DeepestElementOf(ty), i == j); }; return Dispatch_fia_fiu32_f16_bool(create, c0, c1); }; - auto r = TransformElements(builder, ty, transform, args[0], args[1]); - if (builder.Diagnostics().contains_errors()) { - return utils::Failure; - } - return r; + return TransformElements(builder, ty, transform, args[0], args[1]); } -ConstEval::ConstantResult ConstEval::OpNotEqual(const sem::Type* ty, - utils::VectorRef args, - const Source&) { +ConstEval::Result ConstEval::OpNotEqual(const sem::Type* ty, + utils::VectorRef args, + const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { - auto create = [&](auto i, auto j) -> const Constant* { + auto create = [&](auto i, auto j) -> ImplResult { return CreateElement(builder, sem::Type::DeepestElementOf(ty), i != j); }; return Dispatch_fia_fiu32_f16_bool(create, c0, c1); }; - auto r = TransformElements(builder, ty, transform, args[0], args[1]); - if (builder.Diagnostics().contains_errors()) { - return utils::Failure; - } - return r; + return TransformElements(builder, ty, transform, args[0], args[1]); } -ConstEval::ConstantResult ConstEval::OpLessThan(const sem::Type* ty, - utils::VectorRef args, - const Source&) { +ConstEval::Result ConstEval::OpLessThan(const sem::Type* ty, + utils::VectorRef args, + const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { - auto create = [&](auto i, auto j) -> const Constant* { + auto create = [&](auto i, auto j) -> ImplResult { return CreateElement(builder, sem::Type::DeepestElementOf(ty), i < j); }; return Dispatch_fia_fiu32_f16(create, c0, c1); }; - auto r = TransformElements(builder, ty, transform, args[0], args[1]); - if (builder.Diagnostics().contains_errors()) { - return utils::Failure; - } - return r; + return TransformElements(builder, ty, transform, args[0], args[1]); } -ConstEval::ConstantResult ConstEval::OpGreaterThan(const sem::Type* ty, - utils::VectorRef args, - const Source&) { +ConstEval::Result ConstEval::OpGreaterThan(const sem::Type* ty, + utils::VectorRef args, + const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { - auto create = [&](auto i, auto j) -> const Constant* { + auto create = [&](auto i, auto j) -> ImplResult { return CreateElement(builder, sem::Type::DeepestElementOf(ty), i > j); }; return Dispatch_fia_fiu32_f16(create, c0, c1); }; - auto r = TransformElements(builder, ty, transform, args[0], args[1]); - if (builder.Diagnostics().contains_errors()) { - return utils::Failure; - } - return r; + return TransformElements(builder, ty, transform, args[0], args[1]); } -ConstEval::ConstantResult ConstEval::OpLessThanEqual(const sem::Type* ty, - utils::VectorRef args, - const Source&) { +ConstEval::Result ConstEval::OpLessThanEqual(const sem::Type* ty, + utils::VectorRef args, + const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { - auto create = [&](auto i, auto j) -> const Constant* { + auto create = [&](auto i, auto j) -> ImplResult { return CreateElement(builder, sem::Type::DeepestElementOf(ty), i <= j); }; return Dispatch_fia_fiu32_f16(create, c0, c1); }; - auto r = TransformElements(builder, ty, transform, args[0], args[1]); - if (builder.Diagnostics().contains_errors()) { - return utils::Failure; - } - return r; + return TransformElements(builder, ty, transform, args[0], args[1]); } -ConstEval::ConstantResult ConstEval::OpGreaterThanEqual(const sem::Type* ty, - utils::VectorRef args, - const Source&) { +ConstEval::Result ConstEval::OpGreaterThanEqual(const sem::Type* ty, + utils::VectorRef args, + const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { - auto create = [&](auto i, auto j) -> const Constant* { + auto create = [&](auto i, auto j) -> ImplResult { return CreateElement(builder, sem::Type::DeepestElementOf(ty), i >= j); }; return Dispatch_fia_fiu32_f16(create, c0, c1); }; - auto r = TransformElements(builder, ty, transform, args[0], args[1]); - if (builder.Diagnostics().contains_errors()) { - return utils::Failure; - } - return r; + return TransformElements(builder, ty, transform, args[0], args[1]); } -ConstEval::ConstantResult ConstEval::OpAnd(const sem::Type* ty, - utils::VectorRef args, - const Source&) { +ConstEval::Result ConstEval::OpAnd(const sem::Type* ty, + utils::VectorRef args, + const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { - auto create = [&](auto i, auto j) -> const Constant* { + auto create = [&](auto i, auto j) -> ImplResult { using T = decltype(i); T result; if constexpr (std::is_same_v) { @@ -1414,18 +1370,14 @@ ConstEval::ConstantResult ConstEval::OpAnd(const sem::Type* ty, return Dispatch_ia_iu32_bool(create, c0, c1); }; - auto r = TransformElements(builder, ty, transform, args[0], args[1]); - if (builder.Diagnostics().contains_errors()) { - return utils::Failure; - } - return r; + return TransformElements(builder, ty, transform, args[0], args[1]); } -ConstEval::ConstantResult ConstEval::OpOr(const sem::Type* ty, - utils::VectorRef args, - const Source&) { +ConstEval::Result ConstEval::OpOr(const sem::Type* ty, + utils::VectorRef args, + const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { - auto create = [&](auto i, auto j) -> const Constant* { + auto create = [&](auto i, auto j) -> ImplResult { using T = decltype(i); T result; if constexpr (std::is_same_v) { @@ -1438,18 +1390,14 @@ ConstEval::ConstantResult ConstEval::OpOr(const sem::Type* ty, return Dispatch_ia_iu32_bool(create, c0, c1); }; - auto r = TransformElements(builder, ty, transform, args[0], args[1]); - if (builder.Diagnostics().contains_errors()) { - return utils::Failure; - } - return r; + return TransformElements(builder, ty, transform, args[0], args[1]); } -ConstEval::ConstantResult ConstEval::OpXor(const sem::Type* ty, - utils::VectorRef args, - const Source&) { +ConstEval::Result ConstEval::OpXor(const sem::Type* ty, + utils::VectorRef args, + const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { - auto create = [&](auto i, auto j) -> const Constant* { + auto create = [&](auto i, auto j) -> const ImplConstant* { return CreateElement(builder, sem::Type::DeepestElementOf(ty), decltype(i){i ^ j}); }; return Dispatch_ia_iu32(create, c0, c1); @@ -1462,9 +1410,9 @@ ConstEval::ConstantResult ConstEval::OpXor(const sem::Type* ty, return r; } -ConstEval::ConstantResult ConstEval::atan2(const sem::Type* ty, - utils::VectorRef args, - const Source&) { +ConstEval::Result ConstEval::atan2(const sem::Type* ty, + utils::VectorRef args, + const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto create = [&](auto i, auto j) { return CreateElement(builder, c0->Type(), decltype(i)(std::atan2(i.value, j.value))); @@ -1474,9 +1422,9 @@ ConstEval::ConstantResult ConstEval::atan2(const sem::Type* ty, return TransformElements(builder, ty, transform, args[0], args[1]); } -ConstEval::ConstantResult ConstEval::clamp(const sem::Type* ty, - utils::VectorRef args, - const Source&) { +ConstEval::Result ConstEval::clamp(const sem::Type* ty, + utils::VectorRef args, + const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1, const sem::Constant* c2) { auto create = [&](auto e, auto low, auto high) { @@ -1488,17 +1436,13 @@ ConstEval::ConstantResult ConstEval::clamp(const sem::Type* ty, return TransformElements(builder, ty, transform, args[0], args[1], args[2]); } -utils::Result ConstEval::Convert(const sem::Type* target_ty, - const sem::Constant* value, - const Source& source) { +ConstEval::Result ConstEval::Convert(const sem::Type* target_ty, + const sem::Constant* value, + const Source& source) { if (value->Type() == target_ty) { return value; } - auto conv = static_cast(value)->Convert(builder, target_ty, source); - if (!conv) { - return utils::Failure; - } - return conv.Get(); + return static_cast(value)->Convert(builder, target_ty, source); } void ConstEval::AddError(const std::string& msg, const Source& source) const { diff --git a/src/tint/resolver/const_eval.h b/src/tint/resolver/const_eval.h index 04e22826ee..10a4d60a9c 100644 --- a/src/tint/resolver/const_eval.h +++ b/src/tint/resolver/const_eval.h @@ -53,12 +53,12 @@ class ConstEval { /// * `utils::Failure`. Returned when there was a resolver error. In this situation the method /// will have already reported a diagnostic error message, and the caller should abort /// resolving. - using ConstantResult = utils::Result; + using Result = utils::Result; /// Typedef for a constant evaluation function - using Function = ConstantResult (ConstEval::*)(const sem::Type* result_ty, - utils::VectorRef, - const Source&); + using Function = Result (ConstEval::*)(const sem::Type* result_ty, + utils::VectorRef, + const Source&); /// Constructor /// @param b the program builder @@ -71,44 +71,43 @@ class ConstEval { /// @param ty the target type - must be an array or constructor /// @param args the input arguments /// @return the constructed value, or null if the value cannot be calculated - ConstantResult ArrayOrStructCtor(const sem::Type* ty, - utils::VectorRef args); + Result ArrayOrStructCtor(const sem::Type* ty, utils::VectorRef args); /// @param ty the target type /// @param expr the input expression /// @return the bit-cast of the given expression to the given type, or null if the value cannot /// be calculated - ConstantResult Bitcast(const sem::Type* ty, const sem::Expression* expr); + Result Bitcast(const sem::Type* ty, const sem::Expression* expr); /// @param obj the object being indexed /// @param idx the index expression /// @return the result of the index, or null if the value cannot be calculated - ConstantResult Index(const sem::Expression* obj, const sem::Expression* idx); + Result Index(const sem::Expression* obj, const sem::Expression* idx); /// @param ty the result type /// @param lit the literal AST node /// @return the constant value of the literal - ConstantResult Literal(const sem::Type* ty, const ast::LiteralExpression* lit); + Result Literal(const sem::Type* ty, const ast::LiteralExpression* lit); /// @param obj the object being accessed /// @param member the member /// @return the result of the member access, or null if the value cannot be calculated - ConstantResult MemberAccess(const sem::Expression* obj, const sem::StructMember* member); + Result MemberAccess(const sem::Expression* obj, const sem::StructMember* member); /// @param ty the result type /// @param vector the vector being swizzled /// @param indices the swizzle indices /// @return the result of the swizzle, or null if the value cannot be calculated - ConstantResult Swizzle(const sem::Type* ty, - const sem::Expression* vector, - utils::VectorRef indices); + Result Swizzle(const sem::Type* ty, + const sem::Expression* vector, + utils::VectorRef indices); /// Convert the `value` to `target_type` /// @param ty the result type /// @param value the value being converted /// @param source the source location of the conversion /// @return the converted value, or null if the value cannot be calculated - ConstantResult Convert(const sem::Type* ty, const sem::Constant* value, const Source& source); + Result Convert(const sem::Type* ty, const sem::Constant* value, const Source& source); //////////////////////////////////////////////////////////////////////////////////////////////// // Constant value evaluation methods, to be indirectly called via the intrinsic table @@ -119,72 +118,72 @@ class ConstEval { /// @param args the input arguments /// @param source the source location of the conversion /// @return the converted value, or null if the value cannot be calculated - ConstantResult Conv(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result Conv(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Zero value type constructor /// @param ty the result type /// @param args the input arguments (no arguments provided) /// @param source the source location of the conversion /// @return the constructed value, or null if the value cannot be calculated - ConstantResult Zero(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result Zero(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Identity value type constructor /// @param ty the result type /// @param args the input arguments /// @param source the source location of the conversion /// @return the constructed value, or null if the value cannot be calculated - ConstantResult Identity(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result Identity(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Vector splat constructor /// @param ty the vector type /// @param args the input arguments /// @param source the source location of the conversion /// @return the constructed value, or null if the value cannot be calculated - ConstantResult VecSplat(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result VecSplat(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Vector constructor using scalars /// @param ty the vector type /// @param args the input arguments /// @param source the source location of the conversion /// @return the constructed value, or null if the value cannot be calculated - ConstantResult VecCtorS(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result VecCtorS(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Vector constructor using a mix of scalars and smaller vectors /// @param ty the vector type /// @param args the input arguments /// @param source the source location of the conversion /// @return the constructed value, or null if the value cannot be calculated - ConstantResult VecCtorM(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result VecCtorM(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Matrix constructor using scalar values /// @param ty the matrix type /// @param args the input arguments /// @param source the source location of the conversion /// @return the constructed value, or null if the value cannot be calculated - ConstantResult MatCtorS(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result MatCtorS(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Matrix constructor using column vectors /// @param ty the matrix type /// @param args the input arguments /// @param source the source location of the conversion /// @return the constructed value, or null if the value cannot be calculated - ConstantResult MatCtorV(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result MatCtorV(const sem::Type* ty, + utils::VectorRef args, + const Source& source); //////////////////////////////////////////////////////////////////////////// // Unary Operators @@ -195,27 +194,27 @@ class ConstEval { /// @param args the input arguments /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult OpComplement(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result OpComplement(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Unary minus operator '-' /// @param ty the expression type /// @param args the input arguments /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult OpUnaryMinus(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result OpUnaryMinus(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Unary not operator '!' /// @param ty the expression type /// @param args the input arguments /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult OpNot(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result OpNot(const sem::Type* ty, + utils::VectorRef args, + const Source& source); //////////////////////////////////////////////////////////////////////////// // Binary Operators @@ -226,142 +225,142 @@ class ConstEval { /// @param args the input arguments /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult OpPlus(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result OpPlus(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Minus operator '-' /// @param ty the expression type /// @param args the input arguments /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult OpMinus(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result OpMinus(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Multiply operator '*' for the same type on the LHS and RHS /// @param ty the expression type /// @param args the input arguments /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult OpMultiply(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result OpMultiply(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Multiply operator '*' for matCxR * vecC /// @param ty the expression type /// @param args the input arguments /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult OpMultiplyMatVec(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result OpMultiplyMatVec(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Multiply operator '*' for vecR * matCxR /// @param ty the expression type /// @param args the input arguments /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult OpMultiplyVecMat(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result OpMultiplyVecMat(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Multiply operator '*' for matKxR * matCxK /// @param ty the expression type /// @param args the input arguments /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult OpMultiplyMatMat(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result OpMultiplyMatMat(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Divide operator '/' /// @param ty the expression type /// @param args the input arguments /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult OpDivide(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result OpDivide(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Equality operator '==' /// @param ty the expression type /// @param args the input arguments /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult OpEqual(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result OpEqual(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Inequality operator '!=' /// @param ty the expression type /// @param args the input arguments /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult OpNotEqual(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result OpNotEqual(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Less than operator '<' /// @param ty the expression type /// @param args the input arguments /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult OpLessThan(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result OpLessThan(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Greater than operator '>' /// @param ty the expression type /// @param args the input arguments /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult OpGreaterThan(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result OpGreaterThan(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Less than or equal operator '<=' /// @param ty the expression type /// @param args the input arguments /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult OpLessThanEqual(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result OpLessThanEqual(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Greater than or equal operator '>=' /// @param ty the expression type /// @param args the input arguments /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult OpGreaterThanEqual(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result OpGreaterThanEqual(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Bitwise and operator '&' /// @param ty the expression type /// @param args the input arguments /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult OpAnd(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result OpAnd(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Bitwise or operator '|' /// @param ty the expression type /// @param args the input arguments /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult OpOr(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result OpOr(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Bitwise xor operator '^' /// @param ty the expression type /// @param args the input arguments /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult OpXor(const sem::Type* ty, + Result OpXor(const sem::Type* ty, utils::VectorRef args, const Source& source); @@ -374,18 +373,18 @@ class ConstEval { /// @param args the input arguments /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult atan2(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result atan2(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// clamp builtin /// @param ty the expression type /// @param args the input arguments /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult clamp(const sem::Type* ty, - utils::VectorRef args, - const Source& source); + Result clamp(const sem::Type* ty, + utils::VectorRef args, + const Source& source); private: /// Adds the given error message to the diagnostics diff --git a/src/tint/utils/result.h b/src/tint/utils/result.h index 6a14352670..baca65f4a3 100644 --- a/src/tint/utils/result.h +++ b/src/tint/utils/result.h @@ -17,6 +17,7 @@ #include #include + #include "src/tint/debug.h" namespace tint::utils { @@ -50,6 +51,20 @@ struct [[nodiscard]] Result { Result(const FAILURE_TYPE& failure) // NOLINT(runtime/explicit): : value{failure} {} + /// Copy constructor with success / failure casting + /// @param other the Result to copy + template ()}), + decltype(FAILURE_TYPE{std::declval()})>> + Result(const Result& other) { // NOLINT(runtime/explicit): + if (other) { + value = SUCCESS_TYPE{other.Get()}; + } else { + value = FAILURE_TYPE{other.Failure()}; + } + } + /// @returns true if the result was a success operator bool() const { Validate(); diff --git a/src/tint/utils/result_test.cc b/src/tint/utils/result_test.cc index ce125f452b..6614028563 100644 --- a/src/tint/utils/result_test.cc +++ b/src/tint/utils/result_test.cc @@ -51,5 +51,17 @@ TEST(ResultTest, CustomFailure) { EXPECT_EQ(r.Failure(), "oh noes!"); } +TEST(ResultTest, ValueCast) { + struct X {}; + struct Y : X {}; + + Y* y = nullptr; + auto r_y = Result{y}; + auto r_x = Result{r_y}; + + (void)r_x; + (void)r_y; +} + } // namespace } // namespace tint::utils