dawn-cmake/src/tint/transform/decompose_strided_array.cc
Ben Clayton e5a67ac891 tint: Remove ast::CallExpression -> sem::Call implicit mapping
With abstract materialization, any ast::Expression may map to the new
sem::Materialize node. Because of this, we can't assume that an
ast::CallExpression maps to a sem::Call, as it might be wrapped by a
sem::Materialize.

Remove the mapping, and fix up all the code that was relying on this.

Fixes are done by either:
• Calling `UnwrapMaterialize()->As<sem::Call>()` on the semantic
  expression. This is done when the logic may assume it's possible for
  the expression to be a Materialize node.
• Using the explicit sem::Info::Get<sem::Call>() template argument to
  cast the semantic type to sem::Call. This is done when the logic
  either knows it is impossible for the expression to be a Materialize.

The backends have been stubbed, as we'll want to emit the constant value
for these nodes. It's likely that we'll just use the FoldConstants
transform to strip all Materialize nodes from the tree. For now, be
defensive.

Bug: tint:1504
Change-Id: If9231b300fc30c7fe886c17a804ead8ee2988285
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/90533
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Reviewed-by: David Neto <dneto@google.com>
2022-05-19 21:50:59 +00:00

155 lines
6.4 KiB
C++

// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/decompose_strided_array.h"
#include <unordered_map>
#include <utility>
#include <vector>
#include "src/tint/program_builder.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/expression.h"
#include "src/tint/sem/member_accessor_expression.h"
#include "src/tint/sem/type_constructor.h"
#include "src/tint/transform/simplify_pointers.h"
#include "src/tint/utils/hash.h"
#include "src/tint/utils/map.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeStridedArray);
namespace tint::transform {
namespace {
using DecomposedArrays = std::unordered_map<const sem::Array*, Symbol>;
} // namespace
DecomposeStridedArray::DecomposeStridedArray() = default;
DecomposeStridedArray::~DecomposeStridedArray() = default;
bool DecomposeStridedArray::ShouldRun(const Program* program, const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* ast = node->As<ast::Array>()) {
if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) {
return true;
}
}
}
return false;
}
void DecomposeStridedArray::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
const auto& sem = ctx.src->Sem();
static constexpr const char* kMemberName = "el";
// Maps an array type in the source program to the name of the struct wrapper
// type in the target program.
std::unordered_map<const sem::Array*, Symbol> decomposed;
// Find and replace all arrays with a @stride attribute with a array that has
// the @stride removed. If the source array stride does not match the natural
// stride for the array element type, then replace the array element type with
// a structure, holding a single field with a @size attribute equal to the
// array stride.
ctx.ReplaceAll([&](const ast::Array* ast) -> const ast::Array* {
if (auto* arr = sem.Get(ast)) {
if (!arr->IsStrideImplicit()) {
auto el_ty = utils::GetOrCreate(decomposed, arr, [&] {
auto name = ctx.dst->Symbols().New("strided_arr");
auto* member_ty = ctx.Clone(ast->type);
auto* member = ctx.dst->Member(kMemberName, member_ty,
{ctx.dst->MemberSize(arr->Stride())});
ctx.dst->Structure(name, {member});
return name;
});
auto* count = ctx.Clone(ast->count);
return ctx.dst->ty.array(ctx.dst->ty.type_name(el_ty), count);
}
if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) {
// Strip the @stride attribute
auto* ty = ctx.Clone(ast->type);
auto* count = ctx.Clone(ast->count);
return ctx.dst->ty.array(ty, count);
}
}
return nullptr;
});
// Find all array index-accessors expressions for arrays that have had their
// element changed to a single field structure. These expressions are adjusted
// to insert an additional member accessor for the single structure field.
// Example: `arr[i]` -> `arr[i].el`
ctx.ReplaceAll([&](const ast::IndexAccessorExpression* idx) -> const ast::Expression* {
if (auto* ty = ctx.src->TypeOf(idx->object)) {
if (auto* arr = ty->UnwrapRef()->As<sem::Array>()) {
if (!arr->IsStrideImplicit()) {
auto* expr = ctx.CloneWithoutTransform(idx);
return ctx.dst->MemberAccessor(expr, kMemberName);
}
}
}
return nullptr;
});
// Find all array type constructor expressions for array types that have had
// their element changed to a single field structure. These constructors are
// adjusted to wrap each of the arguments with an additional constructor for
// the new element structure type.
// Example:
// `@stride(32) array<i32, 3>(1, 2, 3)`
// ->
// `array<strided_arr, 3>(strided_arr(1), strided_arr(2), strided_arr(3))`
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::Expression* {
if (!expr->args.empty()) {
if (auto* call = sem.Get(expr)->UnwrapMaterialize()->As<sem::Call>()) {
if (auto* ctor = call->Target()->As<sem::TypeConstructor>()) {
if (auto* arr = ctor->ReturnType()->As<sem::Array>()) {
// Begin by cloning the array constructor type or name
// If this is an unaliased array, this may add a new entry to
// decomposed.
// If this is an aliased array, decomposed should already be
// populated with any strided aliases.
ast::CallExpression::Target target;
if (expr->target.type) {
target.type = ctx.Clone(expr->target.type);
} else {
target.name = ctx.Clone(expr->target.name);
}
ast::ExpressionList args;
if (auto it = decomposed.find(arr); it != decomposed.end()) {
args.reserve(expr->args.size());
for (auto* arg : expr->args) {
args.emplace_back(ctx.dst->Call(it->second, ctx.Clone(arg)));
}
} else {
args = ctx.Clone(expr->args);
}
return target.type ? ctx.dst->Construct(target.type, args)
: ctx.dst->Call(target.name, args);
}
}
}
}
return nullptr;
});
ctx.Clone();
}
} // namespace tint::transform