// Copyright 2021 The Tint Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "src/transform/wrap_arrays_in_structs.h" #include #include "src/program_builder.h" #include "src/sem/array.h" #include "src/sem/expression.h" #include "src/utils/get_or_create.h" namespace tint { namespace transform { WrapArraysInStructs::WrappedArrayInfo::WrappedArrayInfo() = default; WrapArraysInStructs::WrappedArrayInfo::WrappedArrayInfo( const WrappedArrayInfo&) = default; WrapArraysInStructs::WrappedArrayInfo::~WrappedArrayInfo() = default; WrapArraysInStructs::WrapArraysInStructs() = default; WrapArraysInStructs::~WrapArraysInStructs() = default; Output WrapArraysInStructs::Run(const Program* in, const DataMap&) { ProgramBuilder out; CloneContext ctx(&out, in); auto& sem = ctx.src->Sem(); std::unordered_map wrapped_arrays; auto wrapper = [&](const sem::Array* array) { return WrapArray(ctx, wrapped_arrays, array); }; auto wrapper_typename = [&](const sem::Array* arr) -> ast::TypeName* { auto info = wrapper(arr); return info ? ctx.dst->create(info.wrapper_name) : nullptr; }; // Replace all array types with their corresponding wrapper ctx.ReplaceAll([&](ast::Type* ast_type) -> ast::Type* { auto* type = ctx.src->TypeOf(ast_type); if (auto* array = type->UnwrapRef()->As()) { return wrapper_typename(array); } return nullptr; }); // Fix up array accessors so `a[1]` becomes `a.arr[1]` ctx.ReplaceAll([&](ast::ArrayAccessorExpression* accessor) -> ast::ArrayAccessorExpression* { if (auto* array = As(sem.Get(accessor->array())->Type()->UnwrapRef())) { if (wrapper(array)) { // Array is wrapped in a structure. Emit a member accessor to get // to the actual array. auto* arr = ctx.Clone(accessor->array()); auto* idx = ctx.Clone(accessor->idx_expr()); auto* unwrapped = ctx.dst->MemberAccessor(arr, "arr"); return ctx.dst->IndexAccessor(accessor->source(), unwrapped, idx); } } return nullptr; }); // Fix up array constructors so `A(1,2)` becomes `tint_array_wrapper(A(1,2))` ctx.ReplaceAll([&](ast::TypeConstructorExpression* ctor) -> ast::Expression* { if (auto* array = As(sem.Get(ctor)->Type()->UnwrapRef())) { if (auto w = wrapper(array)) { // Wrap the array type constructor with another constructor for // the wrapper auto* wrapped_array_ty = ctx.Clone(ctor->type()); auto* array_ty = w.array_type(ctx); auto* arr_ctor = ctx.dst->Construct(array_ty, ctx.Clone(ctor->values())); return ctx.dst->Construct(wrapped_array_ty, arr_ctor); } } return nullptr; }); ctx.Clone(); return Output(Program(std::move(out))); } WrapArraysInStructs::WrappedArrayInfo WrapArraysInStructs::WrapArray( CloneContext& ctx, std::unordered_map& wrapped_arrays, const sem::Array* array) const { if (array->IsRuntimeSized()) { return {}; // We don't want to wrap runtime sized arrays } return utils::GetOrCreate(wrapped_arrays, array, [&] { WrappedArrayInfo info; // Generate a unique name for the array wrapper info.wrapper_name = ctx.dst->Symbols().New("tint_array_wrapper"); // Examine the element type. Is it also an array? std::function el_type; if (auto* el_array = array->ElemType()->As()) { // Array of array - call WrapArray() on the element type if (auto el = WrapArray(ctx, wrapped_arrays, el_array)) { el_type = [=](CloneContext& c) { return c.dst->create(el.wrapper_name); }; } } // If the element wasn't an array, just create the typical AST type for it if (!el_type) { el_type = [=](CloneContext& c) { return CreateASTTypeFor(&c, array->ElemType()); }; } // Construct the single structure field type info.array_type = [=](CloneContext& c) { ast::DecorationList decos; if (!array->IsStrideImplicit()) { decos.emplace_back( c.dst->create(array->Stride())); } return c.dst->create(el_type(c), array->Count(), std::move(decos)); }; // Structure() will create and append the ast::Struct to the // global declarations of `ctx.dst`. As we haven't finished building the // current module-scope statement or function, this will be placed // immediately before the usage. ctx.dst->Structure(info.wrapper_name, {ctx.dst->Member("arr", info.array_type(ctx))}); return info; }); } } // namespace transform } // namespace tint