tint: Simplify the resolver constant evaluation
And expand the handling to include matrices. Bug: tint:1504 Change-Id: I6fd9ce239d13acf0e2f74b8ea19dfac3457e348c Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91026 Commit-Queue: Ben Clayton <bclayton@google.com> Reviewed-by: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
parent
8f4f449540
commit
d3de38d7e3
|
@ -67,3 +67,4 @@ const char* str(CtorConvIntrinsic i) {
|
|||
}
|
||||
|
||||
} // namespace tint::resolver
|
||||
|
||||
|
|
|
@ -332,7 +332,9 @@ class Resolver {
|
|||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// Cast `Value` to `target_type`
|
||||
/// @return the casted value
|
||||
sem::Constant ConstantCast(const sem::Constant& value, const sem::Type* target_elem_type);
|
||||
sem::Constant ConstantCast(const sem::Constant& value,
|
||||
const sem::Type* target_type,
|
||||
const sem::Type* target_element_type = nullptr);
|
||||
|
||||
sem::Constant EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type);
|
||||
sem::Constant EvaluateConstantValue(const ast::LiteralExpression* literal,
|
||||
|
|
|
@ -23,6 +23,33 @@
|
|||
using namespace tint::number_suffixes; // NOLINT
|
||||
|
||||
namespace tint::resolver {
|
||||
namespace {
|
||||
|
||||
sem::Constant::Scalars CastScalars(sem::Constant::Scalars in, const sem::Type* target_type) {
|
||||
sem::Constant::Scalars out;
|
||||
out.reserve(in.size());
|
||||
for (auto v : in) {
|
||||
// TODO(crbug.com/tint/1504): Check that value fits in new type
|
||||
out.emplace_back(Switch<sem::Constant::Scalar>(
|
||||
target_type, //
|
||||
[&](const sem::AbstractInt*) { return sem::Constant::Cast<AInt>(v); },
|
||||
[&](const sem::AbstractFloat*) { return sem::Constant::Cast<AFloat>(v); },
|
||||
[&](const sem::I32*) { return sem::Constant::Cast<AInt>(v); },
|
||||
[&](const sem::U32*) { return sem::Constant::Cast<AInt>(v); },
|
||||
[&](const sem::F32*) { return sem::Constant::Cast<AFloat>(v); },
|
||||
[&](const sem::F16*) { return sem::Constant::Cast<AFloat>(v); },
|
||||
[&](const sem::Bool*) { return sem::Constant::Cast<bool>(v); },
|
||||
[&](Default) {
|
||||
diag::List diags;
|
||||
TINT_UNREACHABLE(Semantic, diags)
|
||||
<< "invalid element type " << target_type->TypeInfo().name;
|
||||
return sem::Constant::Scalar(false);
|
||||
}));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
sem::Constant Resolver::EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type) {
|
||||
if (auto* e = expr->As<ast::LiteralExpression>()) {
|
||||
|
@ -51,21 +78,22 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::LiteralExpression* lite
|
|||
|
||||
sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call,
|
||||
const sem::Type* type) {
|
||||
auto* vec = type->As<sem::Vector>();
|
||||
|
||||
// For now, only fold scalars and vectors
|
||||
if (!type->is_scalar() && !vec) {
|
||||
uint32_t result_size = 0;
|
||||
auto* el_ty = sem::Type::ElementOf(type, &result_size);
|
||||
if (!el_ty) {
|
||||
return {};
|
||||
}
|
||||
|
||||
auto* elem_type = vec ? vec->type() : type;
|
||||
int result_size = vec ? static_cast<int>(vec->Width()) : 1;
|
||||
// ElementOf() will also return the element type of array, which we do not support.
|
||||
if (type->Is<sem::Array>()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// For zero value init, return 0s
|
||||
if (call->args.empty()) {
|
||||
using Scalars = sem::Constant::Scalars;
|
||||
auto constant = Switch(
|
||||
elem_type,
|
||||
return Switch(
|
||||
el_ty,
|
||||
[&](const sem::AbstractInt*) {
|
||||
return sem::Constant(type, Scalars(result_size, AInt(0)));
|
||||
},
|
||||
|
@ -77,63 +105,53 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call,
|
|||
[&](const sem::F32*) { return sem::Constant(type, Scalars(result_size, AFloat(0))); },
|
||||
[&](const sem::F16*) { return sem::Constant(type, Scalars(result_size, AFloat(0))); },
|
||||
[&](const sem::Bool*) { return sem::Constant(type, Scalars(result_size, false)); });
|
||||
if (constant.IsValid()) {
|
||||
return constant;
|
||||
}
|
||||
}
|
||||
|
||||
// Build value for type_ctor from each child value by casting to
|
||||
// type_ctor's type.
|
||||
// Build value for type_ctor from each child value by casting to type_ctor's type.
|
||||
sem::Constant::Scalars elems;
|
||||
for (auto* expr : call->args) {
|
||||
auto* arg = builder_->Sem().Get(expr);
|
||||
if (!arg || !arg->ConstantValue()) {
|
||||
if (!arg) {
|
||||
return {};
|
||||
}
|
||||
auto cast = ConstantCast(arg->ConstantValue(), elem_type);
|
||||
elems.insert(elems.end(), cast.Elements().begin(), cast.Elements().end());
|
||||
auto value = arg->ConstantValue();
|
||||
if (!value) {
|
||||
return {};
|
||||
}
|
||||
elems.insert(elems.end(), value.Elements().begin(), value.Elements().end());
|
||||
}
|
||||
|
||||
// Splat single-value initializers
|
||||
if (elems.size() == 1) {
|
||||
for (int i = 0; i < result_size - 1; ++i) {
|
||||
for (uint32_t i = 0; i < result_size - 1; ++i) {
|
||||
elems.emplace_back(elems[0]);
|
||||
}
|
||||
}
|
||||
|
||||
return sem::Constant(type, std::move(elems));
|
||||
// Finally cast the elements to the desired type.
|
||||
auto cast = CastScalars(elems, el_ty);
|
||||
|
||||
return sem::Constant(type, std::move(cast));
|
||||
}
|
||||
|
||||
sem::Constant Resolver::ConstantCast(const sem::Constant& value,
|
||||
const sem::Type* target_elem_type) {
|
||||
if (value.ElementType() == target_elem_type) {
|
||||
const sem::Type* target_type,
|
||||
const sem::Type* target_element_type /* = nullptr */) {
|
||||
if (value.Type() == target_type) {
|
||||
return value;
|
||||
}
|
||||
|
||||
sem::Constant::Scalars elems;
|
||||
for (size_t i = 0; i < value.Elements().size(); ++i) {
|
||||
// TODO(crbug.com/tint/1504): Check that value fits in new type
|
||||
elems.emplace_back(Switch<sem::Constant::Scalar>(
|
||||
target_elem_type, //
|
||||
[&](const sem::AbstractInt*) { return value.ElementAs<AInt>(i); },
|
||||
[&](const sem::AbstractFloat*) { return value.ElementAs<AFloat>(i); },
|
||||
[&](const sem::I32*) { return value.ElementAs<AInt>(i); },
|
||||
[&](const sem::U32*) { return value.ElementAs<AInt>(i); },
|
||||
[&](const sem::F32*) { return value.ElementAs<AFloat>(i); },
|
||||
[&](const sem::F16*) { return value.ElementAs<AFloat>(i); },
|
||||
[&](const sem::Bool*) { return value.ElementAs<bool>(i); },
|
||||
[&](Default) {
|
||||
diag::List diags;
|
||||
TINT_UNREACHABLE(Semantic, diags)
|
||||
<< "invalid element type " << target_elem_type->TypeInfo().name;
|
||||
return sem::Constant::Scalar(false);
|
||||
}));
|
||||
if (target_element_type == nullptr) {
|
||||
target_element_type = sem::Type::ElementOf(target_type);
|
||||
}
|
||||
if (target_element_type == nullptr) {
|
||||
return {};
|
||||
}
|
||||
if (value.ElementType() == target_element_type) {
|
||||
return sem::Constant(target_type, value.Elements());
|
||||
}
|
||||
|
||||
auto* target_type =
|
||||
value.Type()->Is<sem::Vector>()
|
||||
? builder_->create<sem::Vector>(target_elem_type, static_cast<uint32_t>(elems.size()))
|
||||
: target_elem_type;
|
||||
auto elems = CastScalars(value.Elements(), target_element_type);
|
||||
|
||||
return sem::Constant(target_type, elems);
|
||||
}
|
||||
|
|
|
@ -81,7 +81,14 @@ class Constant {
|
|||
/// @return the value of the scalar `static_cast` to type T.
|
||||
template <typename T>
|
||||
T ElementAs(size_t index) const {
|
||||
return std::visit([](auto val) { return static_cast<T>(val); }, elems_[index]);
|
||||
return Cast<T>(elems_[index]);
|
||||
}
|
||||
|
||||
/// @param s the input scalar
|
||||
/// @returns the scalar `s` cast to the type `T`.
|
||||
template <typename T>
|
||||
static T Cast(Scalar s) {
|
||||
return std::visit([](auto v) { return static_cast<T>(v); }, s);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
Loading…
Reference in New Issue