tint: Allow ConstEval functions to fail

They now return a utils::Result so they can add an error to diagnostics
and return Failure. Returning nullptr still means cannot evaluate at
compile time, but not a failure.

Bug: tint:1581
Change-Id: Ic30d782fb9fa725ec2faf89a87f74de6282d0304
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/98107
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
Antonio Maiorano 2022-08-04 13:59:36 +00:00 committed by Dawn LUCI CQ
parent 6091d838a7
commit c2a052eaa4
3 changed files with 133 additions and 93 deletions

View File

@ -479,7 +479,7 @@ const Constant* TransformElements(ProgramBuilder& builder, F&& f, CONSTANTS&&...
ConstEval::ConstEval(ProgramBuilder& b) : builder(b) {} ConstEval::ConstEval(ProgramBuilder& b) : builder(b) {}
const sem::Constant* ConstEval::Literal(const sem::Type* ty, ConstEval::ConstantResult ConstEval::Literal(const sem::Type* ty,
const ast::LiteralExpression* literal) { const ast::LiteralExpression* literal) {
return Switch( return Switch(
literal, literal,
@ -510,7 +510,8 @@ const sem::Constant* ConstEval::Literal(const sem::Type* ty,
}); });
} }
const sem::Constant* ConstEval::ArrayOrStructCtor(const sem::Type* ty, ConstEval::ConstantResult ConstEval::ArrayOrStructCtor(
const sem::Type* ty,
utils::VectorRef<const sem::Expression*> args) { utils::VectorRef<const sem::Expression*> args) {
if (args.IsEmpty()) { if (args.IsEmpty()) {
return ZeroValue(builder, ty); return ZeroValue(builder, ty);
@ -530,7 +531,7 @@ const sem::Constant* ConstEval::ArrayOrStructCtor(const sem::Type* ty,
return CreateComposite(builder, ty, std::move(els)); return CreateComposite(builder, ty, std::move(els));
} }
const sem::Constant* ConstEval::Conv(const sem::Type* ty, ConstEval::ConstantResult ConstEval::Conv(const sem::Type* ty,
utils::VectorRef<const sem::Expression*> args) { utils::VectorRef<const sem::Expression*> args) {
uint32_t el_count = 0; uint32_t el_count = 0;
auto* el_ty = sem::Type::ElementOf(ty, &el_count); auto* el_ty = sem::Type::ElementOf(ty, &el_count);
@ -551,17 +552,17 @@ const sem::Constant* ConstEval::Conv(const sem::Type* ty,
return nullptr; return nullptr;
} }
const sem::Constant* ConstEval::Zero(const sem::Type* ty, ConstEval::ConstantResult ConstEval::Zero(const sem::Type* ty,
utils::VectorRef<const sem::Expression*>) { utils::VectorRef<const sem::Expression*>) {
return ZeroValue(builder, ty); return ZeroValue(builder, ty);
} }
const sem::Constant* ConstEval::Identity(const sem::Type*, ConstEval::ConstantResult ConstEval::Identity(const sem::Type*,
utils::VectorRef<const sem::Expression*> args) { utils::VectorRef<const sem::Expression*> args) {
return args[0]->ConstantValue(); return args[0]->ConstantValue();
} }
const sem::Constant* ConstEval::VecSplat(const sem::Type* ty, ConstEval::ConstantResult ConstEval::VecSplat(const sem::Type* ty,
utils::VectorRef<const sem::Expression*> args) { utils::VectorRef<const sem::Expression*> args) {
if (auto* arg = args[0]->ConstantValue()) { if (auto* arg = args[0]->ConstantValue()) {
return builder.create<Splat>(ty, arg, static_cast<const sem::Vector*>(ty)->Width()); return builder.create<Splat>(ty, arg, static_cast<const sem::Vector*>(ty)->Width());
@ -569,7 +570,7 @@ const sem::Constant* ConstEval::VecSplat(const sem::Type* ty,
return nullptr; return nullptr;
} }
const sem::Constant* ConstEval::VecCtorS(const sem::Type* ty, ConstEval::ConstantResult ConstEval::VecCtorS(const sem::Type* ty,
utils::VectorRef<const sem::Expression*> args) { utils::VectorRef<const sem::Expression*> args) {
utils::Vector<const sem::Constant*, 4> els; utils::Vector<const sem::Constant*, 4> els;
for (auto* arg : args) { for (auto* arg : args) {
@ -578,7 +579,7 @@ const sem::Constant* ConstEval::VecCtorS(const sem::Type* ty,
return CreateComposite(builder, ty, std::move(els)); return CreateComposite(builder, ty, std::move(els));
} }
const sem::Constant* ConstEval::VecCtorM(const sem::Type* ty, ConstEval::ConstantResult ConstEval::VecCtorM(const sem::Type* ty,
utils::VectorRef<const sem::Expression*> args) { utils::VectorRef<const sem::Expression*> args) {
utils::Vector<const sem::Constant*, 4> els; utils::Vector<const sem::Constant*, 4> els;
for (auto* arg : args) { for (auto* arg : args) {
@ -603,7 +604,7 @@ const sem::Constant* ConstEval::VecCtorM(const sem::Type* ty,
return CreateComposite(builder, ty, std::move(els)); return CreateComposite(builder, ty, std::move(els));
} }
const sem::Constant* ConstEval::MatCtorS(const sem::Type* ty, ConstEval::ConstantResult ConstEval::MatCtorS(const sem::Type* ty,
utils::VectorRef<const sem::Expression*> args) { utils::VectorRef<const sem::Expression*> args) {
auto* m = static_cast<const sem::Matrix*>(ty); auto* m = static_cast<const sem::Matrix*>(ty);
@ -619,7 +620,7 @@ const sem::Constant* ConstEval::MatCtorS(const sem::Type* ty,
return CreateComposite(builder, ty, std::move(els)); return CreateComposite(builder, ty, std::move(els));
} }
const sem::Constant* ConstEval::MatCtorV(const sem::Type* ty, ConstEval::ConstantResult ConstEval::MatCtorV(const sem::Type* ty,
utils::VectorRef<const sem::Expression*> args) { utils::VectorRef<const sem::Expression*> args) {
utils::Vector<const sem::Constant*, 4> els; utils::Vector<const sem::Constant*, 4> els;
for (auto* arg : args) { for (auto* arg : args) {
@ -628,16 +629,16 @@ const sem::Constant* ConstEval::MatCtorV(const sem::Type* ty,
return CreateComposite(builder, ty, std::move(els)); return CreateComposite(builder, ty, std::move(els));
} }
const sem::Constant* ConstEval::Index(const sem::Expression* obj_expr, ConstEval::ConstantResult ConstEval::Index(const sem::Expression* obj_expr,
const sem::Expression* idx_expr) { const sem::Expression* idx_expr) {
auto obj_val = obj_expr->ConstantValue(); auto obj_val = obj_expr->ConstantValue();
if (!obj_val) { if (!obj_val) {
return {}; return nullptr;
} }
auto idx_val = idx_expr->ConstantValue(); auto idx_val = idx_expr->ConstantValue();
if (!idx_val) { if (!idx_val) {
return {}; return nullptr;
} }
uint32_t el_count = 0; uint32_t el_count = 0;
@ -656,16 +657,16 @@ const sem::Constant* ConstEval::Index(const sem::Expression* obj_expr,
return obj_val->Index(static_cast<size_t>(idx)); return obj_val->Index(static_cast<size_t>(idx));
} }
const sem::Constant* ConstEval::MemberAccess(const sem::Expression* obj_expr, ConstEval::ConstantResult ConstEval::MemberAccess(const sem::Expression* obj_expr,
const sem::StructMember* member) { const sem::StructMember* member) {
auto obj_val = obj_expr->ConstantValue(); auto obj_val = obj_expr->ConstantValue();
if (!obj_val) { if (!obj_val) {
return {}; return nullptr;
} }
return obj_val->Index(static_cast<size_t>(member->Index())); return obj_val->Index(static_cast<size_t>(member->Index()));
} }
const sem::Constant* ConstEval::Swizzle(const sem::Type* ty, ConstEval::ConstantResult ConstEval::Swizzle(const sem::Type* ty,
const sem::Expression* vec_expr, const sem::Expression* vec_expr,
utils::VectorRef<uint32_t> indices) { utils::VectorRef<uint32_t> indices) {
auto* vec_val = vec_expr->ConstantValue(); auto* vec_val = vec_expr->ConstantValue();
@ -681,12 +682,12 @@ const sem::Constant* ConstEval::Swizzle(const sem::Type* ty,
} }
} }
const sem::Constant* ConstEval::Bitcast(const sem::Type*, const sem::Expression*) { ConstEval::ConstantResult ConstEval::Bitcast(const sem::Type*, const sem::Expression*) {
// TODO(crbug.com/tint/1581): Implement @const intrinsics // TODO(crbug.com/tint/1581): Implement @const intrinsics
return nullptr; return nullptr;
} }
const sem::Constant* ConstEval::OpComplement(const sem::Type*, ConstEval::ConstantResult ConstEval::OpComplement(const sem::Type*,
utils::VectorRef<const sem::Expression*> args) { utils::VectorRef<const sem::Expression*> args) {
auto transform = [&](const sem::Constant* c) { auto transform = [&](const sem::Constant* c) {
auto create = [&](auto i) { auto create = [&](auto i) {
@ -697,10 +698,10 @@ const sem::Constant* ConstEval::OpComplement(const sem::Type*,
return TransformElements(builder, transform, args[0]->ConstantValue()); return TransformElements(builder, transform, args[0]->ConstantValue());
} }
const sem::Constant* ConstEval::OpMinus(const sem::Type*, ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type*,
utils::VectorRef<const sem::Expression*> args) { utils::VectorRef<const sem::Expression*> args) {
auto transform = [&](const sem::Constant* c) { auto transform = [&](const sem::Constant* c) {
auto create = [&](auto i) { // auto create = [&](auto i) {
// For signed integrals, avoid C++ UB by not negating the // For signed integrals, avoid C++ UB by not negating the
// smallest negative number. In WGSL, this operation is well // smallest negative number. In WGSL, this operation is well
// defined to return the same value, see: // defined to return the same value, see:
@ -721,7 +722,7 @@ const sem::Constant* ConstEval::OpMinus(const sem::Type*,
return TransformElements(builder, transform, args[0]->ConstantValue()); return TransformElements(builder, transform, args[0]->ConstantValue());
} }
const sem::Constant* ConstEval::atan2(const sem::Type*, ConstEval::ConstantResult ConstEval::atan2(const sem::Type*,
utils::VectorRef<const sem::Expression*> args) { utils::VectorRef<const sem::Expression*> args) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
auto create = [&](auto i, auto j) { auto create = [&](auto i, auto j) {
@ -733,7 +734,7 @@ const sem::Constant* ConstEval::atan2(const sem::Type*,
args[1]->ConstantValue()); args[1]->ConstantValue());
} }
const sem::Constant* ConstEval::clamp(const sem::Type*, ConstEval::ConstantResult ConstEval::clamp(const sem::Type*,
utils::VectorRef<const sem::Expression*> args) { utils::VectorRef<const sem::Expression*> args) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1, auto transform = [&](const sem::Constant* c0, const sem::Constant* c1,
const sem::Constant* c2) { const sem::Constant* c2) {

View File

@ -44,10 +44,6 @@ namespace tint::resolver {
/// before calling a method to evaluate an expression's value. /// before calling a method to evaluate an expression's value.
class ConstEval { class ConstEval {
public: public:
/// Typedef for a constant evaluation function
using Function = const sem::Constant* (ConstEval::*)(const sem::Type* result_ty,
utils::VectorRef<const sem::Expression*>);
/// The result type of a method that may raise a diagnostic error and the caller should abort /// The result type of a method that may raise a diagnostic error and the caller should abort
/// resolving. Can be one of three distinct values: /// resolving. Can be one of three distinct values:
/// * A non-null sem::Constant pointer. Returned when a expression resolves to a creation time /// * A non-null sem::Constant pointer. Returned when a expression resolves to a creation time
@ -59,6 +55,10 @@ class ConstEval {
/// resolving. /// resolving.
using ConstantResult = utils::Result<const sem::Constant*>; using ConstantResult = utils::Result<const sem::Constant*>;
/// Typedef for a constant evaluation function
using Function = ConstantResult (ConstEval::*)(const sem::Type* result_ty,
utils::VectorRef<const sem::Expression*>);
/// Constructor /// Constructor
/// @param b the program builder /// @param b the program builder
explicit ConstEval(ProgramBuilder& b); explicit ConstEval(ProgramBuilder& b);
@ -70,35 +70,35 @@ class ConstEval {
/// @param ty the target type - must be an array or constructor /// @param ty the target type - must be an array or constructor
/// @param args the input arguments /// @param args the input arguments
/// @return the constructed value, or null if the value cannot be calculated /// @return the constructed value, or null if the value cannot be calculated
const sem::Constant* ArrayOrStructCtor(const sem::Type* ty, ConstantResult ArrayOrStructCtor(const sem::Type* ty,
utils::VectorRef<const sem::Expression*> args); utils::VectorRef<const sem::Expression*> args);
/// @param ty the target type /// @param ty the target type
/// @param expr the input expression /// @param expr the input expression
/// @return the bit-cast of the given expression to the given type, or null if the value cannot /// @return the bit-cast of the given expression to the given type, or null if the value cannot
/// be calculated /// be calculated
const sem::Constant* Bitcast(const sem::Type* ty, const sem::Expression* expr); ConstantResult Bitcast(const sem::Type* ty, const sem::Expression* expr);
/// @param obj the object being indexed /// @param obj the object being indexed
/// @param idx the index expression /// @param idx the index expression
/// @return the result of the index, or null if the value cannot be calculated /// @return the result of the index, or null if the value cannot be calculated
const sem::Constant* Index(const sem::Expression* obj, const sem::Expression* idx); ConstantResult Index(const sem::Expression* obj, const sem::Expression* idx);
/// @param ty the result type /// @param ty the result type
/// @param lit the literal AST node /// @param lit the literal AST node
/// @return the constant value of the literal /// @return the constant value of the literal
const sem::Constant* Literal(const sem::Type* ty, const ast::LiteralExpression* lit); ConstantResult Literal(const sem::Type* ty, const ast::LiteralExpression* lit);
/// @param obj the object being accessed /// @param obj the object being accessed
/// @param member the member /// @param member the member
/// @return the result of the member access, or null if the value cannot be calculated /// @return the result of the member access, or null if the value cannot be calculated
const sem::Constant* MemberAccess(const sem::Expression* obj, const sem::StructMember* member); ConstantResult MemberAccess(const sem::Expression* obj, const sem::StructMember* member);
/// @param ty the result type /// @param ty the result type
/// @param vector the vector being swizzled /// @param vector the vector being swizzled
/// @param indices the swizzle indices /// @param indices the swizzle indices
/// @return the result of the swizzle, or null if the value cannot be calculated /// @return the result of the swizzle, or null if the value cannot be calculated
const sem::Constant* Swizzle(const sem::Type* ty, ConstantResult Swizzle(const sem::Type* ty,
const sem::Expression* vector, const sem::Expression* vector,
utils::VectorRef<uint32_t> indices); utils::VectorRef<uint32_t> indices);
@ -117,73 +117,65 @@ class ConstEval {
/// @param ty the result type /// @param ty the result type
/// @param args the input arguments /// @param args the input arguments
/// @return the converted value, or null if the value cannot be calculated /// @return the converted value, or null if the value cannot be calculated
const sem::Constant* Conv(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args); ConstantResult Conv(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
/// Zero value type constructor /// Zero value type constructor
/// @param ty the result type /// @param ty the result type
/// @param args the input arguments (no arguments provided) /// @param args the input arguments (no arguments provided)
/// @return the constructed value, or null if the value cannot be calculated /// @return the constructed value, or null if the value cannot be calculated
const sem::Constant* Zero(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args); ConstantResult Zero(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
/// Identity value type constructor /// Identity value type constructor
/// @param ty the result type /// @param ty the result type
/// @param args the input arguments /// @param args the input arguments
/// @return the constructed value, or null if the value cannot be calculated /// @return the constructed value, or null if the value cannot be calculated
const sem::Constant* Identity(const sem::Type* ty, ConstantResult Identity(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
utils::VectorRef<const sem::Expression*> args);
/// Vector splat constructor /// Vector splat constructor
/// @param ty the vector type /// @param ty the vector type
/// @param args the input arguments /// @param args the input arguments
/// @return the constructed value, or null if the value cannot be calculated /// @return the constructed value, or null if the value cannot be calculated
const sem::Constant* VecSplat(const sem::Type* ty, ConstantResult VecSplat(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
utils::VectorRef<const sem::Expression*> args);
/// Vector constructor using scalars /// Vector constructor using scalars
/// @param ty the vector type /// @param ty the vector type
/// @param args the input arguments /// @param args the input arguments
/// @return the constructed value, or null if the value cannot be calculated /// @return the constructed value, or null if the value cannot be calculated
const sem::Constant* VecCtorS(const sem::Type* ty, ConstantResult VecCtorS(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
utils::VectorRef<const sem::Expression*> args);
/// Vector constructor using a mix of scalars and smaller vectors /// Vector constructor using a mix of scalars and smaller vectors
/// @param ty the vector type /// @param ty the vector type
/// @param args the input arguments /// @param args the input arguments
/// @return the constructed value, or null if the value cannot be calculated /// @return the constructed value, or null if the value cannot be calculated
const sem::Constant* VecCtorM(const sem::Type* ty, ConstantResult VecCtorM(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
utils::VectorRef<const sem::Expression*> args);
/// Matrix constructor using scalar values /// Matrix constructor using scalar values
/// @param ty the matrix type /// @param ty the matrix type
/// @param args the input arguments /// @param args the input arguments
/// @return the constructed value, or null if the value cannot be calculated /// @return the constructed value, or null if the value cannot be calculated
const sem::Constant* MatCtorS(const sem::Type* ty, ConstantResult MatCtorS(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
utils::VectorRef<const sem::Expression*> args);
/// Matrix constructor using column vectors /// Matrix constructor using column vectors
/// @param ty the matrix type /// @param ty the matrix type
/// @param args the input arguments /// @param args the input arguments
/// @return the constructed value, or null if the value cannot be calculated /// @return the constructed value, or null if the value cannot be calculated
const sem::Constant* MatCtorV(const sem::Type* ty, ConstantResult MatCtorV(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
utils::VectorRef<const sem::Expression*> args);
//////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////
// Operators // Unary Operators
//////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////
/// Complement operator '~' /// Complement operator '~'
/// @param ty the integer type /// @param ty the integer type
/// @param args the input arguments /// @param args the input arguments
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
const sem::Constant* OpComplement(const sem::Type* ty, ConstantResult OpComplement(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
utils::VectorRef<const sem::Expression*> args);
/// Minus operator '-' /// Minus operator '-'
/// @param ty the expression type /// @param ty the expression type
/// @param args the input arguments /// @param args the input arguments
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
const sem::Constant* OpMinus(const sem::Type* ty, ConstantResult OpMinus(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
utils::VectorRef<const sem::Expression*> args);
//////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////
// Builtins // Builtins
@ -193,13 +185,13 @@ class ConstEval {
/// @param ty the expression type /// @param ty the expression type
/// @param args the input arguments /// @param args the input arguments
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
const sem::Constant* atan2(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args); ConstantResult atan2(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
/// clamp builtin /// clamp builtin
/// @param ty the expression type /// @param ty the expression type
/// @param args the input arguments /// @param args the input arguments
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
const sem::Constant* clamp(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args); ConstantResult clamp(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
private: private:
/// Adds the given error message to the diagnostics /// Adds the given error message to the diagnostics

View File

@ -1501,7 +1501,12 @@ sem::Expression* Resolver::IndexAccessor(const ast::IndexAccessorExpression* exp
} }
auto stage = sem::EarliestStage(obj->Stage(), idx->Stage()); auto stage = sem::EarliestStage(obj->Stage(), idx->Stage());
auto val = const_eval_.Index(obj, idx); const sem::Constant* val = nullptr;
if (auto r = const_eval_.Index(obj, idx)) {
val = r.Get();
} else {
return nullptr;
}
bool has_side_effects = idx->HasSideEffects() || obj->HasSideEffects(); bool has_side_effects = idx->HasSideEffects() || obj->HasSideEffects();
auto* sem = builder_->create<sem::IndexAccessorExpression>( auto* sem = builder_->create<sem::IndexAccessorExpression>(
expr, ty, stage, obj, idx, current_statement_, std::move(val), has_side_effects, expr, ty, stage, obj, idx, current_statement_, std::move(val), has_side_effects,
@ -1520,7 +1525,12 @@ sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) {
return nullptr; return nullptr;
} }
auto val = const_eval_.Bitcast(ty, inner); const sem::Constant* val = nullptr;
if (auto r = const_eval_.Bitcast(ty, inner)) {
val = r.Get();
} else {
return nullptr;
}
auto stage = sem::EvaluationStage::kRuntime; // TODO(crbug.com/tint/1581) auto stage = sem::EvaluationStage::kRuntime; // TODO(crbug.com/tint/1581)
auto* sem = builder_->create<sem::Expression>(expr, ty, stage, current_statement_, auto* sem = builder_->create<sem::Expression>(expr, ty, stage, current_statement_,
std::move(val), inner->HasSideEffects()); std::move(val), inner->HasSideEffects());
@ -1575,8 +1585,12 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
const sem::Constant* value = nullptr; const sem::Constant* value = nullptr;
auto stage = sem::EarliestStage(ctor_or_conv.target->Stage(), args_stage); auto stage = sem::EarliestStage(ctor_or_conv.target->Stage(), args_stage);
if (stage == sem::EvaluationStage::kConstant) { if (stage == sem::EvaluationStage::kConstant) {
value = if (auto r = (const_eval_.*ctor_or_conv.const_eval_fn)(
(const_eval_.*ctor_or_conv.const_eval_fn)(ctor_or_conv.target->ReturnType(), args); ctor_or_conv.target->ReturnType(), args)) {
value = r.Get();
} else {
return nullptr;
}
} }
return builder_->create<sem::Call>(expr, ctor_or_conv.target, stage, std::move(args), return builder_->create<sem::Call>(expr, ctor_or_conv.target, stage, std::move(args),
current_statement_, value, has_side_effects); current_statement_, value, has_side_effects);
@ -1593,7 +1607,11 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
auto stage = args_stage; // The evaluation stage of the call auto stage = args_stage; // The evaluation stage of the call
const sem::Constant* value = nullptr; // The constant value for the call const sem::Constant* value = nullptr; // The constant value for the call
if (stage == sem::EvaluationStage::kConstant) { if (stage == sem::EvaluationStage::kConstant) {
value = const_eval_.ArrayOrStructCtor(ty, args); if (auto r = const_eval_.ArrayOrStructCtor(ty, args)) {
value = r.Get();
} else {
return nullptr;
}
if (!value) { if (!value) {
// Constant evaluation failed. // Constant evaluation failed.
// Can happen for expressions that will fail validation (later). // Can happen for expressions that will fail validation (later).
@ -1873,7 +1891,11 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr,
// If the builtin is @const, and all arguments have constant values, evaluate the builtin now. // If the builtin is @const, and all arguments have constant values, evaluate the builtin now.
const sem::Constant* value = nullptr; const sem::Constant* value = nullptr;
if (stage == sem::EvaluationStage::kConstant) { if (stage == sem::EvaluationStage::kConstant) {
value = (const_eval_.*builtin.const_eval_fn)(builtin.sem->ReturnType(), args); if (auto r = (const_eval_.*builtin.const_eval_fn)(builtin.sem->ReturnType(), args)) {
value = r.Get();
} else {
return nullptr;
}
} }
bool has_side_effects = bool has_side_effects =
@ -2035,7 +2057,12 @@ sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) {
return nullptr; return nullptr;
} }
auto val = const_eval_.Literal(ty, literal); const sem::Constant* val = nullptr;
if (auto r = const_eval_.Literal(ty, literal)) {
val = r.Get();
} else {
return nullptr;
}
return builder_->create<sem::Expression>(literal, ty, sem::EvaluationStage::kConstant, return builder_->create<sem::Expression>(literal, ty, sem::EvaluationStage::kConstant,
current_statement_, std::move(val), current_statement_, std::move(val),
/* has_side_effects */ false); /* has_side_effects */ false);
@ -2156,7 +2183,12 @@ sem::Expression* Resolver::MemberAccessor(const ast::MemberAccessorExpression* e
ret = builder_->create<sem::Reference>(ret, ref->StorageClass(), ref->Access()); ret = builder_->create<sem::Reference>(ret, ref->StorageClass(), ref->Access());
} }
auto* val = const_eval_.MemberAccess(object, member); const sem::Constant* val = nullptr;
if (auto r = const_eval_.MemberAccess(object, member)) {
val = r.Get();
} else {
return nullptr;
}
return builder_->create<sem::StructMemberAccess>(expr, ret, current_statement_, val, object, return builder_->create<sem::StructMemberAccess>(expr, ret, current_statement_, val, object,
member, has_side_effects, source_var); member, has_side_effects, source_var);
} }
@ -2224,10 +2256,13 @@ sem::Expression* Resolver::MemberAccessor(const ast::MemberAccessorExpression* e
// the swizzle. // the swizzle.
ret = builder_->create<sem::Vector>(vec->type(), static_cast<uint32_t>(size)); ret = builder_->create<sem::Vector>(vec->type(), static_cast<uint32_t>(size));
} }
auto* val = const_eval_.Swizzle(ret, object, swizzle); if (auto r = const_eval_.Swizzle(ret, object, swizzle)) {
auto* val = r.Get();
return builder_->create<sem::Swizzle>(expr, ret, current_statement_, val, object, return builder_->create<sem::Swizzle>(expr, ret, current_statement_, val, object,
std::move(swizzle), has_side_effects, source_var); std::move(swizzle), has_side_effects, source_var);
} }
return nullptr;
}
AddError("invalid member accessor expression. Expected vector or struct, got '" + AddError("invalid member accessor expression. Expected vector or struct, got '" +
sem_.TypeNameOf(storage_ty) + "'", sem_.TypeNameOf(storage_ty) + "'",
@ -2240,7 +2275,6 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
const auto* rhs = sem_.Get(expr->rhs); const auto* rhs = sem_.Get(expr->rhs);
auto* lhs_ty = lhs->Type()->UnwrapRef(); auto* lhs_ty = lhs->Type()->UnwrapRef();
auto* rhs_ty = rhs->Type()->UnwrapRef(); auto* rhs_ty = rhs->Type()->UnwrapRef();
auto stage = sem::EvaluationStage::kRuntime; // TODO(crbug.com/tint/1581)
auto op = intrinsic_table_->Lookup(expr->op, lhs_ty, rhs_ty, expr->source, false); auto op = intrinsic_table_->Lookup(expr->op, lhs_ty, rhs_ty, expr->source, false);
if (!op.result) { if (!op.result) {
@ -2260,8 +2294,17 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
} }
const sem::Constant* value = nullptr; const sem::Constant* value = nullptr;
auto stage = sem::EarliestStage(lhs->Stage(), rhs->Stage());
if (stage == sem::EvaluationStage::kConstant) {
if (op.const_eval_fn) { if (op.const_eval_fn) {
value = (const_eval_.*op.const_eval_fn)(op.result, utils::Vector{lhs, rhs}); if (auto r = (const_eval_.*op.const_eval_fn)(op.result, utils::Vector{lhs, rhs})) {
value = r.Get();
} else {
return nullptr;
}
} else {
stage = sem::EvaluationStage::kRuntime;
}
} }
bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects(); bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects();
@ -2337,7 +2380,11 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
stage = expr->Stage(); stage = expr->Stage();
if (stage == sem::EvaluationStage::kConstant) { if (stage == sem::EvaluationStage::kConstant) {
if (op.const_eval_fn) { if (op.const_eval_fn) {
value = (const_eval_.*op.const_eval_fn)(ty, utils::Vector{expr}); if (auto r = (const_eval_.*op.const_eval_fn)(ty, utils::Vector{expr})) {
value = r.Get();
} else {
return nullptr;
}
} else { } else {
stage = sem::EvaluationStage::kRuntime; stage = sem::EvaluationStage::kRuntime;
} }