diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h index 7774a334ae..25766e671f 100644 --- a/src/tint/program_builder.h +++ b/src/tint/program_builder.h @@ -1978,6 +1978,17 @@ class ProgramBuilder { Expr(std::forward(rhs))); } + /// @param source the source information + /// @param lhs the left hand argument to the addition operation + /// @param rhs the right hand argument to the addition operation + /// @returns a `ast::BinaryExpression` summing the arguments `lhs` and `rhs` + template + const ast::BinaryExpression* Add(const Source& source, LHS&& lhs, RHS&& rhs) { + return create(source, ast::BinaryOp::kAdd, + Expr(std::forward(lhs)), + Expr(std::forward(rhs))); + } + /// @param lhs the left hand argument to the and operation /// @param rhs the right hand argument to the and operation /// @returns a `ast::BinaryExpression` bitwise anding `lhs` and `rhs` diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc index 26a040396d..5144e6308d 100644 --- a/src/tint/resolver/const_eval.cc +++ b/src/tint/resolver/const_eval.cc @@ -509,15 +509,6 @@ const Constant* TransformBinaryElements(ProgramBuilder& builder, return CreateComposite(builder, ty, std::move(els)); } -/// CombineSource returns the combined `Source`s of each expression in `exprs`. -Source CombineSource(utils::VectorRef exprs) { - Source result = exprs[0]->Declaration()->source; - for (size_t i = 1; i < exprs.Length(); ++i) { - result = result.Combine(result, exprs[i]->Declaration()->source); - } - return result; -} - } // namespace ConstEval::ConstEval(ProgramBuilder& b) : builder(b) {} @@ -575,20 +566,19 @@ ConstEval::ConstantResult ConstEval::ArrayOrStructCtor( } ConstEval::ConstantResult ConstEval::Conv(const sem::Type* ty, - utils::VectorRef args) { + utils::VectorRef args, + const Source& source) { uint32_t el_count = 0; auto* el_ty = sem::Type::ElementOf(ty, &el_count); if (!el_ty) { return nullptr; } - auto& src = args[0]->Declaration()->source; - auto* arg = args[0]->ConstantValue(); - if (!arg) { + if (!args[0]) { return nullptr; // Single argument is not constant. } - if (auto conv = Convert(ty, arg, src)) { + if (auto conv = Convert(ty, args[0], source)) { return conv.Get(); } @@ -596,37 +586,38 @@ ConstEval::ConstantResult ConstEval::Conv(const sem::Type* ty, } ConstEval::ConstantResult ConstEval::Zero(const sem::Type* ty, - utils::VectorRef) { + utils::VectorRef, + const Source&) { return ZeroValue(builder, ty); } ConstEval::ConstantResult ConstEval::Identity(const sem::Type*, - utils::VectorRef args) { - return args[0]->ConstantValue(); + utils::VectorRef args, + const Source&) { + return args[0]; } ConstEval::ConstantResult ConstEval::VecSplat(const sem::Type* ty, - utils::VectorRef args) { - if (auto* arg = args[0]->ConstantValue()) { + utils::VectorRef args, + const Source&) { + if (auto* arg = args[0]) { return builder.create(ty, arg, static_cast(ty)->Width()); } return nullptr; } ConstEval::ConstantResult ConstEval::VecCtorS(const sem::Type* ty, - utils::VectorRef args) { - utils::Vector els; - for (auto* arg : args) { - els.Push(arg->ConstantValue()); - } - return CreateComposite(builder, ty, std::move(els)); + utils::VectorRef args, + const Source&) { + return CreateComposite(builder, ty, args); } ConstEval::ConstantResult ConstEval::VecCtorM(const sem::Type* ty, - utils::VectorRef args) { + utils::VectorRef args, + const Source&) { utils::Vector els; for (auto* arg : args) { - auto* val = arg->ConstantValue(); + auto* val = arg; if (!val) { return nullptr; } @@ -648,7 +639,8 @@ ConstEval::ConstantResult ConstEval::VecCtorM(const sem::Type* ty, } ConstEval::ConstantResult ConstEval::MatCtorS(const sem::Type* ty, - utils::VectorRef args) { + utils::VectorRef args, + const Source&) { auto* m = static_cast(ty); utils::Vector els; @@ -656,7 +648,7 @@ ConstEval::ConstantResult ConstEval::MatCtorS(const sem::Type* ty, utils::Vector column; for (uint32_t r = 0; r < m->rows(); r++) { auto i = r + c * m->rows(); - column.Push(args[i]->ConstantValue()); + column.Push(args[i]); } els.Push(CreateComposite(builder, m->ColumnType(), std::move(column))); } @@ -664,12 +656,9 @@ ConstEval::ConstantResult ConstEval::MatCtorS(const sem::Type* ty, } ConstEval::ConstantResult ConstEval::MatCtorV(const sem::Type* ty, - utils::VectorRef args) { - utils::Vector els; - for (auto* arg : args) { - els.Push(arg->ConstantValue()); - } - return CreateComposite(builder, ty, std::move(els)); + utils::VectorRef args, + const Source&) { + return CreateComposite(builder, ty, args); } ConstEval::ConstantResult ConstEval::Index(const sem::Expression* obj_expr, @@ -731,18 +720,20 @@ ConstEval::ConstantResult ConstEval::Bitcast(const sem::Type*, const sem::Expres } ConstEval::ConstantResult ConstEval::OpComplement(const sem::Type*, - utils::VectorRef args) { + utils::VectorRef args, + const Source&) { auto transform = [&](const sem::Constant* c) { auto create = [&](auto i) { return CreateElement(builder, c->Type(), decltype(i)(~i.value)); }; return Dispatch_ia_iu32(create, c); }; - return TransformElements(builder, transform, args[0]->ConstantValue()); + return TransformElements(builder, transform, args[0]); } ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type*, - utils::VectorRef args) { + utils::VectorRef args, + const Source&) { auto transform = [&](const sem::Constant* c) { auto create = [&](auto i) { // For signed integrals, avoid C++ UB by not negating the @@ -762,11 +753,12 @@ ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type*, }; return Dispatch_fia_fi32_f16(create, c); }; - return TransformElements(builder, transform, args[0]->ConstantValue()); + return TransformElements(builder, transform, args[0]); } ConstEval::ConstantResult ConstEval::OpPlus(const sem::Type* ty, - utils::VectorRef args) { + utils::VectorRef args, + const Source& source) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto create = [&](auto i, auto j) -> const Constant* { using NumberT = decltype(i); @@ -791,7 +783,7 @@ ConstEval::ConstantResult ConstEval::OpPlus(const sem::Type* ty, AddError("'" + std::to_string(add_values(i.value, j.value)) + "' cannot be represented as '" + ty->FriendlyName(builder.Symbols()) + "'", - CombineSource(args)); + source); return nullptr; } } else { @@ -802,8 +794,7 @@ ConstEval::ConstantResult ConstEval::OpPlus(const sem::Type* ty, return Dispatch_fia_fiu32_f16(create, c0, c1); }; - auto r = TransformBinaryElements(builder, transform, args[0]->ConstantValue(), - args[1]->ConstantValue()); + auto r = TransformBinaryElements(builder, transform, args[0], args[1]); if (builder.Diagnostics().contains_errors()) { return utils::Failure; } @@ -811,19 +802,20 @@ ConstEval::ConstantResult ConstEval::OpPlus(const sem::Type* ty, } ConstEval::ConstantResult ConstEval::atan2(const sem::Type*, - utils::VectorRef args) { + utils::VectorRef args, + const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto create = [&](auto i, auto j) { return CreateElement(builder, c0->Type(), decltype(i)(std::atan2(i.value, j.value))); }; return Dispatch_fa_f32_f16(create, c0, c1); }; - return TransformElements(builder, transform, args[0]->ConstantValue(), - args[1]->ConstantValue()); + return TransformElements(builder, transform, args[0], args[1]); } ConstEval::ConstantResult ConstEval::clamp(const sem::Type*, - utils::VectorRef args) { + utils::VectorRef args, + const Source&) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1, const sem::Constant* c2) { auto create = [&](auto e, auto low, auto high) { @@ -832,8 +824,7 @@ ConstEval::ConstantResult ConstEval::clamp(const sem::Type*, }; return Dispatch_fia_fiu32_f16(create, c0, c1, c2); }; - return TransformElements(builder, transform, args[0]->ConstantValue(), args[1]->ConstantValue(), - args[2]->ConstantValue()); + return TransformElements(builder, transform, args[0], args[1], args[2]); } utils::Result ConstEval::Convert(const sem::Type* target_ty, diff --git a/src/tint/resolver/const_eval.h b/src/tint/resolver/const_eval.h index 38bde53c8e..dbc3dbd74a 100644 --- a/src/tint/resolver/const_eval.h +++ b/src/tint/resolver/const_eval.h @@ -57,7 +57,8 @@ class ConstEval { /// Typedef for a constant evaluation function using Function = ConstantResult (ConstEval::*)(const sem::Type* result_ty, - utils::VectorRef); + utils::VectorRef, + const Source&); /// Constructor /// @param b the program builder @@ -116,50 +117,74 @@ class ConstEval { /// Type conversion /// @param ty the result type /// @param args the input arguments + /// @param source the source location of the conversion /// @return the converted value, or null if the value cannot be calculated - ConstantResult Conv(const sem::Type* ty, utils::VectorRef args); + ConstantResult Conv(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Zero value type constructor /// @param ty the result type /// @param args the input arguments (no arguments provided) + /// @param source the source location of the conversion /// @return the constructed value, or null if the value cannot be calculated - ConstantResult Zero(const sem::Type* ty, utils::VectorRef args); + ConstantResult Zero(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Identity value type constructor /// @param ty the result type /// @param args the input arguments + /// @param source the source location of the conversion /// @return the constructed value, or null if the value cannot be calculated - ConstantResult Identity(const sem::Type* ty, utils::VectorRef args); + ConstantResult Identity(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Vector splat constructor /// @param ty the vector type /// @param args the input arguments + /// @param source the source location of the conversion /// @return the constructed value, or null if the value cannot be calculated - ConstantResult VecSplat(const sem::Type* ty, utils::VectorRef args); + ConstantResult VecSplat(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Vector constructor using scalars /// @param ty the vector type /// @param args the input arguments + /// @param source the source location of the conversion /// @return the constructed value, or null if the value cannot be calculated - ConstantResult VecCtorS(const sem::Type* ty, utils::VectorRef args); + ConstantResult VecCtorS(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Vector constructor using a mix of scalars and smaller vectors /// @param ty the vector type /// @param args the input arguments + /// @param source the source location of the conversion /// @return the constructed value, or null if the value cannot be calculated - ConstantResult VecCtorM(const sem::Type* ty, utils::VectorRef args); + ConstantResult VecCtorM(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Matrix constructor using scalar values /// @param ty the matrix type /// @param args the input arguments + /// @param source the source location of the conversion /// @return the constructed value, or null if the value cannot be calculated - ConstantResult MatCtorS(const sem::Type* ty, utils::VectorRef args); + ConstantResult MatCtorS(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Matrix constructor using column vectors /// @param ty the matrix type /// @param args the input arguments + /// @param source the source location of the conversion /// @return the constructed value, or null if the value cannot be calculated - ConstantResult MatCtorV(const sem::Type* ty, utils::VectorRef args); + ConstantResult MatCtorV(const sem::Type* ty, + utils::VectorRef args, + const Source& source); //////////////////////////////////////////////////////////////////////////// // Unary Operators @@ -168,14 +193,20 @@ class ConstEval { /// Complement operator '~' /// @param ty the integer type /// @param args the input arguments + /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult OpComplement(const sem::Type* ty, utils::VectorRef args); + ConstantResult OpComplement(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// Minus operator '-' /// @param ty the expression type /// @param args the input arguments + /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult OpMinus(const sem::Type* ty, utils::VectorRef args); + ConstantResult OpMinus(const sem::Type* ty, + utils::VectorRef args, + const Source& source); //////////////////////////////////////////////////////////////////////////// // Binary Operators @@ -184,8 +215,11 @@ class ConstEval { /// Plus operator '+' /// @param ty the expression type /// @param args the input arguments + /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult OpPlus(const sem::Type* ty, utils::VectorRef args); + ConstantResult OpPlus(const sem::Type* ty, + utils::VectorRef args, + const Source& source); //////////////////////////////////////////////////////////////////////////// // Builtins @@ -194,14 +228,20 @@ class ConstEval { /// atan2 builtin /// @param ty the expression type /// @param args the input arguments + /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult atan2(const sem::Type* ty, utils::VectorRef args); + ConstantResult atan2(const sem::Type* ty, + utils::VectorRef args, + const Source& source); /// clamp builtin /// @param ty the expression type /// @param args the input arguments + /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated - ConstantResult clamp(const sem::Type* ty, utils::VectorRef args); + ConstantResult clamp(const sem::Type* ty, + utils::VectorRef args, + const Source& source); private: /// Adds the given error message to the diagnostics diff --git a/src/tint/resolver/const_eval_test.cc b/src/tint/resolver/const_eval_test.cc index e7c342d21a..9377fb86b1 100644 --- a/src/tint/resolver/const_eval_test.cc +++ b/src/tint/resolver/const_eval_test.cc @@ -3229,27 +3229,27 @@ INSTANTIATE_TEST_SUITE_P(Add, )))); TEST_F(ResolverConstEvalTest, BinaryAbstractAddOverflow_AInt) { - GlobalConst("c", nullptr, Add(Expr(Source{{1, 1}}, AInt::Highest()), 1_a)); + GlobalConst("c", nullptr, Add(Source{{1, 1}}, Expr(AInt::Highest()), 1_a)); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "1:1 error: '-9223372036854775808' cannot be represented as 'abstract-int'"); } TEST_F(ResolverConstEvalTest, BinaryAbstractAddUnderflow_AInt) { - GlobalConst("c", nullptr, Add(Expr(Source{{1, 1}}, AInt::Lowest()), -1_a)); + GlobalConst("c", nullptr, Add(Source{{1, 1}}, Expr(AInt::Lowest()), -1_a)); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "1:1 error: '9223372036854775807' cannot be represented as 'abstract-int'"); } TEST_F(ResolverConstEvalTest, BinaryAbstractAddOverflow_AFloat) { - GlobalConst("c", nullptr, Add(Expr(Source{{1, 1}}, AFloat::Highest()), AFloat::Highest())); + GlobalConst("c", nullptr, Add(Source{{1, 1}}, Expr(AFloat::Highest()), AFloat::Highest())); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "1:1 error: 'inf' cannot be represented as 'abstract-float'"); } TEST_F(ResolverConstEvalTest, BinaryAbstractAddUnderflow_AFloat) { - GlobalConst("c", nullptr, Add(Expr(Source{{1, 1}}, AFloat::Lowest()), AFloat::Lowest())); + GlobalConst("c", nullptr, Add(Source{{1, 1}}, Expr(AFloat::Lowest()), AFloat::Lowest())); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "1:1 error: '-inf' cannot be represented as 'abstract-float'"); } diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index 046cf6f002..5d52c140dc 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -1585,8 +1585,10 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { const sem::Constant* value = nullptr; auto stage = sem::EarliestStage(ctor_or_conv.target->Stage(), args_stage); if (stage == sem::EvaluationStage::kConstant) { + auto const_args = + utils::Transform(args, [](auto* arg) { return arg->ConstantValue(); }); if (auto r = (const_eval_.*ctor_or_conv.const_eval_fn)( - ctor_or_conv.target->ReturnType(), args)) { + ctor_or_conv.target->ReturnType(), const_args, expr->source)) { value = r.Get(); } else { return nullptr; @@ -1891,7 +1893,9 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr, // If the builtin is @const, and all arguments have constant values, evaluate the builtin now. const sem::Constant* value = nullptr; if (stage == sem::EvaluationStage::kConstant) { - if (auto r = (const_eval_.*builtin.const_eval_fn)(builtin.sem->ReturnType(), args)) { + auto const_args = utils::Transform(args, [](auto* arg) { return arg->ConstantValue(); }); + if (auto r = (const_eval_.*builtin.const_eval_fn)(builtin.sem->ReturnType(), const_args, + expr->source)) { value = r.Get(); } else { return nullptr; @@ -2297,7 +2301,8 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) { auto stage = sem::EarliestStage(lhs->Stage(), rhs->Stage()); if (stage == sem::EvaluationStage::kConstant) { if (op.const_eval_fn) { - if (auto r = (const_eval_.*op.const_eval_fn)(op.result, utils::Vector{lhs, rhs})) { + auto const_args = utils::Vector{lhs->ConstantValue(), rhs->ConstantValue()}; + if (auto r = (const_eval_.*op.const_eval_fn)(op.result, const_args, expr->source)) { value = r.Get(); } else { return nullptr; @@ -2380,7 +2385,9 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) { stage = expr->Stage(); if (stage == sem::EvaluationStage::kConstant) { if (op.const_eval_fn) { - if (auto r = (const_eval_.*op.const_eval_fn)(ty, utils::Vector{expr})) { + if (auto r = (const_eval_.*op.const_eval_fn)( + ty, utils::Vector{expr->ConstantValue()}, + expr->Declaration()->source)) { value = r.Get(); } else { return nullptr;