mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-08-10 14:09:07 +00:00
This is required for implementing module-level conversions in the spir-v backend (upcoming CL). Bug: tint:865 Change-Id: I7fd38c6b1628c791851917165991bc247fc113c2 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/54740 Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Antonio Maiorano <amaiorano@google.com> Reviewed-by: Ben Clayton <bclayton@google.com>
354 lines
10 KiB
C++
354 lines
10 KiB
C++
// Copyright 2021 The Tint Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "src/transform/fold_constants.h"
|
|
|
|
#include <unordered_map>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "src/program_builder.h"
|
|
|
|
namespace tint {
|
|
|
|
namespace {
|
|
|
|
using i32 = ProgramBuilder::i32;
|
|
using u32 = ProgramBuilder::u32;
|
|
using f32 = ProgramBuilder::f32;
|
|
|
|
/// A Value is a sequence of scalars
|
|
struct Value {
|
|
enum class Type {
|
|
i32, //
|
|
u32,
|
|
f32,
|
|
bool_
|
|
};
|
|
|
|
union Scalar {
|
|
ProgramBuilder::i32 i32;
|
|
ProgramBuilder::u32 u32;
|
|
ProgramBuilder::f32 f32;
|
|
bool bool_;
|
|
|
|
Scalar(ProgramBuilder::i32 v) : i32(v) {} // NOLINT
|
|
Scalar(ProgramBuilder::u32 v) : u32(v) {} // NOLINT
|
|
Scalar(ProgramBuilder::f32 v) : f32(v) {} // NOLINT
|
|
Scalar(bool v) : bool_(v) {} // NOLINT
|
|
};
|
|
|
|
using Elems = std::vector<Scalar>;
|
|
|
|
Type type;
|
|
Elems elems;
|
|
|
|
Value() {}
|
|
|
|
Value(ProgramBuilder::i32 v) : type(Type::i32), elems{v} {} // NOLINT
|
|
Value(ProgramBuilder::u32 v) : type(Type::u32), elems{v} {} // NOLINT
|
|
Value(ProgramBuilder::f32 v) : type(Type::f32), elems{v} {} // NOLINT
|
|
Value(bool v) : type(Type::bool_), elems{v} {} // NOLINT
|
|
|
|
explicit Value(Type t, Elems e = {}) : type(t), elems(std::move(e)) {}
|
|
|
|
bool Valid() const { return elems.size() != 0; }
|
|
operator bool() const { return Valid(); }
|
|
|
|
void Append(const Value& value) {
|
|
TINT_ASSERT(value.type == type);
|
|
elems.insert(elems.end(), value.elems.begin(), value.elems.end());
|
|
}
|
|
|
|
/// Calls `func`(s) with s being the current scalar value at `index`.
|
|
/// `func` is typically a lambda of the form '[](auto&& s)'.
|
|
template <typename Func>
|
|
auto WithScalarAt(size_t index, Func&& func) const {
|
|
switch (type) {
|
|
case Value::Type::i32: {
|
|
return func(elems[index].i32);
|
|
}
|
|
case Value::Type::u32: {
|
|
return func(elems[index].u32);
|
|
}
|
|
case Value::Type::f32: {
|
|
return func(elems[index].f32);
|
|
}
|
|
case Value::Type::bool_: {
|
|
return func(elems[index].bool_);
|
|
}
|
|
}
|
|
TINT_ASSERT(false && "Unreachable");
|
|
return func(~0);
|
|
}
|
|
};
|
|
|
|
/// Returns the Value::Type that maps to the ast::Type*
|
|
Value::Type AstToValueType(ast::Type* t) {
|
|
if (t->Is<ast::I32>()) {
|
|
return Value::Type::i32;
|
|
} else if (t->Is<ast::U32>()) {
|
|
return Value::Type::u32;
|
|
} else if (t->Is<ast::F32>()) {
|
|
return Value::Type::f32;
|
|
} else if (t->Is<ast::Bool>()) {
|
|
return Value::Type::bool_;
|
|
}
|
|
TINT_ASSERT(false && "Invalid type");
|
|
return {};
|
|
}
|
|
|
|
/// Cast `Value` to `target_type`
|
|
/// @return the casted value
|
|
Value Cast(const Value& value, Value::Type target_type) {
|
|
if (value.type == target_type) {
|
|
return value;
|
|
}
|
|
|
|
Value result(target_type);
|
|
for (size_t i = 0; i < value.elems.size(); ++i) {
|
|
switch (target_type) {
|
|
case Value::Type::i32:
|
|
result.Append(value.WithScalarAt(
|
|
i, [](auto&& s) { return static_cast<i32>(s); }));
|
|
break;
|
|
|
|
case Value::Type::u32:
|
|
result.Append(value.WithScalarAt(
|
|
i, [](auto&& s) { return static_cast<u32>(s); }));
|
|
break;
|
|
|
|
case Value::Type::f32:
|
|
result.Append(value.WithScalarAt(
|
|
i, [](auto&& s) { return static_cast<f32>(s); }));
|
|
break;
|
|
|
|
case Value::Type::bool_:
|
|
result.Append(value.WithScalarAt(
|
|
i, [](auto&& s) { return static_cast<bool>(s); }));
|
|
break;
|
|
}
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
/// Type that maps `ast::Expression*` to `Value`
|
|
using ExprToValue = std::unordered_map<const ast::Expression*, Value>;
|
|
|
|
/// Adds mapping of `expr` to `value` to `expr_to_value`
|
|
/// @returns true if add succeded
|
|
bool AddExpr(ExprToValue& expr_to_value,
|
|
const ast::Expression* expr,
|
|
Value value) {
|
|
auto r = expr_to_value.emplace(expr, std::move(value));
|
|
return r.second;
|
|
}
|
|
|
|
/// @returns the `Value` in `expr_to_value` at `expr`, leaving it in the map, or
|
|
/// invalid Value if not in map
|
|
Value PeekExpr(ExprToValue& expr_to_value, ast::Expression* expr) {
|
|
auto iter = expr_to_value.find(expr);
|
|
if (iter != expr_to_value.end()) {
|
|
return iter->second;
|
|
}
|
|
return {};
|
|
}
|
|
|
|
/// @returns the `Value` in `expr_to_value` at `expr`, removing it from the map,
|
|
/// or invalid Value if not in map
|
|
Value TakeExpr(ExprToValue& expr_to_value, ast::Expression* expr) {
|
|
auto iter = expr_to_value.find(expr);
|
|
if (iter != expr_to_value.end()) {
|
|
auto result = std::move(iter->second);
|
|
expr_to_value.erase(iter);
|
|
return result;
|
|
}
|
|
return {};
|
|
}
|
|
|
|
/// Folds a `ScalarConstructorExpression` into a `Value`
|
|
Value Fold(const ast::ScalarConstructorExpression* scalar_ctor) {
|
|
auto* literal = scalar_ctor->literal();
|
|
if (auto* lit = literal->As<ast::SintLiteral>()) {
|
|
return {lit->value_as_i32()};
|
|
}
|
|
if (auto* lit = literal->As<ast::UintLiteral>()) {
|
|
return {lit->value_as_u32()};
|
|
}
|
|
if (auto* lit = literal->As<ast::FloatLiteral>()) {
|
|
return {lit->value()};
|
|
}
|
|
if (auto* lit = literal->As<ast::BoolLiteral>()) {
|
|
return {lit->IsTrue()};
|
|
}
|
|
TINT_ASSERT(false && "Unreachable");
|
|
return {};
|
|
}
|
|
|
|
/// Folds a `TypeConstructorExpression` into a `Value` if possible.
|
|
/// @returns a valid `Value` with 1 element for scalars, and 2/3/4 elements for
|
|
/// vectors.
|
|
Value Fold(const ast::TypeConstructorExpression* type_ctor,
|
|
ExprToValue& expr_to_value) {
|
|
auto& ctor_values = type_ctor->values();
|
|
auto* type = type_ctor->type();
|
|
auto* vec = type->As<ast::Vector>();
|
|
|
|
// For now, only fold scalars and vectors
|
|
if (!type->is_scalar() && !vec) {
|
|
return {};
|
|
}
|
|
|
|
auto* elem_type = vec ? vec->type() : type;
|
|
int result_size = vec ? static_cast<int>(vec->size()) : 1;
|
|
|
|
// For zero value init, return 0s
|
|
if (ctor_values.empty()) {
|
|
if (elem_type->Is<ast::I32>()) {
|
|
return Value(Value::Type::i32, Value::Elems(result_size, 0));
|
|
} else if (elem_type->Is<ast::U32>()) {
|
|
return Value(Value::Type::u32, Value::Elems(result_size, 0u));
|
|
} else if (elem_type->Is<ast::F32>()) {
|
|
return Value(Value::Type::f32, Value::Elems(result_size, 0.0f));
|
|
} else if (elem_type->Is<ast::Bool>()) {
|
|
return Value(Value::Type::bool_, Value::Elems(result_size, false));
|
|
}
|
|
}
|
|
|
|
// If not all ctor_values are foldable, we can't fold this node
|
|
for (auto* cv : ctor_values) {
|
|
if (!PeekExpr(expr_to_value, cv)) {
|
|
return {};
|
|
}
|
|
}
|
|
|
|
// Build value for type_ctor from each child value by casting to
|
|
// type_ctor's type.
|
|
Value new_value(AstToValueType(elem_type));
|
|
for (auto* cv : ctor_values) {
|
|
auto value = TakeExpr(expr_to_value, cv);
|
|
new_value.Append(Cast(value, AstToValueType(elem_type)));
|
|
}
|
|
|
|
// Splat single-value initializers
|
|
if (new_value.elems.size() == 1) {
|
|
auto first_value = new_value;
|
|
for (int i = 0; i < result_size - 1; ++i) {
|
|
new_value.Append(first_value);
|
|
}
|
|
}
|
|
|
|
return new_value;
|
|
}
|
|
|
|
/// @returns a `ConstructorExpression` to replace `expr` with, or nullptr if we
|
|
/// shouldn't replace it.
|
|
ast::ConstructorExpression* Build(CloneContext& ctx,
|
|
const ast::Expression* expr,
|
|
const Value& value) {
|
|
// If original ctor expression had no init values, don't replace the
|
|
// expression
|
|
if (auto* ctor = expr->As<ast::TypeConstructorExpression>()) {
|
|
if (ctor->values().size() == 0) {
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
auto make_ast_type = [&]() -> ast::Type* {
|
|
switch (value.type) {
|
|
case Value::Type::i32:
|
|
return ctx.dst->ty.i32();
|
|
case Value::Type::u32:
|
|
return ctx.dst->ty.u32();
|
|
case Value::Type::f32:
|
|
return ctx.dst->ty.f32();
|
|
case Value::Type::bool_:
|
|
return ctx.dst->ty.bool_();
|
|
}
|
|
return nullptr;
|
|
};
|
|
|
|
if (auto* type_ctor = expr->As<ast::TypeConstructorExpression>()) {
|
|
if (auto* vec = type_ctor->type()->As<ast::Vector>()) {
|
|
uint32_t vec_size = static_cast<uint32_t>(vec->size());
|
|
|
|
// We'd like to construct the new vector with the same number of
|
|
// constructor args that the original node had, but after folding
|
|
// constants, cases like the following are problematic:
|
|
//
|
|
// vec3<f32> = vec3<f32>(vec2<f32>, 1.0) // vec_size=3, ctor_size=2
|
|
//
|
|
// In this case, creating a vec3 with 2 args is invalid, so we should
|
|
// create it with 3. So what we do is construct with vec_size args,
|
|
// except if the original vector was single-value initialized, in which
|
|
// case, we only construct with one arg again.
|
|
uint32_t ctor_size = (type_ctor->values().size() == 1) ? 1 : vec_size;
|
|
|
|
ast::ExpressionList ctors;
|
|
for (uint32_t i = 0; i < ctor_size; ++i) {
|
|
value.WithScalarAt(
|
|
i, [&](auto&& s) { ctors.emplace_back(ctx.dst->Expr(s)); });
|
|
}
|
|
|
|
return ctx.dst->vec(make_ast_type(), vec_size, ctors);
|
|
} else if (type_ctor->type()->is_scalar()) {
|
|
return value.WithScalarAt(0, [&](auto&& s) { return ctx.dst->Expr(s); });
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
namespace transform {
|
|
|
|
FoldConstants::FoldConstants() = default;
|
|
|
|
FoldConstants::~FoldConstants() = default;
|
|
|
|
Output FoldConstants::Run(const Program* in, const DataMap&) {
|
|
ProgramBuilder out;
|
|
CloneContext ctx(&out, in);
|
|
|
|
ExprToValue expr_to_value;
|
|
|
|
// Visit inner expressions before outer expressions
|
|
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
|
if (auto* scalar_ctor = node->As<ast::ScalarConstructorExpression>()) {
|
|
if (auto v = Fold(scalar_ctor)) {
|
|
AddExpr(expr_to_value, scalar_ctor, std::move(v));
|
|
}
|
|
}
|
|
if (auto* type_ctor = node->As<ast::TypeConstructorExpression>()) {
|
|
if (auto v = Fold(type_ctor, expr_to_value)) {
|
|
AddExpr(expr_to_value, type_ctor, std::move(v));
|
|
}
|
|
}
|
|
}
|
|
|
|
for (auto& kvp : expr_to_value) {
|
|
if (auto* ctor_expr = Build(ctx, kvp.first, kvp.second)) {
|
|
ctx.Replace(kvp.first, ctor_expr);
|
|
}
|
|
}
|
|
|
|
ctx.Clone();
|
|
|
|
return Output(Program(std::move(out)));
|
|
}
|
|
|
|
} // namespace transform
|
|
} // namespace tint
|