tint: Simplify the resolver constant evaluation
And expand the handling to include matrices. Bug: tint:1504 Change-Id: I6fd9ce239d13acf0e2f74b8ea19dfac3457e348c Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91026 Commit-Queue: Ben Clayton <bclayton@google.com> Reviewed-by: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
parent
8f4f449540
commit
d3de38d7e3
|
@ -67,3 +67,4 @@ const char* str(CtorConvIntrinsic i) {
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tint::resolver
|
} // namespace tint::resolver
|
||||||
|
|
||||||
|
|
|
@ -332,7 +332,9 @@ class Resolver {
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
/// Cast `Value` to `target_type`
|
/// Cast `Value` to `target_type`
|
||||||
/// @return the casted value
|
/// @return the casted value
|
||||||
sem::Constant ConstantCast(const sem::Constant& value, const sem::Type* target_elem_type);
|
sem::Constant ConstantCast(const sem::Constant& value,
|
||||||
|
const sem::Type* target_type,
|
||||||
|
const sem::Type* target_element_type = nullptr);
|
||||||
|
|
||||||
sem::Constant EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type);
|
sem::Constant EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type);
|
||||||
sem::Constant EvaluateConstantValue(const ast::LiteralExpression* literal,
|
sem::Constant EvaluateConstantValue(const ast::LiteralExpression* literal,
|
||||||
|
|
|
@ -23,6 +23,33 @@
|
||||||
using namespace tint::number_suffixes; // NOLINT
|
using namespace tint::number_suffixes; // NOLINT
|
||||||
|
|
||||||
namespace tint::resolver {
|
namespace tint::resolver {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
sem::Constant::Scalars CastScalars(sem::Constant::Scalars in, const sem::Type* target_type) {
|
||||||
|
sem::Constant::Scalars out;
|
||||||
|
out.reserve(in.size());
|
||||||
|
for (auto v : in) {
|
||||||
|
// TODO(crbug.com/tint/1504): Check that value fits in new type
|
||||||
|
out.emplace_back(Switch<sem::Constant::Scalar>(
|
||||||
|
target_type, //
|
||||||
|
[&](const sem::AbstractInt*) { return sem::Constant::Cast<AInt>(v); },
|
||||||
|
[&](const sem::AbstractFloat*) { return sem::Constant::Cast<AFloat>(v); },
|
||||||
|
[&](const sem::I32*) { return sem::Constant::Cast<AInt>(v); },
|
||||||
|
[&](const sem::U32*) { return sem::Constant::Cast<AInt>(v); },
|
||||||
|
[&](const sem::F32*) { return sem::Constant::Cast<AFloat>(v); },
|
||||||
|
[&](const sem::F16*) { return sem::Constant::Cast<AFloat>(v); },
|
||||||
|
[&](const sem::Bool*) { return sem::Constant::Cast<bool>(v); },
|
||||||
|
[&](Default) {
|
||||||
|
diag::List diags;
|
||||||
|
TINT_UNREACHABLE(Semantic, diags)
|
||||||
|
<< "invalid element type " << target_type->TypeInfo().name;
|
||||||
|
return sem::Constant::Scalar(false);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
sem::Constant Resolver::EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type) {
|
sem::Constant Resolver::EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type) {
|
||||||
if (auto* e = expr->As<ast::LiteralExpression>()) {
|
if (auto* e = expr->As<ast::LiteralExpression>()) {
|
||||||
|
@ -51,21 +78,22 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::LiteralExpression* lite
|
||||||
|
|
||||||
sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call,
|
sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call,
|
||||||
const sem::Type* type) {
|
const sem::Type* type) {
|
||||||
auto* vec = type->As<sem::Vector>();
|
uint32_t result_size = 0;
|
||||||
|
auto* el_ty = sem::Type::ElementOf(type, &result_size);
|
||||||
// For now, only fold scalars and vectors
|
if (!el_ty) {
|
||||||
if (!type->is_scalar() && !vec) {
|
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
auto* elem_type = vec ? vec->type() : type;
|
// ElementOf() will also return the element type of array, which we do not support.
|
||||||
int result_size = vec ? static_cast<int>(vec->Width()) : 1;
|
if (type->Is<sem::Array>()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
// For zero value init, return 0s
|
// For zero value init, return 0s
|
||||||
if (call->args.empty()) {
|
if (call->args.empty()) {
|
||||||
using Scalars = sem::Constant::Scalars;
|
using Scalars = sem::Constant::Scalars;
|
||||||
auto constant = Switch(
|
return Switch(
|
||||||
elem_type,
|
el_ty,
|
||||||
[&](const sem::AbstractInt*) {
|
[&](const sem::AbstractInt*) {
|
||||||
return sem::Constant(type, Scalars(result_size, AInt(0)));
|
return sem::Constant(type, Scalars(result_size, AInt(0)));
|
||||||
},
|
},
|
||||||
|
@ -77,63 +105,53 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call,
|
||||||
[&](const sem::F32*) { return sem::Constant(type, Scalars(result_size, AFloat(0))); },
|
[&](const sem::F32*) { return sem::Constant(type, Scalars(result_size, AFloat(0))); },
|
||||||
[&](const sem::F16*) { return sem::Constant(type, Scalars(result_size, AFloat(0))); },
|
[&](const sem::F16*) { return sem::Constant(type, Scalars(result_size, AFloat(0))); },
|
||||||
[&](const sem::Bool*) { return sem::Constant(type, Scalars(result_size, false)); });
|
[&](const sem::Bool*) { return sem::Constant(type, Scalars(result_size, false)); });
|
||||||
if (constant.IsValid()) {
|
|
||||||
return constant;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build value for type_ctor from each child value by casting to
|
// Build value for type_ctor from each child value by casting to type_ctor's type.
|
||||||
// type_ctor's type.
|
|
||||||
sem::Constant::Scalars elems;
|
sem::Constant::Scalars elems;
|
||||||
for (auto* expr : call->args) {
|
for (auto* expr : call->args) {
|
||||||
auto* arg = builder_->Sem().Get(expr);
|
auto* arg = builder_->Sem().Get(expr);
|
||||||
if (!arg || !arg->ConstantValue()) {
|
if (!arg) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
auto cast = ConstantCast(arg->ConstantValue(), elem_type);
|
auto value = arg->ConstantValue();
|
||||||
elems.insert(elems.end(), cast.Elements().begin(), cast.Elements().end());
|
if (!value) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
elems.insert(elems.end(), value.Elements().begin(), value.Elements().end());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Splat single-value initializers
|
// Splat single-value initializers
|
||||||
if (elems.size() == 1) {
|
if (elems.size() == 1) {
|
||||||
for (int i = 0; i < result_size - 1; ++i) {
|
for (uint32_t i = 0; i < result_size - 1; ++i) {
|
||||||
elems.emplace_back(elems[0]);
|
elems.emplace_back(elems[0]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return sem::Constant(type, std::move(elems));
|
// Finally cast the elements to the desired type.
|
||||||
|
auto cast = CastScalars(elems, el_ty);
|
||||||
|
|
||||||
|
return sem::Constant(type, std::move(cast));
|
||||||
}
|
}
|
||||||
|
|
||||||
sem::Constant Resolver::ConstantCast(const sem::Constant& value,
|
sem::Constant Resolver::ConstantCast(const sem::Constant& value,
|
||||||
const sem::Type* target_elem_type) {
|
const sem::Type* target_type,
|
||||||
if (value.ElementType() == target_elem_type) {
|
const sem::Type* target_element_type /* = nullptr */) {
|
||||||
|
if (value.Type() == target_type) {
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
sem::Constant::Scalars elems;
|
if (target_element_type == nullptr) {
|
||||||
for (size_t i = 0; i < value.Elements().size(); ++i) {
|
target_element_type = sem::Type::ElementOf(target_type);
|
||||||
// TODO(crbug.com/tint/1504): Check that value fits in new type
|
}
|
||||||
elems.emplace_back(Switch<sem::Constant::Scalar>(
|
if (target_element_type == nullptr) {
|
||||||
target_elem_type, //
|
return {};
|
||||||
[&](const sem::AbstractInt*) { return value.ElementAs<AInt>(i); },
|
}
|
||||||
[&](const sem::AbstractFloat*) { return value.ElementAs<AFloat>(i); },
|
if (value.ElementType() == target_element_type) {
|
||||||
[&](const sem::I32*) { return value.ElementAs<AInt>(i); },
|
return sem::Constant(target_type, value.Elements());
|
||||||
[&](const sem::U32*) { return value.ElementAs<AInt>(i); },
|
|
||||||
[&](const sem::F32*) { return value.ElementAs<AFloat>(i); },
|
|
||||||
[&](const sem::F16*) { return value.ElementAs<AFloat>(i); },
|
|
||||||
[&](const sem::Bool*) { return value.ElementAs<bool>(i); },
|
|
||||||
[&](Default) {
|
|
||||||
diag::List diags;
|
|
||||||
TINT_UNREACHABLE(Semantic, diags)
|
|
||||||
<< "invalid element type " << target_elem_type->TypeInfo().name;
|
|
||||||
return sem::Constant::Scalar(false);
|
|
||||||
}));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto* target_type =
|
auto elems = CastScalars(value.Elements(), target_element_type);
|
||||||
value.Type()->Is<sem::Vector>()
|
|
||||||
? builder_->create<sem::Vector>(target_elem_type, static_cast<uint32_t>(elems.size()))
|
|
||||||
: target_elem_type;
|
|
||||||
|
|
||||||
return sem::Constant(target_type, elems);
|
return sem::Constant(target_type, elems);
|
||||||
}
|
}
|
||||||
|
|
|
@ -81,7 +81,14 @@ class Constant {
|
||||||
/// @return the value of the scalar `static_cast` to type T.
|
/// @return the value of the scalar `static_cast` to type T.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T ElementAs(size_t index) const {
|
T ElementAs(size_t index) const {
|
||||||
return std::visit([](auto val) { return static_cast<T>(val); }, elems_[index]);
|
return Cast<T>(elems_[index]);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// @param s the input scalar
|
||||||
|
/// @returns the scalar `s` cast to the type `T`.
|
||||||
|
template <typename T>
|
||||||
|
static T Cast(Scalar s) {
|
||||||
|
return std::visit([](auto v) { return static_cast<T>(v); }, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
Loading…
Reference in New Issue