tint: Allow ConstEval functions to fail

They now return a utils::Result so they can add an error to diagnostics
and return Failure. Returning nullptr still means cannot evaluate at
compile time, but not a failure.

Bug: tint:1581
Change-Id: Ic30d782fb9fa725ec2faf89a87f74de6282d0304
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/98107
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
Antonio Maiorano 2022-08-04 13:59:36 +00:00 committed by Dawn LUCI CQ
parent 6091d838a7
commit c2a052eaa4
3 changed files with 133 additions and 93 deletions

View File

@ -479,7 +479,7 @@ const Constant* TransformElements(ProgramBuilder& builder, F&& f, CONSTANTS&&...
ConstEval::ConstEval(ProgramBuilder& b) : builder(b) {}
const sem::Constant* ConstEval::Literal(const sem::Type* ty,
ConstEval::ConstantResult ConstEval::Literal(const sem::Type* ty,
const ast::LiteralExpression* literal) {
return Switch(
literal,
@ -510,7 +510,8 @@ const sem::Constant* ConstEval::Literal(const sem::Type* ty,
});
}
const sem::Constant* ConstEval::ArrayOrStructCtor(const sem::Type* ty,
ConstEval::ConstantResult ConstEval::ArrayOrStructCtor(
const sem::Type* ty,
utils::VectorRef<const sem::Expression*> args) {
if (args.IsEmpty()) {
return ZeroValue(builder, ty);
@ -530,7 +531,7 @@ const sem::Constant* ConstEval::ArrayOrStructCtor(const sem::Type* ty,
return CreateComposite(builder, ty, std::move(els));
}
const sem::Constant* ConstEval::Conv(const sem::Type* ty,
ConstEval::ConstantResult ConstEval::Conv(const sem::Type* ty,
utils::VectorRef<const sem::Expression*> args) {
uint32_t el_count = 0;
auto* el_ty = sem::Type::ElementOf(ty, &el_count);
@ -551,17 +552,17 @@ const sem::Constant* ConstEval::Conv(const sem::Type* ty,
return nullptr;
}
const sem::Constant* ConstEval::Zero(const sem::Type* ty,
ConstEval::ConstantResult ConstEval::Zero(const sem::Type* ty,
utils::VectorRef<const sem::Expression*>) {
return ZeroValue(builder, ty);
}
const sem::Constant* ConstEval::Identity(const sem::Type*,
ConstEval::ConstantResult ConstEval::Identity(const sem::Type*,
utils::VectorRef<const sem::Expression*> args) {
return args[0]->ConstantValue();
}
const sem::Constant* ConstEval::VecSplat(const sem::Type* ty,
ConstEval::ConstantResult ConstEval::VecSplat(const sem::Type* ty,
utils::VectorRef<const sem::Expression*> args) {
if (auto* arg = args[0]->ConstantValue()) {
return builder.create<Splat>(ty, arg, static_cast<const sem::Vector*>(ty)->Width());
@ -569,7 +570,7 @@ const sem::Constant* ConstEval::VecSplat(const sem::Type* ty,
return nullptr;
}
const sem::Constant* ConstEval::VecCtorS(const sem::Type* ty,
ConstEval::ConstantResult ConstEval::VecCtorS(const sem::Type* ty,
utils::VectorRef<const sem::Expression*> args) {
utils::Vector<const sem::Constant*, 4> els;
for (auto* arg : args) {
@ -578,7 +579,7 @@ const sem::Constant* ConstEval::VecCtorS(const sem::Type* ty,
return CreateComposite(builder, ty, std::move(els));
}
const sem::Constant* ConstEval::VecCtorM(const sem::Type* ty,
ConstEval::ConstantResult ConstEval::VecCtorM(const sem::Type* ty,
utils::VectorRef<const sem::Expression*> args) {
utils::Vector<const sem::Constant*, 4> els;
for (auto* arg : args) {
@ -603,7 +604,7 @@ const sem::Constant* ConstEval::VecCtorM(const sem::Type* ty,
return CreateComposite(builder, ty, std::move(els));
}
const sem::Constant* ConstEval::MatCtorS(const sem::Type* ty,
ConstEval::ConstantResult ConstEval::MatCtorS(const sem::Type* ty,
utils::VectorRef<const sem::Expression*> args) {
auto* m = static_cast<const sem::Matrix*>(ty);
@ -619,7 +620,7 @@ const sem::Constant* ConstEval::MatCtorS(const sem::Type* ty,
return CreateComposite(builder, ty, std::move(els));
}
const sem::Constant* ConstEval::MatCtorV(const sem::Type* ty,
ConstEval::ConstantResult ConstEval::MatCtorV(const sem::Type* ty,
utils::VectorRef<const sem::Expression*> args) {
utils::Vector<const sem::Constant*, 4> els;
for (auto* arg : args) {
@ -628,16 +629,16 @@ const sem::Constant* ConstEval::MatCtorV(const sem::Type* ty,
return CreateComposite(builder, ty, std::move(els));
}
const sem::Constant* ConstEval::Index(const sem::Expression* obj_expr,
ConstEval::ConstantResult ConstEval::Index(const sem::Expression* obj_expr,
const sem::Expression* idx_expr) {
auto obj_val = obj_expr->ConstantValue();
if (!obj_val) {
return {};
return nullptr;
}
auto idx_val = idx_expr->ConstantValue();
if (!idx_val) {
return {};
return nullptr;
}
uint32_t el_count = 0;
@ -656,16 +657,16 @@ const sem::Constant* ConstEval::Index(const sem::Expression* obj_expr,
return obj_val->Index(static_cast<size_t>(idx));
}
const sem::Constant* ConstEval::MemberAccess(const sem::Expression* obj_expr,
ConstEval::ConstantResult ConstEval::MemberAccess(const sem::Expression* obj_expr,
const sem::StructMember* member) {
auto obj_val = obj_expr->ConstantValue();
if (!obj_val) {
return {};
return nullptr;
}
return obj_val->Index(static_cast<size_t>(member->Index()));
}
const sem::Constant* ConstEval::Swizzle(const sem::Type* ty,
ConstEval::ConstantResult ConstEval::Swizzle(const sem::Type* ty,
const sem::Expression* vec_expr,
utils::VectorRef<uint32_t> indices) {
auto* vec_val = vec_expr->ConstantValue();
@ -681,12 +682,12 @@ const sem::Constant* ConstEval::Swizzle(const sem::Type* ty,
}
}
const sem::Constant* ConstEval::Bitcast(const sem::Type*, const sem::Expression*) {
ConstEval::ConstantResult ConstEval::Bitcast(const sem::Type*, const sem::Expression*) {
// TODO(crbug.com/tint/1581): Implement @const intrinsics
return nullptr;
}
const sem::Constant* ConstEval::OpComplement(const sem::Type*,
ConstEval::ConstantResult ConstEval::OpComplement(const sem::Type*,
utils::VectorRef<const sem::Expression*> args) {
auto transform = [&](const sem::Constant* c) {
auto create = [&](auto i) {
@ -697,10 +698,10 @@ const sem::Constant* ConstEval::OpComplement(const sem::Type*,
return TransformElements(builder, transform, args[0]->ConstantValue());
}
const sem::Constant* ConstEval::OpMinus(const sem::Type*,
ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type*,
utils::VectorRef<const sem::Expression*> args) {
auto transform = [&](const sem::Constant* c) {
auto create = [&](auto i) { //
auto create = [&](auto i) {
// For signed integrals, avoid C++ UB by not negating the
// smallest negative number. In WGSL, this operation is well
// defined to return the same value, see:
@ -721,7 +722,7 @@ const sem::Constant* ConstEval::OpMinus(const sem::Type*,
return TransformElements(builder, transform, args[0]->ConstantValue());
}
const sem::Constant* ConstEval::atan2(const sem::Type*,
ConstEval::ConstantResult ConstEval::atan2(const sem::Type*,
utils::VectorRef<const sem::Expression*> args) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
auto create = [&](auto i, auto j) {
@ -733,7 +734,7 @@ const sem::Constant* ConstEval::atan2(const sem::Type*,
args[1]->ConstantValue());
}
const sem::Constant* ConstEval::clamp(const sem::Type*,
ConstEval::ConstantResult ConstEval::clamp(const sem::Type*,
utils::VectorRef<const sem::Expression*> args) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1,
const sem::Constant* c2) {

View File

@ -44,10 +44,6 @@ namespace tint::resolver {
/// before calling a method to evaluate an expression's value.
class ConstEval {
public:
/// Typedef for a constant evaluation function
using Function = const sem::Constant* (ConstEval::*)(const sem::Type* result_ty,
utils::VectorRef<const sem::Expression*>);
/// The result type of a method that may raise a diagnostic error and the caller should abort
/// resolving. Can be one of three distinct values:
/// * A non-null sem::Constant pointer. Returned when a expression resolves to a creation time
@ -59,6 +55,10 @@ class ConstEval {
/// resolving.
using ConstantResult = utils::Result<const sem::Constant*>;
/// Typedef for a constant evaluation function
using Function = ConstantResult (ConstEval::*)(const sem::Type* result_ty,
utils::VectorRef<const sem::Expression*>);
/// Constructor
/// @param b the program builder
explicit ConstEval(ProgramBuilder& b);
@ -70,35 +70,35 @@ class ConstEval {
/// @param ty the target type - must be an array or constructor
/// @param args the input arguments
/// @return the constructed value, or null if the value cannot be calculated
const sem::Constant* ArrayOrStructCtor(const sem::Type* ty,
ConstantResult ArrayOrStructCtor(const sem::Type* ty,
utils::VectorRef<const sem::Expression*> args);
/// @param ty the target type
/// @param expr the input expression
/// @return the bit-cast of the given expression to the given type, or null if the value cannot
/// be calculated
const sem::Constant* Bitcast(const sem::Type* ty, const sem::Expression* expr);
ConstantResult Bitcast(const sem::Type* ty, const sem::Expression* expr);
/// @param obj the object being indexed
/// @param idx the index expression
/// @return the result of the index, or null if the value cannot be calculated
const sem::Constant* Index(const sem::Expression* obj, const sem::Expression* idx);
ConstantResult Index(const sem::Expression* obj, const sem::Expression* idx);
/// @param ty the result type
/// @param lit the literal AST node
/// @return the constant value of the literal
const sem::Constant* Literal(const sem::Type* ty, const ast::LiteralExpression* lit);
ConstantResult Literal(const sem::Type* ty, const ast::LiteralExpression* lit);
/// @param obj the object being accessed
/// @param member the member
/// @return the result of the member access, or null if the value cannot be calculated
const sem::Constant* MemberAccess(const sem::Expression* obj, const sem::StructMember* member);
ConstantResult MemberAccess(const sem::Expression* obj, const sem::StructMember* member);
/// @param ty the result type
/// @param vector the vector being swizzled
/// @param indices the swizzle indices
/// @return the result of the swizzle, or null if the value cannot be calculated
const sem::Constant* Swizzle(const sem::Type* ty,
ConstantResult Swizzle(const sem::Type* ty,
const sem::Expression* vector,
utils::VectorRef<uint32_t> indices);
@ -117,73 +117,65 @@ class ConstEval {
/// @param ty the result type
/// @param args the input arguments
/// @return the converted value, or null if the value cannot be calculated
const sem::Constant* Conv(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
ConstantResult Conv(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
/// Zero value type constructor
/// @param ty the result type
/// @param args the input arguments (no arguments provided)
/// @return the constructed value, or null if the value cannot be calculated
const sem::Constant* Zero(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
ConstantResult Zero(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
/// Identity value type constructor
/// @param ty the result type
/// @param args the input arguments
/// @return the constructed value, or null if the value cannot be calculated
const sem::Constant* Identity(const sem::Type* ty,
utils::VectorRef<const sem::Expression*> args);
ConstantResult Identity(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
/// Vector splat constructor
/// @param ty the vector type
/// @param args the input arguments
/// @return the constructed value, or null if the value cannot be calculated
const sem::Constant* VecSplat(const sem::Type* ty,
utils::VectorRef<const sem::Expression*> args);
ConstantResult VecSplat(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
/// Vector constructor using scalars
/// @param ty the vector type
/// @param args the input arguments
/// @return the constructed value, or null if the value cannot be calculated
const sem::Constant* VecCtorS(const sem::Type* ty,
utils::VectorRef<const sem::Expression*> args);
ConstantResult VecCtorS(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
/// Vector constructor using a mix of scalars and smaller vectors
/// @param ty the vector type
/// @param args the input arguments
/// @return the constructed value, or null if the value cannot be calculated
const sem::Constant* VecCtorM(const sem::Type* ty,
utils::VectorRef<const sem::Expression*> args);
ConstantResult VecCtorM(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
/// Matrix constructor using scalar values
/// @param ty the matrix type
/// @param args the input arguments
/// @return the constructed value, or null if the value cannot be calculated
const sem::Constant* MatCtorS(const sem::Type* ty,
utils::VectorRef<const sem::Expression*> args);
ConstantResult MatCtorS(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
/// Matrix constructor using column vectors
/// @param ty the matrix type
/// @param args the input arguments
/// @return the constructed value, or null if the value cannot be calculated
const sem::Constant* MatCtorV(const sem::Type* ty,
utils::VectorRef<const sem::Expression*> args);
ConstantResult MatCtorV(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
////////////////////////////////////////////////////////////////////////////
// Operators
// Unary Operators
////////////////////////////////////////////////////////////////////////////
/// Complement operator '~'
/// @param ty the integer type
/// @param args the input arguments
/// @return the result value, or null if the value cannot be calculated
const sem::Constant* OpComplement(const sem::Type* ty,
utils::VectorRef<const sem::Expression*> args);
ConstantResult OpComplement(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
/// Minus operator '-'
/// @param ty the expression type
/// @param args the input arguments
/// @return the result value, or null if the value cannot be calculated
const sem::Constant* OpMinus(const sem::Type* ty,
utils::VectorRef<const sem::Expression*> args);
ConstantResult OpMinus(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
////////////////////////////////////////////////////////////////////////////
// Builtins
@ -193,13 +185,13 @@ class ConstEval {
/// @param ty the expression type
/// @param args the input arguments
/// @return the result value, or null if the value cannot be calculated
const sem::Constant* atan2(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
ConstantResult atan2(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
/// clamp builtin
/// @param ty the expression type
/// @param args the input arguments
/// @return the result value, or null if the value cannot be calculated
const sem::Constant* clamp(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
ConstantResult clamp(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
private:
/// Adds the given error message to the diagnostics

View File

@ -1501,7 +1501,12 @@ sem::Expression* Resolver::IndexAccessor(const ast::IndexAccessorExpression* exp
}
auto stage = sem::EarliestStage(obj->Stage(), idx->Stage());
auto val = const_eval_.Index(obj, idx);
const sem::Constant* val = nullptr;
if (auto r = const_eval_.Index(obj, idx)) {
val = r.Get();
} else {
return nullptr;
}
bool has_side_effects = idx->HasSideEffects() || obj->HasSideEffects();
auto* sem = builder_->create<sem::IndexAccessorExpression>(
expr, ty, stage, obj, idx, current_statement_, std::move(val), has_side_effects,
@ -1520,7 +1525,12 @@ sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) {
return nullptr;
}
auto val = const_eval_.Bitcast(ty, inner);
const sem::Constant* val = nullptr;
if (auto r = const_eval_.Bitcast(ty, inner)) {
val = r.Get();
} else {
return nullptr;
}
auto stage = sem::EvaluationStage::kRuntime; // TODO(crbug.com/tint/1581)
auto* sem = builder_->create<sem::Expression>(expr, ty, stage, current_statement_,
std::move(val), inner->HasSideEffects());
@ -1575,8 +1585,12 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
const sem::Constant* value = nullptr;
auto stage = sem::EarliestStage(ctor_or_conv.target->Stage(), args_stage);
if (stage == sem::EvaluationStage::kConstant) {
value =
(const_eval_.*ctor_or_conv.const_eval_fn)(ctor_or_conv.target->ReturnType(), args);
if (auto r = (const_eval_.*ctor_or_conv.const_eval_fn)(
ctor_or_conv.target->ReturnType(), args)) {
value = r.Get();
} else {
return nullptr;
}
}
return builder_->create<sem::Call>(expr, ctor_or_conv.target, stage, std::move(args),
current_statement_, value, has_side_effects);
@ -1593,7 +1607,11 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
auto stage = args_stage; // The evaluation stage of the call
const sem::Constant* value = nullptr; // The constant value for the call
if (stage == sem::EvaluationStage::kConstant) {
value = const_eval_.ArrayOrStructCtor(ty, args);
if (auto r = const_eval_.ArrayOrStructCtor(ty, args)) {
value = r.Get();
} else {
return nullptr;
}
if (!value) {
// Constant evaluation failed.
// Can happen for expressions that will fail validation (later).
@ -1873,7 +1891,11 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr,
// If the builtin is @const, and all arguments have constant values, evaluate the builtin now.
const sem::Constant* value = nullptr;
if (stage == sem::EvaluationStage::kConstant) {
value = (const_eval_.*builtin.const_eval_fn)(builtin.sem->ReturnType(), args);
if (auto r = (const_eval_.*builtin.const_eval_fn)(builtin.sem->ReturnType(), args)) {
value = r.Get();
} else {
return nullptr;
}
}
bool has_side_effects =
@ -2035,7 +2057,12 @@ sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) {
return nullptr;
}
auto val = const_eval_.Literal(ty, literal);
const sem::Constant* val = nullptr;
if (auto r = const_eval_.Literal(ty, literal)) {
val = r.Get();
} else {
return nullptr;
}
return builder_->create<sem::Expression>(literal, ty, sem::EvaluationStage::kConstant,
current_statement_, std::move(val),
/* has_side_effects */ false);
@ -2156,7 +2183,12 @@ sem::Expression* Resolver::MemberAccessor(const ast::MemberAccessorExpression* e
ret = builder_->create<sem::Reference>(ret, ref->StorageClass(), ref->Access());
}
auto* val = const_eval_.MemberAccess(object, member);
const sem::Constant* val = nullptr;
if (auto r = const_eval_.MemberAccess(object, member)) {
val = r.Get();
} else {
return nullptr;
}
return builder_->create<sem::StructMemberAccess>(expr, ret, current_statement_, val, object,
member, has_side_effects, source_var);
}
@ -2224,10 +2256,13 @@ sem::Expression* Resolver::MemberAccessor(const ast::MemberAccessorExpression* e
// the swizzle.
ret = builder_->create<sem::Vector>(vec->type(), static_cast<uint32_t>(size));
}
auto* val = const_eval_.Swizzle(ret, object, swizzle);
if (auto r = const_eval_.Swizzle(ret, object, swizzle)) {
auto* val = r.Get();
return builder_->create<sem::Swizzle>(expr, ret, current_statement_, val, object,
std::move(swizzle), has_side_effects, source_var);
}
return nullptr;
}
AddError("invalid member accessor expression. Expected vector or struct, got '" +
sem_.TypeNameOf(storage_ty) + "'",
@ -2240,7 +2275,6 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
const auto* rhs = sem_.Get(expr->rhs);
auto* lhs_ty = lhs->Type()->UnwrapRef();
auto* rhs_ty = rhs->Type()->UnwrapRef();
auto stage = sem::EvaluationStage::kRuntime; // TODO(crbug.com/tint/1581)
auto op = intrinsic_table_->Lookup(expr->op, lhs_ty, rhs_ty, expr->source, false);
if (!op.result) {
@ -2260,8 +2294,17 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
}
const sem::Constant* value = nullptr;
auto stage = sem::EarliestStage(lhs->Stage(), rhs->Stage());
if (stage == sem::EvaluationStage::kConstant) {
if (op.const_eval_fn) {
value = (const_eval_.*op.const_eval_fn)(op.result, utils::Vector{lhs, rhs});
if (auto r = (const_eval_.*op.const_eval_fn)(op.result, utils::Vector{lhs, rhs})) {
value = r.Get();
} else {
return nullptr;
}
} else {
stage = sem::EvaluationStage::kRuntime;
}
}
bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects();
@ -2337,7 +2380,11 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
stage = expr->Stage();
if (stage == sem::EvaluationStage::kConstant) {
if (op.const_eval_fn) {
value = (const_eval_.*op.const_eval_fn)(ty, utils::Vector{expr});
if (auto r = (const_eval_.*op.const_eval_fn)(ty, utils::Vector{expr})) {
value = r.Get();
} else {
return nullptr;
}
} else {
stage = sem::EvaluationStage::kRuntime;
}