tint/resolver: Optimize constant evaluation methods

Materialize() was re-evaluating the constant values for the incoming
semantic expression, despite this already being evaluated. Just use the
sem::Expression::ConstantValue().

resolver.cc already has all the semantic pointers, so pass them in
instead of pointlessly hitting the ast -> sem map.

Change-Id: If2bc7cd10f79079fb811e9d83c5150dd3c0c244c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/95764
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
Ben Clayton 2022-07-07 17:30:11 +00:00 committed by Dawn LUCI CQ
parent 51265542e9
commit d5f53ab580
3 changed files with 71 additions and 84 deletions

View File

@ -1306,11 +1306,11 @@ const sem::Expression* Resolver::Materialize(const sem::Expression* expr,
if (!validator_.Materialize(target_ty, src_ty, decl->source)) {
return nullptr;
}
auto expr_val = EvaluateConstantValue(decl, expr->Type());
auto expr_val = expr->ConstantValue();
if (!expr_val) {
TINT_ICE(Resolver, builder_->Diagnostics())
<< decl->source << "EvaluateConstantValue(" << decl->TypeInfo().name
<< ") returned invalid value";
<< decl->source << "Materialize(" << decl->TypeInfo().name
<< ") called on expression with no constant value";
return nullptr;
}
auto materialized_val = ConvertValue(expr_val, target_ty, decl->source);
@ -1422,7 +1422,7 @@ sem::Expression* Resolver::IndexAccessor(const ast::IndexAccessorExpression* exp
ty = builder_->create<sem::Reference>(ty, ref->StorageClass(), ref->Access());
}
auto val = EvaluateConstantValue(expr, ty);
auto val = EvaluateIndexValue(obj, idx);
bool has_side_effects = idx->HasSideEffects() || obj->HasSideEffects();
auto* sem = builder_->create<sem::IndexAccessorExpression>(
expr, ty, obj, idx, current_statement_, std::move(val), has_side_effects,
@ -1441,7 +1441,7 @@ sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) {
return nullptr;
}
auto val = EvaluateConstantValue(expr, ty);
auto val = EvaluateBitcastValue(inner, ty);
auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_, std::move(val),
inner->HasSideEffects());
@ -1489,9 +1489,9 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
if (!MaterializeArguments(args, call_target)) {
return nullptr;
}
auto val = EvaluateConstantValue(expr, call_target->ReturnType());
auto val = EvaluateCtorOrConvValue(args, call_target->ReturnType());
return builder_->create<sem::Call>(expr, call_target, std::move(args), current_statement_,
std::move(val), has_side_effects);
val, has_side_effects);
};
// ct_ctor_or_conv is a helper for building either a sem::TypeConstructor or sem::TypeConversion
@ -1528,10 +1528,9 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
if (!MaterializeArguments(args, call_target)) {
return nullptr;
}
auto val = EvaluateConstantValue(expr, call_target->ReturnType());
auto val = EvaluateCtorOrConvValue(args, arr);
return builder_->create<sem::Call>(expr, call_target, std::move(args),
current_statement_, std::move(val),
has_side_effects);
current_statement_, val, has_side_effects);
},
[&](const sem::Struct* str) -> sem::Call* {
auto* call_target = utils::GetOrCreate(
@ -1551,7 +1550,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
if (!MaterializeArguments(args, call_target)) {
return nullptr;
}
auto val = EvaluateConstantValue(expr, call_target->ReturnType());
auto val = EvaluateCtorOrConvValue(args, str);
return builder_->create<sem::Call>(expr, call_target, std::move(args),
current_statement_, std::move(val),
has_side_effects);
@ -1857,7 +1856,7 @@ sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) {
return nullptr;
}
auto val = EvaluateConstantValue(literal, ty);
auto val = EvaluateLiteralValue(literal, ty);
return builder_->create<sem::Expression>(literal, ty, current_statement_, std::move(val),
/* has_side_effects */ false);
}
@ -2077,10 +2076,10 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
}
}
auto val = EvaluateConstantValue(expr, op.result);
auto* val = EvaluateBinaryValue(lhs, rhs, op);
bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects();
auto* sem = builder_->create<sem::Expression>(expr, op.result, current_statement_,
std::move(val), has_side_effects);
auto* sem = builder_->create<sem::Expression>(expr, op.result, current_statement_, val,
has_side_effects);
sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors();
return sem;
@ -2095,6 +2094,7 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
const sem::Type* ty = nullptr;
const sem::Variable* source_var = nullptr;
const sem::Constant* val = nullptr;
switch (unary->op) {
case ast::UnaryOp::kAddressOf:
@ -2147,12 +2147,12 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
}
}
ty = op.result;
val = EvaluateUnaryValue(expr, op);
break;
}
}
auto val = EvaluateConstantValue(unary, ty);
auto* sem = builder_->create<sem::Expression>(unary, ty, current_statement_, std::move(val),
auto* sem = builder_->create<sem::Expression>(unary, ty, current_statement_, val,
expr->HasSideEffects(), source_var);
sem->Behaviors() = expr->Behaviors();
return sem;

View File

@ -209,15 +209,17 @@ class Resolver {
/// These methods are called from the expression resolving methods, and so child-expression
/// nodes are guaranteed to have been already resolved and any constant values calculated.
////////////////////////////////////////////////////////////////////////////////////////////////
const sem::Constant* EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type);
const sem::Constant* EvaluateConstantValue(const ast::IdentifierExpression* ident,
const sem::Type* type);
const sem::Constant* EvaluateConstantValue(const ast::LiteralExpression* literal,
const sem::Type* type);
const sem::Constant* EvaluateConstantValue(const ast::CallExpression* call,
const sem::Type* type);
const sem::Constant* EvaluateConstantValue(const ast::IndexAccessorExpression* call,
const sem::Type* type);
const sem::Constant* EvaluateBinaryValue(const sem::Expression* lhs,
const sem::Expression* rhs,
const IntrinsicTable::BinaryOperator&);
const sem::Constant* EvaluateBitcastValue(const sem::Expression*, const sem::Type*);
const sem::Constant* EvaluateCtorOrConvValue(
const std::vector<const sem::Expression*>& args,
const sem::Type* ty); // Note: ty is not an array or structure
const sem::Constant* EvaluateIndexValue(const sem::Expression* obj, const sem::Expression* idx);
const sem::Constant* EvaluateLiteralValue(const ast::LiteralExpression*, const sem::Type*);
const sem::Constant* EvaluateUnaryValue(const sem::Expression*,
const IntrinsicTable::UnaryOperator&);
/// The result type of a ConstantEvaluation method.
/// Can be one of three distinct values:

View File

@ -21,6 +21,7 @@
#include "src/tint/sem/constant.h"
#include "src/tint/sem/type_constructor.h"
#include "src/tint/utils/compiler_macros.h"
#include "src/tint/utils/transform.h"
using namespace tint::number_suffixes; // NOLINT
@ -354,26 +355,8 @@ const Constant* CreateComposite(ProgramBuilder& builder,
} // namespace
const sem::Constant* Resolver::EvaluateConstantValue(const ast::Expression* expr,
const sem::Type* type) {
return Switch(
expr, //
[&](const ast::IdentifierExpression* e) { return EvaluateConstantValue(e, type); },
[&](const ast::LiteralExpression* e) { return EvaluateConstantValue(e, type); },
[&](const ast::CallExpression* e) { return EvaluateConstantValue(e, type); },
[&](const ast::IndexAccessorExpression* e) { return EvaluateConstantValue(e, type); });
}
const sem::Constant* Resolver::EvaluateConstantValue(const ast::IdentifierExpression* ident,
const sem::Type*) {
if (auto* sem = builder_->Sem().Get(ident)) {
return sem->ConstantValue();
}
return {};
}
const sem::Constant* Resolver::EvaluateConstantValue(const ast::LiteralExpression* literal,
const sem::Type* type) {
const sem::Constant* Resolver::EvaluateLiteralValue(const ast::LiteralExpression* literal,
const sem::Type* type) {
return Switch(
literal,
[&](const ast::BoolLiteralExpression* lit) {
@ -403,14 +386,11 @@ const sem::Constant* Resolver::EvaluateConstantValue(const ast::LiteralExpressio
});
}
const sem::Constant* Resolver::EvaluateConstantValue(const ast::CallExpression* call,
const sem::Type* ty) {
// Note: we are building constant values for array types. The working group as verbally agreed
// to support constant expression arrays, but this is not (yet) part of the spec.
// See: https://github.com/gpuweb/gpuweb/issues/3056
const sem::Constant* Resolver::EvaluateCtorOrConvValue(
const std::vector<const sem::Expression*>& args,
const sem::Type* ty) {
// For zero value init, return 0s
if (call->args.empty()) {
if (args.empty()) {
return ZeroValue(*builder_, ty);
}
@ -420,16 +400,10 @@ const sem::Constant* Resolver::EvaluateConstantValue(const ast::CallExpression*
return nullptr; // Target type does not support constant values
}
// value_of returns a `const Constant*` for the expression `expr`, or nullptr if the expression
// does not have a constant value.
auto value_of = [&](const ast::Expression* expr) {
return static_cast<const Constant*>(builder_->Sem().Get(expr)->ConstantValue());
};
if (call->args.size() == 1) {
if (args.size() == 1) {
// Type constructor or conversion that takes a single argument.
auto& src = call->args[0]->source;
auto* arg = value_of(call->args[0]);
auto& src = args[0]->Declaration()->source;
auto* arg = static_cast<const Constant*>(args[0]->ConstantValue());
if (!arg) {
return nullptr; // Single argument is not constant.
}
@ -463,8 +437,8 @@ const sem::Constant* Resolver::EvaluateConstantValue(const ast::CallExpression*
// Helper for pushing all the argument constants to `els`.
auto push_all_args = [&] {
for (auto* expr : call->args) {
auto* arg = value_of(expr);
for (auto* expr : args) {
auto* arg = static_cast<const Constant*>(expr->ConstantValue());
if (!arg) {
return;
}
@ -472,13 +446,15 @@ const sem::Constant* Resolver::EvaluateConstantValue(const ast::CallExpression*
}
};
// TODO(crbug.com/tint/1611): Add structure support.
Switch(
ty, // What's the target type being constructed?
[&](const sem::Vector*) {
// Vector can be constructed with a mix of scalars / abstract numerics and smaller
// vectors.
for (auto* expr : call->args) {
auto* arg = value_of(expr);
for (auto* expr : args) {
auto* arg = static_cast<const Constant*>(expr->ConstantValue());
if (!arg) {
return;
}
@ -500,13 +476,14 @@ const sem::Constant* Resolver::EvaluateConstantValue(const ast::CallExpression*
[&](const sem::Matrix* m) {
// Matrix can be constructed with a set of scalars / abstract numerics, or column
// vectors.
if (call->args.size() == m->columns() * m->rows()) {
if (args.size() == m->columns() * m->rows()) {
// Matrix built from scalars / abstract numerics
for (uint32_t c = 0; c < m->columns(); c++) {
std::vector<const Constant*> column;
column.reserve(m->rows());
for (uint32_t r = 0; r < m->rows(); r++) {
auto* arg = value_of(call->args[r + c * m->rows()]);
auto* arg =
static_cast<const Constant*>(args[r + c * m->rows()]->ConstantValue());
if (!arg) {
return;
}
@ -514,7 +491,7 @@ const sem::Constant* Resolver::EvaluateConstantValue(const ast::CallExpression*
}
els.push_back(CreateComposite(*builder_, m->ColumnType(), std::move(column)));
}
} else if (call->args.size() == m->columns()) {
} else if (args.size() == m->columns()) {
// Matrix built from column vectors
push_all_args();
}
@ -532,24 +509,14 @@ const sem::Constant* Resolver::EvaluateConstantValue(const ast::CallExpression*
return CreateComposite(*builder_, ty, std::move(els));
}
const sem::Constant* Resolver::EvaluateConstantValue(const ast::IndexAccessorExpression* accessor,
const sem::Type*) {
auto* obj_sem = builder_->Sem().Get(accessor->object);
if (!obj_sem) {
return {};
}
auto obj_val = obj_sem->ConstantValue();
const sem::Constant* Resolver::EvaluateIndexValue(const sem::Expression* obj_expr,
const sem::Expression* idx_expr) {
auto obj_val = obj_expr->ConstantValue();
if (!obj_val) {
return {};
}
auto* idx_sem = builder_->Sem().Get(accessor->index);
if (!idx_sem) {
return {};
}
auto idx_val = idx_sem->ConstantValue();
auto idx_val = idx_expr->ConstantValue();
if (!idx_val) {
return {};
}
@ -563,13 +530,31 @@ const sem::Constant* Resolver::EvaluateConstantValue(const ast::IndexAccessorExp
AddWarning("index " + std::to_string(idx) + " out of bounds [0.." +
std::to_string(el_count - 1) + "]. Clamping index to " +
std::to_string(clamped),
accessor->index->source);
idx_expr->Declaration()->source);
idx = clamped;
}
return obj_val->Index(static_cast<size_t>(idx));
}
const sem::Constant* Resolver::EvaluateBitcastValue(const sem::Expression*, const sem::Type*) {
// TODO(crbug.com/tint/1581): Implement @const intrinsics
return nullptr;
}
const sem::Constant* Resolver::EvaluateBinaryValue(const sem::Expression*,
const sem::Expression*,
const IntrinsicTable::BinaryOperator&) {
// TODO(crbug.com/tint/1581): Implement @const intrinsics
return nullptr;
}
const sem::Constant* Resolver::EvaluateUnaryValue(const sem::Expression*,
const IntrinsicTable::UnaryOperator&) {
// TODO(crbug.com/tint/1581): Implement @const intrinsics
return nullptr;
}
utils::Result<const sem::Constant*> Resolver::ConvertValue(const sem::Constant* value,
const sem::Type* target_ty,
const Source& source) {