// Copyright 2021 The Tint Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "src/transform/calculate_array_length.h" #include #include #include "src/ast/call_statement.h" #include "src/ast/disable_validation_decoration.h" #include "src/program_builder.h" #include "src/sem/block_statement.h" #include "src/sem/call.h" #include "src/sem/statement.h" #include "src/sem/struct.h" #include "src/sem/variable.h" #include "src/transform/simplify_pointers.h" #include "src/utils/hash.h" #include "src/utils/map.h" TINT_INSTANTIATE_TYPEINFO(tint::transform::CalculateArrayLength); TINT_INSTANTIATE_TYPEINFO( tint::transform::CalculateArrayLength::BufferSizeIntrinsic); namespace tint { namespace transform { namespace { /// ArrayUsage describes a runtime array usage. /// It is used as a key by the array_length_by_usage map. struct ArrayUsage { ast::BlockStatement const* const block; sem::Node const* const buffer; bool operator==(const ArrayUsage& rhs) const { return block == rhs.block && buffer == rhs.buffer; } struct Hasher { inline std::size_t operator()(const ArrayUsage& u) const { return utils::Hash(u.block, u.buffer); } }; }; } // namespace CalculateArrayLength::BufferSizeIntrinsic::BufferSizeIntrinsic(ProgramID pid) : Base(pid) {} CalculateArrayLength::BufferSizeIntrinsic::~BufferSizeIntrinsic() = default; std::string CalculateArrayLength::BufferSizeIntrinsic::InternalName() const { return "intrinsic_buffer_size"; } const CalculateArrayLength::BufferSizeIntrinsic* CalculateArrayLength::BufferSizeIntrinsic::Clone(CloneContext* ctx) const { return ctx->dst->ASTNodes().Create( ctx->dst->ID()); } CalculateArrayLength::CalculateArrayLength() = default; CalculateArrayLength::~CalculateArrayLength() = default; void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) { auto& sem = ctx.src->Sem(); if (!Requires(ctx)) { return; } // get_buffer_size_intrinsic() emits the function decorated with // BufferSizeIntrinsic that is transformed by the HLSL writer into a call to // [RW]ByteAddressBuffer.GetDimensions(). std::unordered_map buffer_size_intrinsics; auto get_buffer_size_intrinsic = [&](const sem::Struct* buffer_type) { return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] { auto name = ctx.dst->Sym(); auto* buffer_typename = ctx.dst->ty.type_name(ctx.Clone(buffer_type->Declaration()->name)); auto* disable_validation = ctx.dst->Disable( ast::DisabledValidation::kIgnoreConstructibleFunctionParameter); auto* func = ctx.dst->create( name, ast::VariableList{ // Note: The buffer parameter requires the kStorage StorageClass // in order for HLSL to emit this as a ByteAddressBuffer. ctx.dst->create( ctx.dst->Sym("buffer"), ast::StorageClass::kStorage, ast::Access::kUndefined, buffer_typename, true, nullptr, ast::DecorationList{disable_validation}), ctx.dst->Param("result", ctx.dst->ty.pointer(ctx.dst->ty.u32(), ast::StorageClass::kFunction)), }, ctx.dst->ty.void_(), nullptr, ast::DecorationList{ ctx.dst->ASTNodes().Create(ctx.dst->ID()), }, ast::DecorationList{}); ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), buffer_type->Declaration(), func); return name; }); }; std::unordered_map array_length_by_usage; // Find all the arrayLength() calls... for (auto* node : ctx.src->ASTNodes().Objects()) { if (auto* call_expr = node->As()) { auto* call = sem.Get(call_expr); if (auto* intrinsic = call->Target()->As()) { if (intrinsic->Type() == sem::IntrinsicType::kArrayLength) { // We're dealing with an arrayLength() call // https://gpuweb.github.io/gpuweb/wgsl/#array-types states: // // * The last member of the structure type defining the store type for // a variable in the storage storage class may be a runtime-sized // array. // * A runtime-sized array must not be used as the store type or // contained within a store type in any other cases. // * An expression must not evaluate to a runtime-sized array type. // // We can assume that the arrayLength() call has a single argument of // the form: arrayLength(&X.Y) where X is an expression that resolves // to the storage buffer structure, and Y is the runtime sized array. auto* arg = call_expr->args[0]; auto* address_of = arg->As(); if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) { TINT_ICE(Transform, ctx.dst->Diagnostics()) << "arrayLength() expected pointer to member access, got " << address_of->TypeInfo().name; } auto* array_expr = address_of->expr; auto* accessor = array_expr->As(); if (!accessor) { TINT_ICE(Transform, ctx.dst->Diagnostics()) << "arrayLength() expected pointer to member access, got " "pointer to " << array_expr->TypeInfo().name; break; } auto* storage_buffer_expr = accessor->structure; auto* storage_buffer_sem = sem.Get(storage_buffer_expr); auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef()->As(); // Generate BufferSizeIntrinsic for this storage type if we haven't // already auto buffer_size = get_buffer_size_intrinsic(storage_buffer_type); if (!storage_buffer_type) { TINT_ICE(Transform, ctx.dst->Diagnostics()) << "arrayLength(X.Y) expected X to be sem::Struct, got " << storage_buffer_type->FriendlyName(ctx.src->Symbols()); break; } // Find the current statement block auto* block = call->Stmt()->Block()->Declaration(); // If the storage_buffer_expr is resolves to a variable (typically // true) then key the array_length from the variable. If not, key off // the expression semantic node, which will be unique per call to // arrayLength(). const sem::Node* storage_buffer_usage = storage_buffer_sem; if (auto* user = storage_buffer_sem->As()) { storage_buffer_usage = user->Variable(); } auto array_length = utils::GetOrCreate( array_length_by_usage, {block, storage_buffer_usage}, [&] { // First time this array length is used for this block. // Let's calculate it. // Semantic info for the runtime array structure member auto* array_member_sem = storage_buffer_type->Members().back(); // Construct the variable that'll hold the result of // RWByteAddressBuffer.GetDimensions() auto* buffer_size_result = ctx.dst->Decl( ctx.dst->Var(ctx.dst->Sym(), ctx.dst->ty.u32(), ast::StorageClass::kNone, ctx.dst->Expr(0u))); // Call storage_buffer.GetDimensions(&buffer_size_result) auto* call_get_dims = ctx.dst->CallStmt(ctx.dst->Call( // BufferSizeIntrinsic(X, ARGS...) is // translated to: // X.GetDimensions(ARGS..) by the writer buffer_size, ctx.Clone(storage_buffer_expr), ctx.dst->AddressOf( ctx.dst->Expr(buffer_size_result->variable->symbol)))); // Calculate actual array length // total_storage_buffer_size - array_offset // array_length = ---------------------------------------- // array_stride auto name = ctx.dst->Sym(); uint32_t array_offset = array_member_sem->Offset(); uint32_t array_stride = array_member_sem->Size(); auto* array_length_var = ctx.dst->Decl(ctx.dst->Const( name, ctx.dst->ty.u32(), ctx.dst->Div( ctx.dst->Sub(buffer_size_result->variable->symbol, array_offset), array_stride))); // Insert the array length calculations at the top of the block ctx.InsertBefore(block->statements, block->statements[0], buffer_size_result); ctx.InsertBefore(block->statements, block->statements[0], call_get_dims); ctx.InsertBefore(block->statements, block->statements[0], array_length_var); return name; }); // Replace the call to arrayLength() with the array length variable ctx.Replace(call_expr, ctx.dst->Expr(array_length)); } } } } ctx.Clone(); } } // namespace transform } // namespace tint