From c2a052eaa4453eeb9c797b6f576143bdcf3b26a8 Mon Sep 17 00:00:00 2001 From: Antonio Maiorano Date: Thu, 4 Aug 2022 13:59:36 +0000 Subject: [PATCH] tint: Allow ConstEval functions to fail They now return a utils::Result so they can add an error to diagnostics and return Failure. Returning nullptr still means cannot evaluate at compile time, but not a failure. Bug: tint:1581 Change-Id: Ic30d782fb9fa725ec2faf89a87f74de6282d0304 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/98107 Kokoro: Kokoro Commit-Queue: Antonio Maiorano Reviewed-by: Dan Sinclair --- src/tint/resolver/const_eval.cc | 89 +++++++++++++++++---------------- src/tint/resolver/const_eval.h | 60 ++++++++++------------ src/tint/resolver/resolver.cc | 77 ++++++++++++++++++++++------ 3 files changed, 133 insertions(+), 93 deletions(-) diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc index 71840a880a..4c9808b907 100644 --- a/src/tint/resolver/const_eval.cc +++ b/src/tint/resolver/const_eval.cc @@ -479,8 +479,8 @@ const Constant* TransformElements(ProgramBuilder& builder, F&& f, CONSTANTS&&... ConstEval::ConstEval(ProgramBuilder& b) : builder(b) {} -const sem::Constant* ConstEval::Literal(const sem::Type* ty, - const ast::LiteralExpression* literal) { +ConstEval::ConstantResult ConstEval::Literal(const sem::Type* ty, + const ast::LiteralExpression* literal) { return Switch( literal, [&](const ast::BoolLiteralExpression* lit) { @@ -510,8 +510,9 @@ const sem::Constant* ConstEval::Literal(const sem::Type* ty, }); } -const sem::Constant* ConstEval::ArrayOrStructCtor(const sem::Type* ty, - utils::VectorRef args) { +ConstEval::ConstantResult ConstEval::ArrayOrStructCtor( + const sem::Type* ty, + utils::VectorRef args) { if (args.IsEmpty()) { return ZeroValue(builder, ty); } @@ -530,8 +531,8 @@ const sem::Constant* ConstEval::ArrayOrStructCtor(const sem::Type* ty, return CreateComposite(builder, ty, std::move(els)); } -const sem::Constant* ConstEval::Conv(const sem::Type* ty, - utils::VectorRef args) { +ConstEval::ConstantResult ConstEval::Conv(const sem::Type* ty, + utils::VectorRef args) { uint32_t el_count = 0; auto* el_ty = sem::Type::ElementOf(ty, &el_count); if (!el_ty) { @@ -551,26 +552,26 @@ const sem::Constant* ConstEval::Conv(const sem::Type* ty, return nullptr; } -const sem::Constant* ConstEval::Zero(const sem::Type* ty, - utils::VectorRef) { +ConstEval::ConstantResult ConstEval::Zero(const sem::Type* ty, + utils::VectorRef) { return ZeroValue(builder, ty); } -const sem::Constant* ConstEval::Identity(const sem::Type*, - utils::VectorRef args) { +ConstEval::ConstantResult ConstEval::Identity(const sem::Type*, + utils::VectorRef args) { return args[0]->ConstantValue(); } -const sem::Constant* ConstEval::VecSplat(const sem::Type* ty, - utils::VectorRef args) { +ConstEval::ConstantResult ConstEval::VecSplat(const sem::Type* ty, + utils::VectorRef args) { if (auto* arg = args[0]->ConstantValue()) { return builder.create(ty, arg, static_cast(ty)->Width()); } return nullptr; } -const sem::Constant* ConstEval::VecCtorS(const sem::Type* ty, - utils::VectorRef args) { +ConstEval::ConstantResult ConstEval::VecCtorS(const sem::Type* ty, + utils::VectorRef args) { utils::Vector els; for (auto* arg : args) { els.Push(arg->ConstantValue()); @@ -578,8 +579,8 @@ const sem::Constant* ConstEval::VecCtorS(const sem::Type* ty, return CreateComposite(builder, ty, std::move(els)); } -const sem::Constant* ConstEval::VecCtorM(const sem::Type* ty, - utils::VectorRef args) { +ConstEval::ConstantResult ConstEval::VecCtorM(const sem::Type* ty, + utils::VectorRef args) { utils::Vector els; for (auto* arg : args) { auto* val = arg->ConstantValue(); @@ -603,8 +604,8 @@ const sem::Constant* ConstEval::VecCtorM(const sem::Type* ty, return CreateComposite(builder, ty, std::move(els)); } -const sem::Constant* ConstEval::MatCtorS(const sem::Type* ty, - utils::VectorRef args) { +ConstEval::ConstantResult ConstEval::MatCtorS(const sem::Type* ty, + utils::VectorRef args) { auto* m = static_cast(ty); utils::Vector els; @@ -619,8 +620,8 @@ const sem::Constant* ConstEval::MatCtorS(const sem::Type* ty, return CreateComposite(builder, ty, std::move(els)); } -const sem::Constant* ConstEval::MatCtorV(const sem::Type* ty, - utils::VectorRef args) { +ConstEval::ConstantResult ConstEval::MatCtorV(const sem::Type* ty, + utils::VectorRef args) { utils::Vector els; for (auto* arg : args) { els.Push(arg->ConstantValue()); @@ -628,16 +629,16 @@ const sem::Constant* ConstEval::MatCtorV(const sem::Type* ty, return CreateComposite(builder, ty, std::move(els)); } -const sem::Constant* ConstEval::Index(const sem::Expression* obj_expr, - const sem::Expression* idx_expr) { +ConstEval::ConstantResult ConstEval::Index(const sem::Expression* obj_expr, + const sem::Expression* idx_expr) { auto obj_val = obj_expr->ConstantValue(); if (!obj_val) { - return {}; + return nullptr; } auto idx_val = idx_expr->ConstantValue(); if (!idx_val) { - return {}; + return nullptr; } uint32_t el_count = 0; @@ -656,18 +657,18 @@ const sem::Constant* ConstEval::Index(const sem::Expression* obj_expr, return obj_val->Index(static_cast(idx)); } -const sem::Constant* ConstEval::MemberAccess(const sem::Expression* obj_expr, - const sem::StructMember* member) { +ConstEval::ConstantResult ConstEval::MemberAccess(const sem::Expression* obj_expr, + const sem::StructMember* member) { auto obj_val = obj_expr->ConstantValue(); if (!obj_val) { - return {}; + return nullptr; } return obj_val->Index(static_cast(member->Index())); } -const sem::Constant* ConstEval::Swizzle(const sem::Type* ty, - const sem::Expression* vec_expr, - utils::VectorRef indices) { +ConstEval::ConstantResult ConstEval::Swizzle(const sem::Type* ty, + const sem::Expression* vec_expr, + utils::VectorRef indices) { auto* vec_val = vec_expr->ConstantValue(); if (!vec_val) { return nullptr; @@ -681,13 +682,13 @@ const sem::Constant* ConstEval::Swizzle(const sem::Type* ty, } } -const sem::Constant* ConstEval::Bitcast(const sem::Type*, const sem::Expression*) { +ConstEval::ConstantResult ConstEval::Bitcast(const sem::Type*, const sem::Expression*) { // TODO(crbug.com/tint/1581): Implement @const intrinsics return nullptr; } -const sem::Constant* ConstEval::OpComplement(const sem::Type*, - utils::VectorRef args) { +ConstEval::ConstantResult ConstEval::OpComplement(const sem::Type*, + utils::VectorRef args) { auto transform = [&](const sem::Constant* c) { auto create = [&](auto i) { return CreateElement(builder, c->Type(), decltype(i)(~i.value)); @@ -697,14 +698,14 @@ const sem::Constant* ConstEval::OpComplement(const sem::Type*, return TransformElements(builder, transform, args[0]->ConstantValue()); } -const sem::Constant* ConstEval::OpMinus(const sem::Type*, - utils::VectorRef args) { +ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type*, + utils::VectorRef args) { auto transform = [&](const sem::Constant* c) { - auto create = [&](auto i) { // - // For signed integrals, avoid C++ UB by not negating the - // smallest negative number. In WGSL, this operation is well - // defined to return the same value, see: - // https://gpuweb.github.io/gpuweb/wgsl/#arithmetic-expr. + auto create = [&](auto i) { + // For signed integrals, avoid C++ UB by not negating the + // smallest negative number. In WGSL, this operation is well + // defined to return the same value, see: + // https://gpuweb.github.io/gpuweb/wgsl/#arithmetic-expr. using T = UnwrapNumber; if constexpr (std::is_integral_v) { auto v = i.value; @@ -721,8 +722,8 @@ const sem::Constant* ConstEval::OpMinus(const sem::Type*, return TransformElements(builder, transform, args[0]->ConstantValue()); } -const sem::Constant* ConstEval::atan2(const sem::Type*, - utils::VectorRef args) { +ConstEval::ConstantResult ConstEval::atan2(const sem::Type*, + utils::VectorRef args) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto create = [&](auto i, auto j) { return CreateElement(builder, c0->Type(), decltype(i)(std::atan2(i.value, j.value))); @@ -733,8 +734,8 @@ const sem::Constant* ConstEval::atan2(const sem::Type*, args[1]->ConstantValue()); } -const sem::Constant* ConstEval::clamp(const sem::Type*, - utils::VectorRef args) { +ConstEval::ConstantResult ConstEval::clamp(const sem::Type*, + utils::VectorRef args) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1, const sem::Constant* c2) { auto create = [&](auto e, auto low, auto high) { diff --git a/src/tint/resolver/const_eval.h b/src/tint/resolver/const_eval.h index 5349637865..2059731005 100644 --- a/src/tint/resolver/const_eval.h +++ b/src/tint/resolver/const_eval.h @@ -44,10 +44,6 @@ namespace tint::resolver { /// before calling a method to evaluate an expression's value. class ConstEval { public: - /// Typedef for a constant evaluation function - using Function = const sem::Constant* (ConstEval::*)(const sem::Type* result_ty, - utils::VectorRef); - /// The result type of a method that may raise a diagnostic error and the caller should abort /// resolving. Can be one of three distinct values: /// * A non-null sem::Constant pointer. Returned when a expression resolves to a creation time @@ -59,6 +55,10 @@ class ConstEval { /// resolving. using ConstantResult = utils::Result; + /// Typedef for a constant evaluation function + using Function = ConstantResult (ConstEval::*)(const sem::Type* result_ty, + utils::VectorRef); + /// Constructor /// @param b the program builder explicit ConstEval(ProgramBuilder& b); @@ -70,37 +70,37 @@ class ConstEval { /// @param ty the target type - must be an array or constructor /// @param args the input arguments /// @return the constructed value, or null if the value cannot be calculated - const sem::Constant* ArrayOrStructCtor(const sem::Type* ty, - utils::VectorRef args); + ConstantResult ArrayOrStructCtor(const sem::Type* ty, + utils::VectorRef args); /// @param ty the target type /// @param expr the input expression /// @return the bit-cast of the given expression to the given type, or null if the value cannot /// be calculated - const sem::Constant* Bitcast(const sem::Type* ty, const sem::Expression* expr); + ConstantResult Bitcast(const sem::Type* ty, const sem::Expression* expr); /// @param obj the object being indexed /// @param idx the index expression /// @return the result of the index, or null if the value cannot be calculated - const sem::Constant* Index(const sem::Expression* obj, const sem::Expression* idx); + ConstantResult Index(const sem::Expression* obj, const sem::Expression* idx); /// @param ty the result type /// @param lit the literal AST node /// @return the constant value of the literal - const sem::Constant* Literal(const sem::Type* ty, const ast::LiteralExpression* lit); + ConstantResult Literal(const sem::Type* ty, const ast::LiteralExpression* lit); /// @param obj the object being accessed /// @param member the member /// @return the result of the member access, or null if the value cannot be calculated - const sem::Constant* MemberAccess(const sem::Expression* obj, const sem::StructMember* member); + ConstantResult MemberAccess(const sem::Expression* obj, const sem::StructMember* member); /// @param ty the result type /// @param vector the vector being swizzled /// @param indices the swizzle indices /// @return the result of the swizzle, or null if the value cannot be calculated - const sem::Constant* Swizzle(const sem::Type* ty, - const sem::Expression* vector, - utils::VectorRef indices); + ConstantResult Swizzle(const sem::Type* ty, + const sem::Expression* vector, + utils::VectorRef indices); /// Convert the `value` to `target_type` /// @param ty the result type @@ -117,73 +117,65 @@ class ConstEval { /// @param ty the result type /// @param args the input arguments /// @return the converted value, or null if the value cannot be calculated - const sem::Constant* Conv(const sem::Type* ty, utils::VectorRef args); + ConstantResult Conv(const sem::Type* ty, utils::VectorRef args); /// Zero value type constructor /// @param ty the result type /// @param args the input arguments (no arguments provided) /// @return the constructed value, or null if the value cannot be calculated - const sem::Constant* Zero(const sem::Type* ty, utils::VectorRef args); + ConstantResult Zero(const sem::Type* ty, utils::VectorRef args); /// Identity value type constructor /// @param ty the result type /// @param args the input arguments /// @return the constructed value, or null if the value cannot be calculated - const sem::Constant* Identity(const sem::Type* ty, - utils::VectorRef args); + ConstantResult Identity(const sem::Type* ty, utils::VectorRef args); /// Vector splat constructor /// @param ty the vector type /// @param args the input arguments /// @return the constructed value, or null if the value cannot be calculated - const sem::Constant* VecSplat(const sem::Type* ty, - utils::VectorRef args); + ConstantResult VecSplat(const sem::Type* ty, utils::VectorRef args); /// Vector constructor using scalars /// @param ty the vector type /// @param args the input arguments /// @return the constructed value, or null if the value cannot be calculated - const sem::Constant* VecCtorS(const sem::Type* ty, - utils::VectorRef args); + ConstantResult VecCtorS(const sem::Type* ty, utils::VectorRef args); /// Vector constructor using a mix of scalars and smaller vectors /// @param ty the vector type /// @param args the input arguments /// @return the constructed value, or null if the value cannot be calculated - const sem::Constant* VecCtorM(const sem::Type* ty, - utils::VectorRef args); + ConstantResult VecCtorM(const sem::Type* ty, utils::VectorRef args); /// Matrix constructor using scalar values /// @param ty the matrix type /// @param args the input arguments /// @return the constructed value, or null if the value cannot be calculated - const sem::Constant* MatCtorS(const sem::Type* ty, - utils::VectorRef args); + ConstantResult MatCtorS(const sem::Type* ty, utils::VectorRef args); /// Matrix constructor using column vectors /// @param ty the matrix type /// @param args the input arguments /// @return the constructed value, or null if the value cannot be calculated - const sem::Constant* MatCtorV(const sem::Type* ty, - utils::VectorRef args); + ConstantResult MatCtorV(const sem::Type* ty, utils::VectorRef args); //////////////////////////////////////////////////////////////////////////// - // Operators + // Unary Operators //////////////////////////////////////////////////////////////////////////// /// Complement operator '~' /// @param ty the integer type /// @param args the input arguments /// @return the result value, or null if the value cannot be calculated - const sem::Constant* OpComplement(const sem::Type* ty, - utils::VectorRef args); + ConstantResult OpComplement(const sem::Type* ty, utils::VectorRef args); /// Minus operator '-' /// @param ty the expression type /// @param args the input arguments /// @return the result value, or null if the value cannot be calculated - const sem::Constant* OpMinus(const sem::Type* ty, - utils::VectorRef args); + ConstantResult OpMinus(const sem::Type* ty, utils::VectorRef args); //////////////////////////////////////////////////////////////////////////// // Builtins @@ -193,13 +185,13 @@ class ConstEval { /// @param ty the expression type /// @param args the input arguments /// @return the result value, or null if the value cannot be calculated - const sem::Constant* atan2(const sem::Type* ty, utils::VectorRef args); + ConstantResult atan2(const sem::Type* ty, utils::VectorRef args); /// clamp builtin /// @param ty the expression type /// @param args the input arguments /// @return the result value, or null if the value cannot be calculated - const sem::Constant* clamp(const sem::Type* ty, utils::VectorRef args); + ConstantResult clamp(const sem::Type* ty, utils::VectorRef args); private: /// Adds the given error message to the diagnostics diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index ab25bfbe1f..317a91a96b 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -1501,7 +1501,12 @@ sem::Expression* Resolver::IndexAccessor(const ast::IndexAccessorExpression* exp } auto stage = sem::EarliestStage(obj->Stage(), idx->Stage()); - auto val = const_eval_.Index(obj, idx); + const sem::Constant* val = nullptr; + if (auto r = const_eval_.Index(obj, idx)) { + val = r.Get(); + } else { + return nullptr; + } bool has_side_effects = idx->HasSideEffects() || obj->HasSideEffects(); auto* sem = builder_->create( expr, ty, stage, obj, idx, current_statement_, std::move(val), has_side_effects, @@ -1520,7 +1525,12 @@ sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) { return nullptr; } - auto val = const_eval_.Bitcast(ty, inner); + const sem::Constant* val = nullptr; + if (auto r = const_eval_.Bitcast(ty, inner)) { + val = r.Get(); + } else { + return nullptr; + } auto stage = sem::EvaluationStage::kRuntime; // TODO(crbug.com/tint/1581) auto* sem = builder_->create(expr, ty, stage, current_statement_, std::move(val), inner->HasSideEffects()); @@ -1575,8 +1585,12 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { const sem::Constant* value = nullptr; auto stage = sem::EarliestStage(ctor_or_conv.target->Stage(), args_stage); if (stage == sem::EvaluationStage::kConstant) { - value = - (const_eval_.*ctor_or_conv.const_eval_fn)(ctor_or_conv.target->ReturnType(), args); + if (auto r = (const_eval_.*ctor_or_conv.const_eval_fn)( + ctor_or_conv.target->ReturnType(), args)) { + value = r.Get(); + } else { + return nullptr; + } } return builder_->create(expr, ctor_or_conv.target, stage, std::move(args), current_statement_, value, has_side_effects); @@ -1593,7 +1607,11 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { auto stage = args_stage; // The evaluation stage of the call const sem::Constant* value = nullptr; // The constant value for the call if (stage == sem::EvaluationStage::kConstant) { - value = const_eval_.ArrayOrStructCtor(ty, args); + if (auto r = const_eval_.ArrayOrStructCtor(ty, args)) { + value = r.Get(); + } else { + return nullptr; + } if (!value) { // Constant evaluation failed. // Can happen for expressions that will fail validation (later). @@ -1873,7 +1891,11 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr, // If the builtin is @const, and all arguments have constant values, evaluate the builtin now. const sem::Constant* value = nullptr; if (stage == sem::EvaluationStage::kConstant) { - value = (const_eval_.*builtin.const_eval_fn)(builtin.sem->ReturnType(), args); + if (auto r = (const_eval_.*builtin.const_eval_fn)(builtin.sem->ReturnType(), args)) { + value = r.Get(); + } else { + return nullptr; + } } bool has_side_effects = @@ -2035,7 +2057,12 @@ sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) { return nullptr; } - auto val = const_eval_.Literal(ty, literal); + const sem::Constant* val = nullptr; + if (auto r = const_eval_.Literal(ty, literal)) { + val = r.Get(); + } else { + return nullptr; + } return builder_->create(literal, ty, sem::EvaluationStage::kConstant, current_statement_, std::move(val), /* has_side_effects */ false); @@ -2156,7 +2183,12 @@ sem::Expression* Resolver::MemberAccessor(const ast::MemberAccessorExpression* e ret = builder_->create(ret, ref->StorageClass(), ref->Access()); } - auto* val = const_eval_.MemberAccess(object, member); + const sem::Constant* val = nullptr; + if (auto r = const_eval_.MemberAccess(object, member)) { + val = r.Get(); + } else { + return nullptr; + } return builder_->create(expr, ret, current_statement_, val, object, member, has_side_effects, source_var); } @@ -2224,9 +2256,12 @@ sem::Expression* Resolver::MemberAccessor(const ast::MemberAccessorExpression* e // the swizzle. ret = builder_->create(vec->type(), static_cast(size)); } - auto* val = const_eval_.Swizzle(ret, object, swizzle); - return builder_->create(expr, ret, current_statement_, val, object, - std::move(swizzle), has_side_effects, source_var); + if (auto r = const_eval_.Swizzle(ret, object, swizzle)) { + auto* val = r.Get(); + return builder_->create(expr, ret, current_statement_, val, object, + std::move(swizzle), has_side_effects, source_var); + } + return nullptr; } AddError("invalid member accessor expression. Expected vector or struct, got '" + @@ -2240,7 +2275,6 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) { const auto* rhs = sem_.Get(expr->rhs); auto* lhs_ty = lhs->Type()->UnwrapRef(); auto* rhs_ty = rhs->Type()->UnwrapRef(); - auto stage = sem::EvaluationStage::kRuntime; // TODO(crbug.com/tint/1581) auto op = intrinsic_table_->Lookup(expr->op, lhs_ty, rhs_ty, expr->source, false); if (!op.result) { @@ -2260,8 +2294,17 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) { } const sem::Constant* value = nullptr; - if (op.const_eval_fn) { - value = (const_eval_.*op.const_eval_fn)(op.result, utils::Vector{lhs, rhs}); + auto stage = sem::EarliestStage(lhs->Stage(), rhs->Stage()); + if (stage == sem::EvaluationStage::kConstant) { + if (op.const_eval_fn) { + if (auto r = (const_eval_.*op.const_eval_fn)(op.result, utils::Vector{lhs, rhs})) { + value = r.Get(); + } else { + return nullptr; + } + } else { + stage = sem::EvaluationStage::kRuntime; + } } bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects(); @@ -2337,7 +2380,11 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) { stage = expr->Stage(); if (stage == sem::EvaluationStage::kConstant) { if (op.const_eval_fn) { - value = (const_eval_.*op.const_eval_fn)(ty, utils::Vector{expr}); + if (auto r = (const_eval_.*op.const_eval_fn)(ty, utils::Vector{expr})) { + value = r.Get(); + } else { + return nullptr; + } } else { stage = sem::EvaluationStage::kRuntime; }