tint: Simplify the resolver constant evaluation

And expand the handling to include matrices.

Bug: tint:1504
Change-Id: I6fd9ce239d13acf0e2f74b8ea19dfac3457e348c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91026
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Ben Clayton 2022-05-20 19:55:50 +00:00 committed by Dawn LUCI CQ
parent 8f4f449540
commit d3de38d7e3
4 changed files with 72 additions and 44 deletions

View File

@ -67,3 +67,4 @@ const char* str(CtorConvIntrinsic i) {
} }
} // namespace tint::resolver } // namespace tint::resolver

View File

@ -332,7 +332,9 @@ class Resolver {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
/// Cast `Value` to `target_type` /// Cast `Value` to `target_type`
/// @return the casted value /// @return the casted value
sem::Constant ConstantCast(const sem::Constant& value, const sem::Type* target_elem_type); sem::Constant ConstantCast(const sem::Constant& value,
const sem::Type* target_type,
const sem::Type* target_element_type = nullptr);
sem::Constant EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type); sem::Constant EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type);
sem::Constant EvaluateConstantValue(const ast::LiteralExpression* literal, sem::Constant EvaluateConstantValue(const ast::LiteralExpression* literal,

View File

@ -23,6 +23,33 @@
using namespace tint::number_suffixes; // NOLINT using namespace tint::number_suffixes; // NOLINT
namespace tint::resolver { namespace tint::resolver {
namespace {
sem::Constant::Scalars CastScalars(sem::Constant::Scalars in, const sem::Type* target_type) {
sem::Constant::Scalars out;
out.reserve(in.size());
for (auto v : in) {
// TODO(crbug.com/tint/1504): Check that value fits in new type
out.emplace_back(Switch<sem::Constant::Scalar>(
target_type, //
[&](const sem::AbstractInt*) { return sem::Constant::Cast<AInt>(v); },
[&](const sem::AbstractFloat*) { return sem::Constant::Cast<AFloat>(v); },
[&](const sem::I32*) { return sem::Constant::Cast<AInt>(v); },
[&](const sem::U32*) { return sem::Constant::Cast<AInt>(v); },
[&](const sem::F32*) { return sem::Constant::Cast<AFloat>(v); },
[&](const sem::F16*) { return sem::Constant::Cast<AFloat>(v); },
[&](const sem::Bool*) { return sem::Constant::Cast<bool>(v); },
[&](Default) {
diag::List diags;
TINT_UNREACHABLE(Semantic, diags)
<< "invalid element type " << target_type->TypeInfo().name;
return sem::Constant::Scalar(false);
}));
}
return out;
}
} // namespace
sem::Constant Resolver::EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type) { sem::Constant Resolver::EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type) {
if (auto* e = expr->As<ast::LiteralExpression>()) { if (auto* e = expr->As<ast::LiteralExpression>()) {
@ -51,21 +78,22 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::LiteralExpression* lite
sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call, sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call,
const sem::Type* type) { const sem::Type* type) {
auto* vec = type->As<sem::Vector>(); uint32_t result_size = 0;
auto* el_ty = sem::Type::ElementOf(type, &result_size);
// For now, only fold scalars and vectors if (!el_ty) {
if (!type->is_scalar() && !vec) {
return {}; return {};
} }
auto* elem_type = vec ? vec->type() : type; // ElementOf() will also return the element type of array, which we do not support.
int result_size = vec ? static_cast<int>(vec->Width()) : 1; if (type->Is<sem::Array>()) {
return {};
}
// For zero value init, return 0s // For zero value init, return 0s
if (call->args.empty()) { if (call->args.empty()) {
using Scalars = sem::Constant::Scalars; using Scalars = sem::Constant::Scalars;
auto constant = Switch( return Switch(
elem_type, el_ty,
[&](const sem::AbstractInt*) { [&](const sem::AbstractInt*) {
return sem::Constant(type, Scalars(result_size, AInt(0))); return sem::Constant(type, Scalars(result_size, AInt(0)));
}, },
@ -77,63 +105,53 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call,
[&](const sem::F32*) { return sem::Constant(type, Scalars(result_size, AFloat(0))); }, [&](const sem::F32*) { return sem::Constant(type, Scalars(result_size, AFloat(0))); },
[&](const sem::F16*) { return sem::Constant(type, Scalars(result_size, AFloat(0))); }, [&](const sem::F16*) { return sem::Constant(type, Scalars(result_size, AFloat(0))); },
[&](const sem::Bool*) { return sem::Constant(type, Scalars(result_size, false)); }); [&](const sem::Bool*) { return sem::Constant(type, Scalars(result_size, false)); });
if (constant.IsValid()) {
return constant;
}
} }
// Build value for type_ctor from each child value by casting to // Build value for type_ctor from each child value by casting to type_ctor's type.
// type_ctor's type.
sem::Constant::Scalars elems; sem::Constant::Scalars elems;
for (auto* expr : call->args) { for (auto* expr : call->args) {
auto* arg = builder_->Sem().Get(expr); auto* arg = builder_->Sem().Get(expr);
if (!arg || !arg->ConstantValue()) { if (!arg) {
return {}; return {};
} }
auto cast = ConstantCast(arg->ConstantValue(), elem_type); auto value = arg->ConstantValue();
elems.insert(elems.end(), cast.Elements().begin(), cast.Elements().end()); if (!value) {
return {};
}
elems.insert(elems.end(), value.Elements().begin(), value.Elements().end());
} }
// Splat single-value initializers // Splat single-value initializers
if (elems.size() == 1) { if (elems.size() == 1) {
for (int i = 0; i < result_size - 1; ++i) { for (uint32_t i = 0; i < result_size - 1; ++i) {
elems.emplace_back(elems[0]); elems.emplace_back(elems[0]);
} }
} }
return sem::Constant(type, std::move(elems)); // Finally cast the elements to the desired type.
auto cast = CastScalars(elems, el_ty);
return sem::Constant(type, std::move(cast));
} }
sem::Constant Resolver::ConstantCast(const sem::Constant& value, sem::Constant Resolver::ConstantCast(const sem::Constant& value,
const sem::Type* target_elem_type) { const sem::Type* target_type,
if (value.ElementType() == target_elem_type) { const sem::Type* target_element_type /* = nullptr */) {
if (value.Type() == target_type) {
return value; return value;
} }
sem::Constant::Scalars elems; if (target_element_type == nullptr) {
for (size_t i = 0; i < value.Elements().size(); ++i) { target_element_type = sem::Type::ElementOf(target_type);
// TODO(crbug.com/tint/1504): Check that value fits in new type }
elems.emplace_back(Switch<sem::Constant::Scalar>( if (target_element_type == nullptr) {
target_elem_type, // return {};
[&](const sem::AbstractInt*) { return value.ElementAs<AInt>(i); }, }
[&](const sem::AbstractFloat*) { return value.ElementAs<AFloat>(i); }, if (value.ElementType() == target_element_type) {
[&](const sem::I32*) { return value.ElementAs<AInt>(i); }, return sem::Constant(target_type, value.Elements());
[&](const sem::U32*) { return value.ElementAs<AInt>(i); },
[&](const sem::F32*) { return value.ElementAs<AFloat>(i); },
[&](const sem::F16*) { return value.ElementAs<AFloat>(i); },
[&](const sem::Bool*) { return value.ElementAs<bool>(i); },
[&](Default) {
diag::List diags;
TINT_UNREACHABLE(Semantic, diags)
<< "invalid element type " << target_elem_type->TypeInfo().name;
return sem::Constant::Scalar(false);
}));
} }
auto* target_type = auto elems = CastScalars(value.Elements(), target_element_type);
value.Type()->Is<sem::Vector>()
? builder_->create<sem::Vector>(target_elem_type, static_cast<uint32_t>(elems.size()))
: target_elem_type;
return sem::Constant(target_type, elems); return sem::Constant(target_type, elems);
} }

View File

@ -81,7 +81,14 @@ class Constant {
/// @return the value of the scalar `static_cast` to type T. /// @return the value of the scalar `static_cast` to type T.
template <typename T> template <typename T>
T ElementAs(size_t index) const { T ElementAs(size_t index) const {
return std::visit([](auto val) { return static_cast<T>(val); }, elems_[index]); return Cast<T>(elems_[index]);
}
/// @param s the input scalar
/// @returns the scalar `s` cast to the type `T`.
template <typename T>
static T Cast(Scalar s) {
return std::visit([](auto v) { return static_cast<T>(v); }, s);
} }
private: private: