// Copyright 2022 The Tint Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "src/tint/transform/direct_variable_access.h" #include #include #include #include "src/tint/ast/traverse_expressions.h" #include "src/tint/program_builder.h" #include "src/tint/sem/abstract_int.h" #include "src/tint/sem/call.h" #include "src/tint/sem/function.h" #include "src/tint/sem/index_accessor_expression.h" #include "src/tint/sem/member_accessor_expression.h" #include "src/tint/sem/module.h" #include "src/tint/sem/statement.h" #include "src/tint/sem/struct.h" #include "src/tint/sem/variable.h" #include "src/tint/transform/utils/hoist_to_decl_before.h" #include "src/tint/utils/reverse.h" #include "src/tint/utils/scoped_assignment.h" TINT_INSTANTIATE_TYPEINFO(tint::transform::DirectVariableAccess); TINT_INSTANTIATE_TYPEINFO(tint::transform::DirectVariableAccess::Config); using namespace tint::number_suffixes; // NOLINT namespace { /// AccessRoot describes the root of an AccessShape. struct AccessRoot { /// The pointer-unwrapped type of the *transformed* variable. /// This may be different for pointers in 'private' and 'function' address space, as the pointer /// parameter type is to the *base object* instead of the input pointer type. tint::sem::Type const* type = nullptr; /// The originating module-scope variable ('private', 'storage', 'uniform', 'workgroup'), /// function-scope variable ('function'), or pointer parameter in the source program. tint::sem::Variable const* variable = nullptr; /// The address space of the variable or pointer type. tint::ast::AddressSpace address_space = tint::ast::AddressSpace::kUndefined; }; /// Inequality operator for AccessRoot bool operator!=(const AccessRoot& a, const AccessRoot& b) { return a.type != b.type || a.variable != b.variable; } /// DynamicIndex is used by DirectVariableAccess::State::AccessOp to indicate an array, matrix or /// vector index. struct DynamicIndex { /// The index of the expression in DirectVariableAccess::State::AccessChain::dynamic_indices size_t slot = 0; }; /// Inequality operator for DynamicIndex bool operator!=(const DynamicIndex& a, const DynamicIndex& b) { return a.slot != b.slot; } /// AccessOp describes a single access in an access chain. /// The access is one of: /// Symbol - a struct member access. /// DynamicIndex - a runtime index on an array, matrix column, or vector element. using AccessOp = std::variant; /// A vector of AccessOp. Describes the static "path" from a root variable to an element /// within the variable. Array accessors index expressions are held externally to the /// AccessShape, so AccessShape will be considered equal even if the array, matrix or vector /// index values differ. /// /// For example, consider the following: /// /// ``` /// struct A { /// x : array, /// y : u32, /// }; /// struct B { /// x : i32, /// y : array /// }; /// var C : B; /// ``` /// /// The following AccessShape would describe the following: /// /// +==============================+===============+=================================+ /// | AccessShape | Type | Expression | /// +==============================+===============+=================================+ /// | [ Variable 'C', Symbol 'x' ] | i32 | C.x | /// +------------------------------+---------------+---------------------------------+ /// | [ Variable 'C', Symbol 'y' ] | array | C.y | /// +------------------------------+---------------+---------------------------------+ /// | [ Variable 'C', Symbol 'y', | A | C.y[dyn_idx[0]] | /// | DynamicIndex ] | | | /// +------------------------------+---------------+---------------------------------+ /// | [ Variable 'C', Symbol 'y', | array | C.y[dyn_idx[0]].x | /// | DynamicIndex, Symbol 'x' ] | | | /// +------------------------------+---------------+---------------------------------+ /// | [ Variable 'C', Symbol 'y', | i32 | C.y[dyn_idx[0]].x[dyn_idx[1]] | /// | DynamicIndex, Symbol 'x', | | | /// | DynamicIndex ] | | | /// +------------------------------+---------------+---------------------------------+ /// | [ Variable 'C', Symbol 'y', | u32 | C.y[dyn_idx[0]].y | /// | DynamicIndex, Symbol 'y' ] | | | /// +------------------------------+---------------+---------------------------------+ /// /// Where: `dyn_idx` is the AccessChain::dynamic_indices. struct AccessShape { // The originating variable. AccessRoot root; /// The chain of access ops. tint::utils::Vector ops; /// @returns the number of DynamicIndex operations in #ops. uint32_t NumDynamicIndices() const { uint32_t count = 0; for (auto& op : ops) { if (std::holds_alternative(op)) { count++; } } return count; } }; /// Equality operator for AccessShape bool operator==(const AccessShape& a, const AccessShape& b) { return !(a.root != b.root) && a.ops == b.ops; } /// Inequality operator for AccessShape bool operator!=(const AccessShape& a, const AccessShape& b) { return !(a == b); } /// AccessChain describes a chain of access expressions originating from a variable. struct AccessChain : AccessShape { /// The array accessor index expressions. This vector is indexed by the `DynamicIndex`s in /// #indices. tint::utils::Vector dynamic_indices; /// If true, then this access chain is used as an argument to call a variant. bool used_in_call = false; }; } // namespace namespace tint::utils { /// Hasher specialization for AccessRoot template <> struct Hasher { /// The hash function for the AccessRoot /// @param d the AccessRoot to hash /// @return the hash for the given AccessRoot size_t operator()(const AccessRoot& d) const { return utils::Hash(d.type, d.variable); } }; /// Hasher specialization for DynamicIndex template <> struct Hasher { /// The hash function for the DynamicIndex /// @param d the DynamicIndex to hash /// @return the hash for the given DynamicIndex size_t operator()(const DynamicIndex& d) const { return utils::Hash(d.slot); } }; /// Hasher specialization for AccessShape template <> struct Hasher { /// The hash function for the AccessShape /// @param s the AccessShape to hash /// @return the hash for the given AccessShape size_t operator()(const AccessShape& s) const { return utils::Hash(s.root, s.ops); } }; } // namespace tint::utils namespace tint::transform { /// The PIMPL state for the DirectVariableAccess transform struct DirectVariableAccess::State { /// Constructor /// @param src the source Program /// @param options the transform options State(const Program* src, const Options& options) : ctx{&b, src, /* auto_clone_symbols */ true}, opts(options) {} /// The main function for the transform. /// @returns the ApplyResult ApplyResult Run() { if (!ctx.src->Sem().Module()->Extensions().Contains( ast::Extension::kChromiumExperimentalFullPtrParameters)) { // If the 'chromium_experimental_full_ptr_parameters' extension is not enabled, then // there's nothing for this transform to do. return SkipTransform; } // Stage 1: // Walk all the expressions of the program, starting with the expression leaves. // Whenever we find an identifier resolving to a var, pointer parameter or pointer let to // another chain, start constructing an access chain. When chains are accessed, these chains // are grown and moved up the expression tree. After this stage, we are left with all the // expression access chains to variables that we may need to transform. for (auto* node : ctx.src->ASTNodes().Objects()) { if (auto* expr = sem.Get(node)) { AppendAccessChain(expr); } } // Stage 2: // Walk the functions in dependency order, starting with the entry points. // Construct the set of function 'variants' by examining the calls made by each function to // their call target. Each variant holds a map of pointer parameter to access chains, and // will have the pointer parameters replaced with an array of u32s, used to perform the // pointer indexing in the variant. // Function call pointer arguments are replaced with an array of these dynamic indices. auto decls = sem.Module()->DependencyOrderedDeclarations(); for (auto* decl : utils::Reverse(decls)) { if (auto* fn = sem.Get(decl)) { auto* fn_info = FnInfoFor(fn); ProcessFunction(fn, fn_info); TransformFunction(fn, fn_info); } } // Stage 3: // Filter out access chains that do not need transforming. // Ensure that chain dynamic index expressions are evaluated once at the correct place ProcessAccessChains(); // Stage 4: // Replace all the access chain expressions in all functions with reconstructed expression // using the originating global variable, and any dynamic indices passed in to the function // variant. TransformAccessChainExpressions(); // Stage 5: // Actually kick the clone. CloneState state; clone_state = &state; ctx.Clone(); return Program(std::move(*ctx.dst)); } private: /// Holds symbols of the transformed pointer parameter. /// If both symbols are valid, then #base_ptr and #indices are both program-unique symbols /// derived from the original parameter name. /// If only one symbol is valid, then this is the original parameter symbol. struct PtrParamSymbols { /// The symbol of the base pointer parameter. Symbol base_ptr; /// The symbol of the dynamic indicies parameter. Symbol indices; }; /// FnVariant describes a unique variant of a function, specialized by the AccessShape of the /// pointer arguments - also known as the variant's "signature". /// /// To help understand what a variant is, consider the following WGSL: /// /// ``` /// fn F(a : ptr, b : u32, c : ptr) { /// return *a + b + *c; /// } /// /// @group(0) @binding(0) var S0 : u32; /// @group(0) @binding(0) var S1 : array; /// /// fn x() { /// F(&S0, 0, &S0); // (A) /// F(&S0, 0, &S0); // (B) /// F(&S1[0], 1, &S0); // (C) /// F(&S1[5], 2, &S0); // (D) /// F(&S1[5], 3, &S1[3]); // (E) /// F(&S1[7], 4, &S1[2]); // (F) /// } /// ``` /// /// Given the calls in x(), function F() will have 3 variants: /// (1) F - called by (A) and (B). /// Note that only 'uniform', 'storage' and 'workgroup' pointer /// parameters are considered for a variant signature, and so /// the argument for parameter 'b' is not included in the /// signature. /// (2) F - called by (C) and (D). /// Note that the array index value is external to the /// AccessShape, and so is not part of the variant signature. /// (3) F - called by (E) and (F). /// /// Each variant of the function will be emitted as a separate function by the transform, and /// would look something like: /// /// ``` /// // variant F (1) /// fn F_S0_S0(b : u32) { /// return S0 + b + S0; /// } /// /// type S1_X = array; /// /// // variant F (2) /// fn F_S1_X_S0(a : S1_X, b : u32) { /// return S1[a[0]] + b + S0; /// } /// /// // variant F (3) /// fn F_S1_X_S1_X(a : S1_X, b : u32, c : S1_X) { /// return S1[a[0]] + b + S1[c[0]]; /// } /// /// @group(0) @binding(0) var S0 : u32; /// @group(0) @binding(0) var S1 : array; /// /// fn x() { /// F_S0_S0(0); // (A) /// F(&S0, 0, &S0); // (B) /// F_S1_X_S0(S1_X(0), 1); // (C) /// F_S1_X_S0(S1_X(5), 2); // (D) /// F_S1_X_S1_X(S1_X(5), 3, S1_X(3)); // (E) /// F_S1_X_S1_X(S1_X(7), 4, S1_X(2)); // (F) /// } /// ``` struct FnVariant { /// The signature of the variant is a map of each of the function's 'uniform', 'storage' and /// 'workgroup' pointer parameters to the caller's AccessShape. using Signature = utils::Hashmap; /// The unique name of the variant. /// The symbol is in the `ctx.dst` program namespace. Symbol name; /// A map of direct calls made by this variant to the name of other function variants. utils::Hashmap calls; /// A map of input program parameter to output parameter symbols. utils::Hashmap ptr_param_symbols; /// The declaration order of the variant, in relation to other variants of the same /// function. Used to ensure deterministic ordering of the transform, as map iteration is /// not deterministic between compilers. size_t order = 0; }; /// FnInfo holds information about a function in the input program. struct FnInfo { /// A map of variant signature to the variant data. utils::Hashmap variants; /// A map of expressions that have been hoisted to a 'let' declaration in the function. utils::Hashmap hoisted_exprs; /// @returns the variants of the function in a deterministically ordered vector. utils::Vector, 8> SortedVariants() { utils::Vector, 8> out; out.Reserve(variants.Count()); for (auto it : variants) { out.Push({&it.key, &it.value}); } out.Sort([&](auto& va, auto& vb) { return va.second->order < vb.second->order; }); return out; } }; /// The program builder ProgramBuilder b; /// The clone context CloneContext ctx; /// The transform options const Options& opts; /// Alias to the semantic info in ctx.src const sem::Info& sem = ctx.src->Sem(); /// Alias to the symbols in ctx.src const SymbolTable& sym = ctx.src->Symbols(); /// Map of semantic function to the function info utils::Hashmap fns; /// Map of AccessShape to the name of a type alias for the an array used for the /// dynamic indices of an access chain, passed down as the transformed type of a variant's /// pointer parameter. utils::Hashmap dynamic_index_array_aliases; /// Map of semantic expression to AccessChain utils::Hashmap access_chains; /// Allocator for FnInfo utils::BlockAllocator fn_info_allocator; /// Allocator for AccessChain utils::BlockAllocator access_chain_allocator; /// Helper used for hoisting expressions to lets HoistToDeclBefore hoist{ctx}; /// Map of string to unique symbol (no collisions in output program). utils::Hashmap unique_symbols; /// CloneState holds pointers to the current function, variant and variant's parameters. struct CloneState { /// The current function being cloned FnInfo* current_function = nullptr; /// The current function variant being built FnVariant* current_variant = nullptr; /// The signature of the current function variant being built const FnVariant::Signature* current_variant_sig = nullptr; }; /// The clone state. /// Only valid during the lifetime of the CloneContext::Clone(). CloneState* clone_state = nullptr; /// AppendAccessChain creates or extends an existing AccessChain for the given expression, /// modifying the #access_chains map. void AppendAccessChain(const sem::Expression* expr) { // take_chain moves the AccessChain from the expression `from` to the expression `expr`. // Returns nullptr if `from` did not hold an access chain. auto take_chain = [&](const sem::Expression* from) -> AccessChain* { if (auto* chain = AccessChainFor(from)) { access_chains.Remove(from); access_chains.Add(expr, chain); return chain; } return nullptr; }; Switch( expr, [&](const sem::VariableUser* user) { // Expression resolves to a variable. auto* variable = user->Variable(); auto create_new_chain = [&] { auto* chain = access_chain_allocator.Create(); chain->root.variable = variable; chain->root.type = variable->Type(); chain->root.address_space = variable->AddressSpace(); if (auto* ptr = chain->root.type->As()) { chain->root.address_space = ptr->AddressSpace(); } access_chains.Add(expr, chain); }; Switch( variable->Declaration(), [&](const ast::Var*) { if (variable->AddressSpace() != ast::AddressSpace::kHandle) { // Start a new access chain for the non-handle 'var' access create_new_chain(); } }, [&](const ast::Parameter*) { if (variable->Type()->Is()) { // Start a new access chain for the pointer parameter access create_new_chain(); } }, [&](const ast::Let*) { if (variable->Type()->Is()) { // variable is a pointer-let. auto* init = sem.Get(variable->Declaration()->initializer); // Note: We do not use take_chain() here, as we need to preserve the // AccessChain on the let's initializer, as the let needs its // initializer updated, and the let may be used multiple times. Instead // we copy the let's AccessChain into a a new AccessChain. if (auto* init_chain = AccessChainFor(init)) { access_chains.Add(expr, access_chain_allocator.Create(*init_chain)); } } }); }, [&](const sem::StructMemberAccess* a) { // Structure member access. // Append the Symbol of the member name to the chain, and move the chain to the // member access expression. if (auto* chain = take_chain(a->Object())) { chain->ops.Push(a->Member()->Name()); } }, [&](const sem::IndexAccessorExpression* a) { // Array, matrix or vector index. // Store the index expression into AccessChain::dynamic_indices, append a // DynamicIndex to the chain, and move the chain to the index accessor expression. if (auto* chain = take_chain(a->Object())) { chain->ops.Push(DynamicIndex{chain->dynamic_indices.Length()}); chain->dynamic_indices.Push(a->Index()); } }, [&](const sem::Expression* e) { if (auto* unary = e->Declaration()->As()) { // Unary op. // If this is a '&' or '*', simply move the chain to the unary op expression. if (unary->op == ast::UnaryOp::kAddressOf || unary->op == ast::UnaryOp::kIndirection) { take_chain(sem.Get(unary->expr)); } } }); } /// MaybeHoistDynamicIndices examines the AccessChain::dynamic_indices member of @p chain, /// hoisting all expressions to their own uniquely named 'let' if none of the following are /// true: /// 1. The index expression is a constant value. /// 2. The index expression's statement is the same as @p usage. /// 3. The index expression is an identifier resolving to a 'let', 'const' or parameter, AND /// that identifier resolves to the same variable at @p usage. /// /// A dynamic index will only be hoisted once. The hoisting applies to all variants of the /// function that holds the dynamic index expression. void MaybeHoistDynamicIndices(AccessChain* chain, const sem::Statement* usage) { for (auto& idx : chain->dynamic_indices) { if (idx->ConstantValue()) { // Dynamic index is constant. continue; // Hoisting not required. } if (idx->Stmt() == usage) { // The index expression is owned by the statement of usage. continue; // Hoisting not required } if (auto* idx_variable_user = idx->UnwrapMaterialize()->As()) { auto* idx_variable = idx_variable_user->Variable(); if (idx_variable->Declaration()->IsAnyOf()) { // Dynamic index is an immutable variable continue; // Hoisting not required. } } // The dynamic index needs to be hoisted (if it hasn't been already). auto fn = FnInfoFor(idx->Stmt()->Function()); fn->hoisted_exprs.GetOrCreate(idx, [=] { // Create a name for the new 'let' auto name = b.Symbols().New("ptr_index_save"); // Insert a new 'let' just above the dynamic index statement. hoist.InsertBefore(idx->Stmt(), [this, idx, name] { return b.Decl(b.Let(name, ctx.CloneWithoutTransform(idx->Declaration()))); }); return name; }); } } /// BuildDynamicIndex builds the AST expression node for the dynamic index expression used in an /// AccessChain. This is similar to just cloning the expression, but BuildDynamicIndex() /// also: /// * Collapses constant value index expressions down to the computed value. This acts as an /// constant folding optimization and reduces noise from the transform. /// * Casts the resulting expression to a u32 if @p cast_to_u32 is true, and the expression type /// isn't implicitly usable as a u32. This is to help feed the expression into a /// `array` argument passed to a callee variant function. const ast::Expression* BuildDynamicIndex(const sem::Expression* idx, bool cast_to_u32) { if (auto* val = idx->ConstantValue()) { // Expression evaluated to a constant value. Just emit that constant. return b.Expr(val->As()); } // Expression is not a constant, clone the expression. // Note: If the dynamic index expression was hoisted to a let, then cloning will return an // identifier expression to the hoisted let. auto* expr = ctx.Clone(idx->Declaration()); if (cast_to_u32) { // The index may be fed to a dynamic index array argument, so the index // expression may need casting to u32. if (!idx->UnwrapMaterialize() ->Type() ->UnwrapRef() ->IsAnyOf()) { expr = b.Construct(b.ty.u32(), expr); } } return expr; } /// ProcessFunction scans the direct calls made by the function @p fn, adding new variants to /// the callee functions and transforming the call expression to pass dynamic indices instead of /// true pointers. /// If the function @p fn has pointer parameters that must be transformed to a caller variant, /// and the function is not called, then the function is dropped from the output of the /// transform, as it cannot be generated. /// @note ProcessFunction must be called in dependency order for the program, starting with the /// entry points. void ProcessFunction(const sem::Function* fn, FnInfo* fn_info) { if (fn_info->variants.IsEmpty()) { // Function has no variants pre-generated by callers. if (MustBeCalled(fn)) { // Drop the function, as it wasn't called and cannot be generated. ctx.Remove(ctx.src->AST().GlobalDeclarations(), fn->Declaration()); return; } // Function was not called. Create a single variant with an empty signature. FnVariant variant; variant.name = ctx.Clone(fn->Declaration()->symbol); variant.order = 0; // Unaltered comes first. fn_info->variants.Add(FnVariant::Signature{}, std::move(variant)); } // Process each of the direct calls made by this function. for (auto* call : fn->DirectCalls()) { ProcessCall(fn_info, call); } } /// ProcessCall creates new variants of the callee function by permuting the call for each of /// the variants of @p caller. ProcessCall also registers the clone callback to transform the /// call expression to pass dynamic indices instead of true pointers. void ProcessCall(FnInfo* caller, const sem::Call* call) { auto* target = call->Target()->As(); if (!target) { // Call target is not a user-declared function. return; // Not interested in this call. } if (!HasPointerParameter(target)) { return; // Not interested in this call. } bool call_needs_transforming = false; // Build the call target function variant for each variant of the caller. for (auto caller_variant_it : caller->SortedVariants()) { auto& caller_signature = *caller_variant_it.first; auto& caller_variant = *caller_variant_it.second; // Build the target variant's signature. FnVariant::Signature target_signature; for (size_t i = 0; i < call->Arguments().Length(); i++) { const auto* arg = call->Arguments()[i]; const auto* param = target->Parameters()[i]; const auto* param_ty = param->Type()->As(); if (!param_ty) { continue; // Parameter type is not a pointer. } // Fetch the access chain for the argument. auto* arg_chain = AccessChainFor(arg); if (!arg_chain) { continue; // Argument does not have an access chain } // Construct the absolute AccessShape by considering the AccessShape of the caller // variant's argument. This will propagate back through pointer parameters, to the // outermost caller. auto absolute = AbsoluteAccessShape(caller_signature, *arg_chain); // If the address space of the root variable of the access chain does not require // transformation, then there's nothing to do. if (!AddressSpaceRequiresTransform(absolute.root.address_space)) { continue; } // Record that this chain was used in a function call. // This preserves the chain during the access chain filtering stage. arg_chain->used_in_call = true; if (IsPrivateOrFunction(absolute.root.address_space)) { // Pointers in 'private' and 'function' address spaces need to be passed by // pointer argument. absolute.root.variable = param; } // Add the parameter's absolute AccessShape to the target's signature. target_signature.Add(param, std::move(absolute)); } // Construct a new FnVariant if this is the first caller of the target signature auto* target_info = FnInfoFor(target); auto& target_variant = target_info->variants.GetOrCreate(target_signature, [&] { if (target_signature.IsEmpty()) { // Call target does not require any argument changes. FnVariant variant; variant.name = ctx.Clone(target->Declaration()->symbol); variant.order = 0; // Unaltered comes first. return variant; } // Build an appropriate variant function name. // This is derived from the original function name and the pointer parameter // chains. std::stringstream ss; ss << ctx.src->Symbols().NameFor(target->Declaration()->symbol); for (auto* param : target->Parameters()) { if (auto indices = target_signature.Find(param)) { ss << "_" << AccessShapeName(*indices); } } // Build the pointer parameter symbols. utils::Hashmap ptr_param_symbols; for (auto param_it : target_signature) { auto* param = param_it.key; auto& shape = param_it.value; // Parameter needs replacing with either zero, one or two parameters: // If the parameter is in the 'private' or 'function' address space, then the // originating pointer is always passed down. This always comes first. // If the access chain has dynamic indices, then we create an array // parameter to hold the dynamic indices. bool requires_base_ptr_param = IsPrivateOrFunction(shape.root.address_space); bool requires_indices_param = shape.NumDynamicIndices() > 0; PtrParamSymbols symbols; if (requires_base_ptr_param && requires_indices_param) { auto original_name = param->Declaration()->symbol; symbols.base_ptr = UniqueSymbolWithSuffix(original_name, "_base"); symbols.indices = UniqueSymbolWithSuffix(original_name, "_indices"); } else if (requires_base_ptr_param) { symbols.base_ptr = ctx.Clone(param->Declaration()->symbol); } else if (requires_indices_param) { symbols.indices = ctx.Clone(param->Declaration()->symbol); } // Remember this base pointer name. ptr_param_symbols.Add(param, symbols); } // Build the variant. FnVariant variant; variant.name = b.Symbols().New(ss.str()); variant.order = target_info->variants.Count() + 1; variant.ptr_param_symbols = std::move(ptr_param_symbols); return variant; }); // Record the call made by caller variant to the target variant. caller_variant.calls.Add(call, target_variant.name); if (!target_signature.IsEmpty()) { // The call expression will need transforming for at least one caller variant. call_needs_transforming = true; } } if (call_needs_transforming) { // Register the clone callback to correctly transform the call expression into the // appropriate variant calls. TransformCall(call); } } /// @returns true if the address space @p address_space requires transforming given the /// transform's options. bool AddressSpaceRequiresTransform(ast::AddressSpace address_space) const { switch (address_space) { case ast::AddressSpace::kUniform: case ast::AddressSpace::kStorage: case ast::AddressSpace::kWorkgroup: return true; case ast::AddressSpace::kPrivate: return opts.transform_private; case ast::AddressSpace::kFunction: return opts.transform_function; default: return false; } } /// @returns the AccessChain for the expression @p expr, or nullptr if the expression does /// not hold an access chain. AccessChain* AccessChainFor(const sem::Expression* expr) const { if (auto chain = access_chains.Find(expr)) { return *chain; } return nullptr; } /// @returns the absolute AccessShape for @p indices, by replacing the originating pointer /// parameter with the AccessChain of variant's signature. AccessShape AbsoluteAccessShape(const FnVariant::Signature& signature, const AccessShape& shape) const { if (auto* root_param = shape.root.variable->As()) { if (auto incoming_chain = signature.Find(root_param)) { // Access chain originates from a parameter, which will be transformed into an array // of dynamic indices. Concatenate the signature's AccessShape for the parameter // to the chain's indices, skipping over the chain's initial parameter index. auto absolute = *incoming_chain; for (auto& op : shape.ops) { absolute.ops.Push(op); } return absolute; } } // Chain does not originate from a parameter, so is already absolute. return shape; } /// TransformFunction registers the clone callback to transform the function @p fn into the /// (potentially multiple) function's variants. TransformFunction will assign the current /// function and variant to #clone_state, which can be used by the other clone callbacks. void TransformFunction(const sem::Function* fn, FnInfo* fn_info) { // Register a custom handler for the specific function ctx.Replace(fn->Declaration(), [this, fn, fn_info] { // For the scope of this lambda, assign current_function to fn_info. TINT_SCOPED_ASSIGNMENT(clone_state->current_function, fn_info); // This callback expects a single function returned. As we're generating potentially // many variant functions, keep a record of the last created variant, and explicitly add // this to the module if it isn't the last. We'll return the last created variant, // taking the place of the original function. const ast::Function* pending_variant = nullptr; // For each variant of fn... for (auto variant_it : fn_info->SortedVariants()) { if (pending_variant) { b.AST().AddFunction(pending_variant); } auto& variant_sig = *variant_it.first; auto& variant = *variant_it.second; // For the rest of this scope, assign the current variant and variant signature. TINT_SCOPED_ASSIGNMENT(clone_state->current_variant_sig, &variant_sig); TINT_SCOPED_ASSIGNMENT(clone_state->current_variant, &variant); // Build the variant's parameters. // Pointer parameters in the 'uniform', 'storage' or 'workgroup' address space are // either replaced with an array of dynamic indices, or are dropped (if there are no // dynamic indices). utils::Vector params; for (auto* param : fn->Parameters()) { if (auto incoming_shape = variant_sig.Find(param)) { auto& symbols = *variant.ptr_param_symbols.Find(param); if (symbols.base_ptr.IsValid()) { auto* base_ptr_ty = b.ty.pointer(CreateASTTypeFor(ctx, incoming_shape->root.type), incoming_shape->root.address_space); params.Push(b.Param(symbols.base_ptr, base_ptr_ty)); } if (symbols.indices.IsValid()) { // Variant has dynamic indices for this variant, replace it. auto* dyn_idx_arr_type = DynamicIndexArrayType(*incoming_shape); params.Push(b.Param(symbols.indices, dyn_idx_arr_type)); } } else { // Just a regular parameter. Just clone the original parameter. params.Push(ctx.Clone(param->Declaration())); } } // Build the variant by cloning the source function. The other clone callbacks will // use clone_state->current_variant and clone_state->current_variant_sig to produce // the variant. auto* ret_ty = ctx.Clone(fn->Declaration()->return_type); auto body = ctx.Clone(fn->Declaration()->body); auto attrs = ctx.Clone(fn->Declaration()->attributes); auto ret_attrs = ctx.Clone(fn->Declaration()->return_type_attributes); pending_variant = b.create(variant.name, std::move(params), ret_ty, body, std::move(attrs), std::move(ret_attrs)); } return pending_variant; }); } /// TransformCall registers the clone callback to transform the call expression @p call to call /// the correct target variant, and to replace pointers arguments with an array of dynamic /// indices. void TransformCall(const sem::Call* call) { // Register a custom handler for the specific call expression ctx.Replace(call->Declaration(), [this, call]() { auto target_variant = clone_state->current_variant->calls.Find(call); if (!target_variant) { // The current variant does not need to transform this call. return ctx.CloneWithoutTransform(call->Declaration()); } // Build the new call expressions's arguments. utils::Vector new_args; for (size_t arg_idx = 0; arg_idx < call->Arguments().Length(); arg_idx++) { auto* arg = call->Arguments()[arg_idx]; auto* param = call->Target()->Parameters()[arg_idx]; auto* param_ty = param->Type()->As(); if (!param_ty) { // Parameter is not a pointer. // Just clone the unaltered argument. new_args.Push(ctx.Clone(arg->Declaration())); continue; // Parameter is not a pointer } auto* chain = AccessChainFor(arg); if (!chain) { // No access chain means the argument is not a pointer that needs transforming. // Just clone the unaltered argument. new_args.Push(ctx.Clone(arg->Declaration())); continue; } // Construct the absolute AccessShape by considering the AccessShape of the caller // variant's argument. This will propagate back through pointer parameters, to the // outermost caller. auto full_indices = AbsoluteAccessShape(*clone_state->current_variant_sig, *chain); // If the parameter is a pointer in the 'private' or 'function' address space, then // we need to pass an additional pointer argument to the base object. if (IsPrivateOrFunction(param_ty->AddressSpace())) { auto* root_expr = BuildAccessRootExpr(chain->root, /* deref */ false); if (!chain->root.variable->Is()) { root_expr = b.AddressOf(root_expr); } new_args.Push(root_expr); } // Get or create the dynamic indices array. if (auto* dyn_idx_arr_ty = DynamicIndexArrayType(full_indices)) { // Build an array of dynamic indices to pass as the replacement for the pointer. utils::Vector dyn_idx_args; if (auto* root_param = chain->root.variable->As()) { // Access chain originates from a pointer parameter. if (auto incoming_chain = clone_state->current_variant_sig->Find(root_param)) { auto indices = clone_state->current_variant->ptr_param_symbols.Find(root_param) ->indices; // This pointer parameter will have been replaced with a array // holding the variant's dynamic indices for the pointer. Unpack these // directly into the array constructor's arguments. auto N = incoming_chain->NumDynamicIndices(); for (uint32_t i = 0; i < N; i++) { dyn_idx_args.Push(b.IndexAccessor(indices, u32(i))); } } } // Pass the dynamic indices of the access chain into the array constructor. for (auto& dyn_idx : chain->dynamic_indices) { dyn_idx_args.Push(BuildDynamicIndex(dyn_idx, /* cast_to_u32 */ true)); } // Construct the dynamic index array, and push as an argument. new_args.Push(b.Construct(dyn_idx_arr_ty, std::move(dyn_idx_args))); } } // Make the call to the target's variant. return b.Call(*target_variant, std::move(new_args)); }); } /// ProcessAccessChains performs the following: /// * Removes all AccessChains from expressions that are not either used as a pointer argument /// in a call, or originates from a pointer parameter. /// * Hoists the dynamic index expressions of AccessChains to 'let' statements, to prevent /// multiple evaluation of the expressions, and avoid expressions resolving to different /// variables based on lexical scope. void ProcessAccessChains() { auto chain_exprs = access_chains.Keys(); chain_exprs.Sort([](const auto& expr_a, const auto& expr_b) { return expr_a->Declaration()->node_id.value < expr_b->Declaration()->node_id.value; }); for (auto* expr : chain_exprs) { auto* chain = *access_chains.Get(expr); if (!chain->used_in_call && !chain->root.variable->Is()) { // Chain was not used in a function call, and does not originate from a // parameter. This chain does not need transforming. Drop it. access_chains.Remove(expr); continue; } // Chain requires transforming. // We need to be careful that the chain does not use expressions with side-effects which // cannot be repeatedly evaluated. In this situation we can hoist the dynamic index // expressions to their own uniquely named lets (if required). MaybeHoistDynamicIndices(chain, expr->Stmt()); } } /// TransformAccessChainExpressions registers the clone callback to: /// * Transform all expressions that have an AccessChain (which aren't arguments to function /// calls, these are handled by TransformCall()), into the equivalent expression using a /// module-scope variable. /// * Replace expressions that have been hoisted to a let, with an identifier expression to that /// let. void TransformAccessChainExpressions() { // Register a custom handler for all non-function call expressions ctx.ReplaceAll([this](const ast::Expression* ast_expr) -> const ast::Expression* { if (!clone_state->current_variant) { // Expression does not belong to a function variant. return nullptr; // Just clone the expression. } auto* expr = sem.Get(ast_expr); if (!expr) { // No semantic node for the expression. return nullptr; // Just clone the expression. } // If the expression has been hoisted to a 'let', then replace the expression with an // identifier to the hoisted let. if (auto hoisted = clone_state->current_function->hoisted_exprs.Find(expr)) { return b.Expr(*hoisted); } auto* chain = AccessChainFor(expr); if (!chain) { // The expression does not have an AccessChain. return nullptr; // Just clone the expression. } auto* root_param = chain->root.variable->As(); if (!root_param) { // The expression has an access chain, but does not originate with a pointer // parameter. We don't need to change anything here. return nullptr; // Just clone the expression. } auto incoming_shape = clone_state->current_variant_sig->Find(root_param); if (!incoming_shape) { // The root parameter of the access chain is not part of the variant's signature. return nullptr; // Just clone the expression. } // Expression holds an access chain to a pointer parameter that needs transforming. // Reconstruct the expression using the variant's incoming shape. auto* chain_expr = BuildAccessRootExpr(incoming_shape->root, /* deref */ true); // Chain starts with a pointer parameter. // Replace this with the variant's incoming shape. This will bring the expression up to // the incoming pointer. auto indices = clone_state->current_variant->ptr_param_symbols.Find(root_param)->indices; for (auto param_access : incoming_shape->ops) { chain_expr = BuildAccessExpr(chain_expr, param_access, [&](size_t i) { return b.IndexAccessor(indices, AInt(i)); }); } // Now build the expression chain within the function. // For each access in the chain (excluding the pointer parameter)... for (auto& op : chain->ops) { chain_expr = BuildAccessExpr(chain_expr, op, [&](size_t i) { return BuildDynamicIndex(chain->dynamic_indices[i], false); }); } // BuildAccessExpr() always returns a non-pointer. // If the expression we're replacing is a pointer, take the address. if (expr->Type()->Is()) { chain_expr = b.AddressOf(chain_expr); } return chain_expr; }); } /// @returns the FnInfo for the given function, constructing a new FnInfo if @p fn doesn't /// already have one. FnInfo* FnInfoFor(const sem::Function* fn) { return fns.GetOrCreate(fn, [this] { return fn_info_allocator.Create(); }); } /// @returns the type alias used to hold the dynamic indices for @p shape, declaring a new alias /// if this is the first call for the given shape. const ast::TypeName* DynamicIndexArrayType(const AccessShape& shape) { auto name = dynamic_index_array_aliases.GetOrCreate(shape, [&] { // Count the number of dynamic indices uint32_t num_dyn_indices = shape.NumDynamicIndices(); if (num_dyn_indices == 0) { return Symbol{}; } auto symbol = b.Symbols().New(AccessShapeName(shape)); b.Alias(symbol, b.ty.array(b.ty.u32(), u32(num_dyn_indices))); return symbol; }); return name.IsValid() ? b.ty.type_name(name) : nullptr; } /// @returns a name describing the given shape std::string AccessShapeName(const AccessShape& shape) { std::stringstream ss; if (IsPrivateOrFunction(shape.root.address_space)) { ss << "F"; } else { ss << ctx.src->Symbols().NameFor(shape.root.variable->Declaration()->symbol); } for (auto& op : shape.ops) { ss << "_"; if (std::holds_alternative(op)) { /// The op uses a dynamic (runtime-expression) index. ss << "X"; continue; } if (auto* member = std::get_if(&op)) { ss << sym.NameFor(*member); continue; } TINT_ICE(Transform, b.Diagnostics()) << "unhandled variant for access chain"; break; } return ss.str(); } /// Builds an expresion to the root of an access, returning the new expression. /// @param root the AccessRoot /// @param deref if true, the returned expression will always be a reference type. const ast::Expression* BuildAccessRootExpr(const AccessRoot& root, bool deref) { if (auto* param = root.variable->As()) { if (auto symbols = clone_state->current_variant->ptr_param_symbols.Find(param)) { if (deref) { return b.Deref(b.Expr(symbols->base_ptr)); } return b.Expr(symbols->base_ptr); } } const ast::Expression* expr = b.Expr(ctx.Clone(root.variable->Declaration()->symbol)); if (deref) { if (root.variable->Type()->Is()) { expr = b.Deref(expr); } } return expr; } /// Builds a single access in an access chain, returning the new expression. /// The returned expression will always be of a reference type. /// @param expr the input expression /// @param access the access to perform on the current expression /// @param dynamic_index a function that obtains the i'th dynamic index const ast::Expression* BuildAccessExpr( const ast::Expression* expr, const AccessOp& access, std::function dynamic_index) { if (auto* dyn_idx = std::get_if(&access)) { /// The access uses a dynamic (runtime-expression) index. auto* idx = dynamic_index(dyn_idx->slot); return b.IndexAccessor(expr, idx); } if (auto* member = std::get_if(&access)) { /// The access is a member access. return b.MemberAccessor(expr, ctx.Clone(*member)); } TINT_ICE(Transform, b.Diagnostics()) << "unhandled variant type for access chain"; return nullptr; } /// @returns a new Symbol starting with @p symbol concatenated with @p suffix, and possibly an /// underscore and number, if the symbol is already taken. Symbol UniqueSymbolWithSuffix(Symbol symbol, const std::string& suffix) { auto str = ctx.src->Symbols().NameFor(symbol) + suffix; return unique_symbols.GetOrCreate(str, [&] { return b.Symbols().New(str); }); } /// @returns true if the function @p fn has at least one pointer parameter. static bool HasPointerParameter(const sem::Function* fn) { for (auto* param : fn->Parameters()) { if (param->Type()->Is()) { return true; } } return false; } /// @returns true if the function @p fn has at least one pointer parameter in an address space /// that must be replaced. If this function is not called, then the function cannot be sensibly /// generated, and must be stripped. static bool MustBeCalled(const sem::Function* fn) { for (auto* param : fn->Parameters()) { if (auto* ptr = param->Type()->As()) { switch (ptr->AddressSpace()) { case ast::AddressSpace::kUniform: case ast::AddressSpace::kStorage: case ast::AddressSpace::kWorkgroup: return true; default: return false; } } } return false; } /// @returns true if the given address space is 'private' or 'function'. static bool IsPrivateOrFunction(const ast::AddressSpace sc) { return sc == ast::AddressSpace::kPrivate || sc == ast::AddressSpace::kFunction; } }; DirectVariableAccess::Config::Config(const Options& opt) : options(opt) {} DirectVariableAccess::Config::~Config() = default; DirectVariableAccess::DirectVariableAccess() = default; DirectVariableAccess::~DirectVariableAccess() = default; Transform::ApplyResult DirectVariableAccess::Apply(const Program* program, const DataMap& inputs, DataMap&) const { Options options; if (auto* cfg = inputs.Get()) { options = cfg->options; } return State(program, options).Run(); } } // namespace tint::transform