tint: Simplify the resolver constant evaluation

And expand the handling to include matrices.

Bug: tint:1504
Change-Id: I6fd9ce239d13acf0e2f74b8ea19dfac3457e348c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91026
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Ben Clayton 2022-05-20 19:55:50 +00:00 committed by Dawn LUCI CQ
parent 8f4f449540
commit d3de38d7e3
4 changed files with 72 additions and 44 deletions

View File

@ -67,3 +67,4 @@ const char* str(CtorConvIntrinsic i) {
}
} // namespace tint::resolver

View File

@ -332,7 +332,9 @@ class Resolver {
//////////////////////////////////////////////////////////////////////////////
/// Cast `Value` to `target_type`
/// @return the casted value
sem::Constant ConstantCast(const sem::Constant& value, const sem::Type* target_elem_type);
sem::Constant ConstantCast(const sem::Constant& value,
const sem::Type* target_type,
const sem::Type* target_element_type = nullptr);
sem::Constant EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type);
sem::Constant EvaluateConstantValue(const ast::LiteralExpression* literal,

View File

@ -23,6 +23,33 @@
using namespace tint::number_suffixes; // NOLINT
namespace tint::resolver {
namespace {
sem::Constant::Scalars CastScalars(sem::Constant::Scalars in, const sem::Type* target_type) {
sem::Constant::Scalars out;
out.reserve(in.size());
for (auto v : in) {
// TODO(crbug.com/tint/1504): Check that value fits in new type
out.emplace_back(Switch<sem::Constant::Scalar>(
target_type, //
[&](const sem::AbstractInt*) { return sem::Constant::Cast<AInt>(v); },
[&](const sem::AbstractFloat*) { return sem::Constant::Cast<AFloat>(v); },
[&](const sem::I32*) { return sem::Constant::Cast<AInt>(v); },
[&](const sem::U32*) { return sem::Constant::Cast<AInt>(v); },
[&](const sem::F32*) { return sem::Constant::Cast<AFloat>(v); },
[&](const sem::F16*) { return sem::Constant::Cast<AFloat>(v); },
[&](const sem::Bool*) { return sem::Constant::Cast<bool>(v); },
[&](Default) {
diag::List diags;
TINT_UNREACHABLE(Semantic, diags)
<< "invalid element type " << target_type->TypeInfo().name;
return sem::Constant::Scalar(false);
}));
}
return out;
}
} // namespace
sem::Constant Resolver::EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type) {
if (auto* e = expr->As<ast::LiteralExpression>()) {
@ -51,21 +78,22 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::LiteralExpression* lite
sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call,
const sem::Type* type) {
auto* vec = type->As<sem::Vector>();
// For now, only fold scalars and vectors
if (!type->is_scalar() && !vec) {
uint32_t result_size = 0;
auto* el_ty = sem::Type::ElementOf(type, &result_size);
if (!el_ty) {
return {};
}
auto* elem_type = vec ? vec->type() : type;
int result_size = vec ? static_cast<int>(vec->Width()) : 1;
// ElementOf() will also return the element type of array, which we do not support.
if (type->Is<sem::Array>()) {
return {};
}
// For zero value init, return 0s
if (call->args.empty()) {
using Scalars = sem::Constant::Scalars;
auto constant = Switch(
elem_type,
return Switch(
el_ty,
[&](const sem::AbstractInt*) {
return sem::Constant(type, Scalars(result_size, AInt(0)));
},
@ -77,63 +105,53 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call,
[&](const sem::F32*) { return sem::Constant(type, Scalars(result_size, AFloat(0))); },
[&](const sem::F16*) { return sem::Constant(type, Scalars(result_size, AFloat(0))); },
[&](const sem::Bool*) { return sem::Constant(type, Scalars(result_size, false)); });
if (constant.IsValid()) {
return constant;
}
}
// Build value for type_ctor from each child value by casting to
// type_ctor's type.
// Build value for type_ctor from each child value by casting to type_ctor's type.
sem::Constant::Scalars elems;
for (auto* expr : call->args) {
auto* arg = builder_->Sem().Get(expr);
if (!arg || !arg->ConstantValue()) {
if (!arg) {
return {};
}
auto cast = ConstantCast(arg->ConstantValue(), elem_type);
elems.insert(elems.end(), cast.Elements().begin(), cast.Elements().end());
auto value = arg->ConstantValue();
if (!value) {
return {};
}
elems.insert(elems.end(), value.Elements().begin(), value.Elements().end());
}
// Splat single-value initializers
if (elems.size() == 1) {
for (int i = 0; i < result_size - 1; ++i) {
for (uint32_t i = 0; i < result_size - 1; ++i) {
elems.emplace_back(elems[0]);
}
}
return sem::Constant(type, std::move(elems));
// Finally cast the elements to the desired type.
auto cast = CastScalars(elems, el_ty);
return sem::Constant(type, std::move(cast));
}
sem::Constant Resolver::ConstantCast(const sem::Constant& value,
const sem::Type* target_elem_type) {
if (value.ElementType() == target_elem_type) {
const sem::Type* target_type,
const sem::Type* target_element_type /* = nullptr */) {
if (value.Type() == target_type) {
return value;
}
sem::Constant::Scalars elems;
for (size_t i = 0; i < value.Elements().size(); ++i) {
// TODO(crbug.com/tint/1504): Check that value fits in new type
elems.emplace_back(Switch<sem::Constant::Scalar>(
target_elem_type, //
[&](const sem::AbstractInt*) { return value.ElementAs<AInt>(i); },
[&](const sem::AbstractFloat*) { return value.ElementAs<AFloat>(i); },
[&](const sem::I32*) { return value.ElementAs<AInt>(i); },
[&](const sem::U32*) { return value.ElementAs<AInt>(i); },
[&](const sem::F32*) { return value.ElementAs<AFloat>(i); },
[&](const sem::F16*) { return value.ElementAs<AFloat>(i); },
[&](const sem::Bool*) { return value.ElementAs<bool>(i); },
[&](Default) {
diag::List diags;
TINT_UNREACHABLE(Semantic, diags)
<< "invalid element type " << target_elem_type->TypeInfo().name;
return sem::Constant::Scalar(false);
}));
if (target_element_type == nullptr) {
target_element_type = sem::Type::ElementOf(target_type);
}
if (target_element_type == nullptr) {
return {};
}
if (value.ElementType() == target_element_type) {
return sem::Constant(target_type, value.Elements());
}
auto* target_type =
value.Type()->Is<sem::Vector>()
? builder_->create<sem::Vector>(target_elem_type, static_cast<uint32_t>(elems.size()))
: target_elem_type;
auto elems = CastScalars(value.Elements(), target_element_type);
return sem::Constant(target_type, elems);
}

View File

@ -81,7 +81,14 @@ class Constant {
/// @return the value of the scalar `static_cast` to type T.
template <typename T>
T ElementAs(size_t index) const {
return std::visit([](auto val) { return static_cast<T>(val); }, elems_[index]);
return Cast<T>(elems_[index]);
}
/// @param s the input scalar
/// @returns the scalar `s` cast to the type `T`.
template <typename T>
static T Cast(Scalar s) {
return std::visit([](auto v) { return static_cast<T>(v); }, s);
}
private: