tint/resolver: Optimize constant evaluation methods

Materialize() was re-evaluating the constant values for the incoming
semantic expression, despite this already being evaluated. Just use the
sem::Expression::ConstantValue().

resolver.cc already has all the semantic pointers, so pass them in
instead of pointlessly hitting the ast -> sem map.

Change-Id: If2bc7cd10f79079fb811e9d83c5150dd3c0c244c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/95764
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
Ben Clayton 2022-07-07 17:30:11 +00:00 committed by Dawn LUCI CQ
parent 51265542e9
commit d5f53ab580
3 changed files with 71 additions and 84 deletions

View File

@ -1306,11 +1306,11 @@ const sem::Expression* Resolver::Materialize(const sem::Expression* expr,
if (!validator_.Materialize(target_ty, src_ty, decl->source)) { if (!validator_.Materialize(target_ty, src_ty, decl->source)) {
return nullptr; return nullptr;
} }
auto expr_val = EvaluateConstantValue(decl, expr->Type()); auto expr_val = expr->ConstantValue();
if (!expr_val) { if (!expr_val) {
TINT_ICE(Resolver, builder_->Diagnostics()) TINT_ICE(Resolver, builder_->Diagnostics())
<< decl->source << "EvaluateConstantValue(" << decl->TypeInfo().name << decl->source << "Materialize(" << decl->TypeInfo().name
<< ") returned invalid value"; << ") called on expression with no constant value";
return nullptr; return nullptr;
} }
auto materialized_val = ConvertValue(expr_val, target_ty, decl->source); auto materialized_val = ConvertValue(expr_val, target_ty, decl->source);
@ -1422,7 +1422,7 @@ sem::Expression* Resolver::IndexAccessor(const ast::IndexAccessorExpression* exp
ty = builder_->create<sem::Reference>(ty, ref->StorageClass(), ref->Access()); ty = builder_->create<sem::Reference>(ty, ref->StorageClass(), ref->Access());
} }
auto val = EvaluateConstantValue(expr, ty); auto val = EvaluateIndexValue(obj, idx);
bool has_side_effects = idx->HasSideEffects() || obj->HasSideEffects(); bool has_side_effects = idx->HasSideEffects() || obj->HasSideEffects();
auto* sem = builder_->create<sem::IndexAccessorExpression>( auto* sem = builder_->create<sem::IndexAccessorExpression>(
expr, ty, obj, idx, current_statement_, std::move(val), has_side_effects, expr, ty, obj, idx, current_statement_, std::move(val), has_side_effects,
@ -1441,7 +1441,7 @@ sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) {
return nullptr; return nullptr;
} }
auto val = EvaluateConstantValue(expr, ty); auto val = EvaluateBitcastValue(inner, ty);
auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_, std::move(val), auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_, std::move(val),
inner->HasSideEffects()); inner->HasSideEffects());
@ -1489,9 +1489,9 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
if (!MaterializeArguments(args, call_target)) { if (!MaterializeArguments(args, call_target)) {
return nullptr; return nullptr;
} }
auto val = EvaluateConstantValue(expr, call_target->ReturnType()); auto val = EvaluateCtorOrConvValue(args, call_target->ReturnType());
return builder_->create<sem::Call>(expr, call_target, std::move(args), current_statement_, return builder_->create<sem::Call>(expr, call_target, std::move(args), current_statement_,
std::move(val), has_side_effects); val, has_side_effects);
}; };
// ct_ctor_or_conv is a helper for building either a sem::TypeConstructor or sem::TypeConversion // ct_ctor_or_conv is a helper for building either a sem::TypeConstructor or sem::TypeConversion
@ -1528,10 +1528,9 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
if (!MaterializeArguments(args, call_target)) { if (!MaterializeArguments(args, call_target)) {
return nullptr; return nullptr;
} }
auto val = EvaluateConstantValue(expr, call_target->ReturnType()); auto val = EvaluateCtorOrConvValue(args, arr);
return builder_->create<sem::Call>(expr, call_target, std::move(args), return builder_->create<sem::Call>(expr, call_target, std::move(args),
current_statement_, std::move(val), current_statement_, val, has_side_effects);
has_side_effects);
}, },
[&](const sem::Struct* str) -> sem::Call* { [&](const sem::Struct* str) -> sem::Call* {
auto* call_target = utils::GetOrCreate( auto* call_target = utils::GetOrCreate(
@ -1551,7 +1550,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
if (!MaterializeArguments(args, call_target)) { if (!MaterializeArguments(args, call_target)) {
return nullptr; return nullptr;
} }
auto val = EvaluateConstantValue(expr, call_target->ReturnType()); auto val = EvaluateCtorOrConvValue(args, str);
return builder_->create<sem::Call>(expr, call_target, std::move(args), return builder_->create<sem::Call>(expr, call_target, std::move(args),
current_statement_, std::move(val), current_statement_, std::move(val),
has_side_effects); has_side_effects);
@ -1857,7 +1856,7 @@ sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) {
return nullptr; return nullptr;
} }
auto val = EvaluateConstantValue(literal, ty); auto val = EvaluateLiteralValue(literal, ty);
return builder_->create<sem::Expression>(literal, ty, current_statement_, std::move(val), return builder_->create<sem::Expression>(literal, ty, current_statement_, std::move(val),
/* has_side_effects */ false); /* has_side_effects */ false);
} }
@ -2077,10 +2076,10 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
} }
} }
auto val = EvaluateConstantValue(expr, op.result); auto* val = EvaluateBinaryValue(lhs, rhs, op);
bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects(); bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects();
auto* sem = builder_->create<sem::Expression>(expr, op.result, current_statement_, auto* sem = builder_->create<sem::Expression>(expr, op.result, current_statement_, val,
std::move(val), has_side_effects); has_side_effects);
sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors(); sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors();
return sem; return sem;
@ -2095,6 +2094,7 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
const sem::Type* ty = nullptr; const sem::Type* ty = nullptr;
const sem::Variable* source_var = nullptr; const sem::Variable* source_var = nullptr;
const sem::Constant* val = nullptr;
switch (unary->op) { switch (unary->op) {
case ast::UnaryOp::kAddressOf: case ast::UnaryOp::kAddressOf:
@ -2147,12 +2147,12 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
} }
} }
ty = op.result; ty = op.result;
val = EvaluateUnaryValue(expr, op);
break; break;
} }
} }
auto val = EvaluateConstantValue(unary, ty); auto* sem = builder_->create<sem::Expression>(unary, ty, current_statement_, val,
auto* sem = builder_->create<sem::Expression>(unary, ty, current_statement_, std::move(val),
expr->HasSideEffects(), source_var); expr->HasSideEffects(), source_var);
sem->Behaviors() = expr->Behaviors(); sem->Behaviors() = expr->Behaviors();
return sem; return sem;

View File

@ -209,15 +209,17 @@ class Resolver {
/// These methods are called from the expression resolving methods, and so child-expression /// These methods are called from the expression resolving methods, and so child-expression
/// nodes are guaranteed to have been already resolved and any constant values calculated. /// nodes are guaranteed to have been already resolved and any constant values calculated.
//////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////
const sem::Constant* EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type); const sem::Constant* EvaluateBinaryValue(const sem::Expression* lhs,
const sem::Constant* EvaluateConstantValue(const ast::IdentifierExpression* ident, const sem::Expression* rhs,
const sem::Type* type); const IntrinsicTable::BinaryOperator&);
const sem::Constant* EvaluateConstantValue(const ast::LiteralExpression* literal, const sem::Constant* EvaluateBitcastValue(const sem::Expression*, const sem::Type*);
const sem::Type* type); const sem::Constant* EvaluateCtorOrConvValue(
const sem::Constant* EvaluateConstantValue(const ast::CallExpression* call, const std::vector<const sem::Expression*>& args,
const sem::Type* type); const sem::Type* ty); // Note: ty is not an array or structure
const sem::Constant* EvaluateConstantValue(const ast::IndexAccessorExpression* call, const sem::Constant* EvaluateIndexValue(const sem::Expression* obj, const sem::Expression* idx);
const sem::Type* type); const sem::Constant* EvaluateLiteralValue(const ast::LiteralExpression*, const sem::Type*);
const sem::Constant* EvaluateUnaryValue(const sem::Expression*,
const IntrinsicTable::UnaryOperator&);
/// The result type of a ConstantEvaluation method. /// The result type of a ConstantEvaluation method.
/// Can be one of three distinct values: /// Can be one of three distinct values:

View File

@ -21,6 +21,7 @@
#include "src/tint/sem/constant.h" #include "src/tint/sem/constant.h"
#include "src/tint/sem/type_constructor.h" #include "src/tint/sem/type_constructor.h"
#include "src/tint/utils/compiler_macros.h" #include "src/tint/utils/compiler_macros.h"
#include "src/tint/utils/transform.h"
using namespace tint::number_suffixes; // NOLINT using namespace tint::number_suffixes; // NOLINT
@ -354,25 +355,7 @@ const Constant* CreateComposite(ProgramBuilder& builder,
} // namespace } // namespace
const sem::Constant* Resolver::EvaluateConstantValue(const ast::Expression* expr, const sem::Constant* Resolver::EvaluateLiteralValue(const ast::LiteralExpression* literal,
const sem::Type* type) {
return Switch(
expr, //
[&](const ast::IdentifierExpression* e) { return EvaluateConstantValue(e, type); },
[&](const ast::LiteralExpression* e) { return EvaluateConstantValue(e, type); },
[&](const ast::CallExpression* e) { return EvaluateConstantValue(e, type); },
[&](const ast::IndexAccessorExpression* e) { return EvaluateConstantValue(e, type); });
}
const sem::Constant* Resolver::EvaluateConstantValue(const ast::IdentifierExpression* ident,
const sem::Type*) {
if (auto* sem = builder_->Sem().Get(ident)) {
return sem->ConstantValue();
}
return {};
}
const sem::Constant* Resolver::EvaluateConstantValue(const ast::LiteralExpression* literal,
const sem::Type* type) { const sem::Type* type) {
return Switch( return Switch(
literal, literal,
@ -403,14 +386,11 @@ const sem::Constant* Resolver::EvaluateConstantValue(const ast::LiteralExpressio
}); });
} }
const sem::Constant* Resolver::EvaluateConstantValue(const ast::CallExpression* call, const sem::Constant* Resolver::EvaluateCtorOrConvValue(
const std::vector<const sem::Expression*>& args,
const sem::Type* ty) { const sem::Type* ty) {
// Note: we are building constant values for array types. The working group as verbally agreed
// to support constant expression arrays, but this is not (yet) part of the spec.
// See: https://github.com/gpuweb/gpuweb/issues/3056
// For zero value init, return 0s // For zero value init, return 0s
if (call->args.empty()) { if (args.empty()) {
return ZeroValue(*builder_, ty); return ZeroValue(*builder_, ty);
} }
@ -420,16 +400,10 @@ const sem::Constant* Resolver::EvaluateConstantValue(const ast::CallExpression*
return nullptr; // Target type does not support constant values return nullptr; // Target type does not support constant values
} }
// value_of returns a `const Constant*` for the expression `expr`, or nullptr if the expression if (args.size() == 1) {
// does not have a constant value.
auto value_of = [&](const ast::Expression* expr) {
return static_cast<const Constant*>(builder_->Sem().Get(expr)->ConstantValue());
};
if (call->args.size() == 1) {
// Type constructor or conversion that takes a single argument. // Type constructor or conversion that takes a single argument.
auto& src = call->args[0]->source; auto& src = args[0]->Declaration()->source;
auto* arg = value_of(call->args[0]); auto* arg = static_cast<const Constant*>(args[0]->ConstantValue());
if (!arg) { if (!arg) {
return nullptr; // Single argument is not constant. return nullptr; // Single argument is not constant.
} }
@ -463,8 +437,8 @@ const sem::Constant* Resolver::EvaluateConstantValue(const ast::CallExpression*
// Helper for pushing all the argument constants to `els`. // Helper for pushing all the argument constants to `els`.
auto push_all_args = [&] { auto push_all_args = [&] {
for (auto* expr : call->args) { for (auto* expr : args) {
auto* arg = value_of(expr); auto* arg = static_cast<const Constant*>(expr->ConstantValue());
if (!arg) { if (!arg) {
return; return;
} }
@ -472,13 +446,15 @@ const sem::Constant* Resolver::EvaluateConstantValue(const ast::CallExpression*
} }
}; };
// TODO(crbug.com/tint/1611): Add structure support.
Switch( Switch(
ty, // What's the target type being constructed? ty, // What's the target type being constructed?
[&](const sem::Vector*) { [&](const sem::Vector*) {
// Vector can be constructed with a mix of scalars / abstract numerics and smaller // Vector can be constructed with a mix of scalars / abstract numerics and smaller
// vectors. // vectors.
for (auto* expr : call->args) { for (auto* expr : args) {
auto* arg = value_of(expr); auto* arg = static_cast<const Constant*>(expr->ConstantValue());
if (!arg) { if (!arg) {
return; return;
} }
@ -500,13 +476,14 @@ const sem::Constant* Resolver::EvaluateConstantValue(const ast::CallExpression*
[&](const sem::Matrix* m) { [&](const sem::Matrix* m) {
// Matrix can be constructed with a set of scalars / abstract numerics, or column // Matrix can be constructed with a set of scalars / abstract numerics, or column
// vectors. // vectors.
if (call->args.size() == m->columns() * m->rows()) { if (args.size() == m->columns() * m->rows()) {
// Matrix built from scalars / abstract numerics // Matrix built from scalars / abstract numerics
for (uint32_t c = 0; c < m->columns(); c++) { for (uint32_t c = 0; c < m->columns(); c++) {
std::vector<const Constant*> column; std::vector<const Constant*> column;
column.reserve(m->rows()); column.reserve(m->rows());
for (uint32_t r = 0; r < m->rows(); r++) { for (uint32_t r = 0; r < m->rows(); r++) {
auto* arg = value_of(call->args[r + c * m->rows()]); auto* arg =
static_cast<const Constant*>(args[r + c * m->rows()]->ConstantValue());
if (!arg) { if (!arg) {
return; return;
} }
@ -514,7 +491,7 @@ const sem::Constant* Resolver::EvaluateConstantValue(const ast::CallExpression*
} }
els.push_back(CreateComposite(*builder_, m->ColumnType(), std::move(column))); els.push_back(CreateComposite(*builder_, m->ColumnType(), std::move(column)));
} }
} else if (call->args.size() == m->columns()) { } else if (args.size() == m->columns()) {
// Matrix built from column vectors // Matrix built from column vectors
push_all_args(); push_all_args();
} }
@ -532,24 +509,14 @@ const sem::Constant* Resolver::EvaluateConstantValue(const ast::CallExpression*
return CreateComposite(*builder_, ty, std::move(els)); return CreateComposite(*builder_, ty, std::move(els));
} }
const sem::Constant* Resolver::EvaluateConstantValue(const ast::IndexAccessorExpression* accessor, const sem::Constant* Resolver::EvaluateIndexValue(const sem::Expression* obj_expr,
const sem::Type*) { const sem::Expression* idx_expr) {
auto* obj_sem = builder_->Sem().Get(accessor->object); auto obj_val = obj_expr->ConstantValue();
if (!obj_sem) {
return {};
}
auto obj_val = obj_sem->ConstantValue();
if (!obj_val) { if (!obj_val) {
return {}; return {};
} }
auto* idx_sem = builder_->Sem().Get(accessor->index); auto idx_val = idx_expr->ConstantValue();
if (!idx_sem) {
return {};
}
auto idx_val = idx_sem->ConstantValue();
if (!idx_val) { if (!idx_val) {
return {}; return {};
} }
@ -563,13 +530,31 @@ const sem::Constant* Resolver::EvaluateConstantValue(const ast::IndexAccessorExp
AddWarning("index " + std::to_string(idx) + " out of bounds [0.." + AddWarning("index " + std::to_string(idx) + " out of bounds [0.." +
std::to_string(el_count - 1) + "]. Clamping index to " + std::to_string(el_count - 1) + "]. Clamping index to " +
std::to_string(clamped), std::to_string(clamped),
accessor->index->source); idx_expr->Declaration()->source);
idx = clamped; idx = clamped;
} }
return obj_val->Index(static_cast<size_t>(idx)); return obj_val->Index(static_cast<size_t>(idx));
} }
const sem::Constant* Resolver::EvaluateBitcastValue(const sem::Expression*, const sem::Type*) {
// TODO(crbug.com/tint/1581): Implement @const intrinsics
return nullptr;
}
const sem::Constant* Resolver::EvaluateBinaryValue(const sem::Expression*,
const sem::Expression*,
const IntrinsicTable::BinaryOperator&) {
// TODO(crbug.com/tint/1581): Implement @const intrinsics
return nullptr;
}
const sem::Constant* Resolver::EvaluateUnaryValue(const sem::Expression*,
const IntrinsicTable::UnaryOperator&) {
// TODO(crbug.com/tint/1581): Implement @const intrinsics
return nullptr;
}
utils::Result<const sem::Constant*> Resolver::ConvertValue(const sem::Constant* value, utils::Result<const sem::Constant*> Resolver::ConvertValue(const sem::Constant* value,
const sem::Type* target_ty, const sem::Type* target_ty,
const Source& source) { const Source& source) {