mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-07-23 21:46:00 +00:00
Nothing uses these, and the mutability of these breaks const-correctness. Switch functions that used to return `const utils::Vector<T, N>&` to returning `utils::VectorRef<T>`. Removes the templated size from the public interface. Replace all `const utils::VectorRef<T>&` with `utils::Vector<T>`, there's no point in using yet another level of pointer indirection. Change-Id: Ib96e3171500606d9afffbb13f40023552a74fffc Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/113021 Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Ben Clayton <bclayton@google.com> Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
1218 lines
55 KiB
C++
1218 lines
55 KiB
C++
// Copyright 2022 The Tint Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "src/tint/transform/direct_variable_access.h"
|
|
|
|
#include <algorithm>
|
|
#include <string>
|
|
#include <utility>
|
|
|
|
#include "src/tint/ast/traverse_expressions.h"
|
|
#include "src/tint/program_builder.h"
|
|
#include "src/tint/sem/abstract_int.h"
|
|
#include "src/tint/sem/call.h"
|
|
#include "src/tint/sem/function.h"
|
|
#include "src/tint/sem/index_accessor_expression.h"
|
|
#include "src/tint/sem/member_accessor_expression.h"
|
|
#include "src/tint/sem/module.h"
|
|
#include "src/tint/sem/statement.h"
|
|
#include "src/tint/sem/struct.h"
|
|
#include "src/tint/sem/variable.h"
|
|
#include "src/tint/transform/utils/hoist_to_decl_before.h"
|
|
#include "src/tint/utils/reverse.h"
|
|
#include "src/tint/utils/scoped_assignment.h"
|
|
|
|
TINT_INSTANTIATE_TYPEINFO(tint::transform::DirectVariableAccess);
|
|
TINT_INSTANTIATE_TYPEINFO(tint::transform::DirectVariableAccess::Config);
|
|
|
|
using namespace tint::number_suffixes; // NOLINT
|
|
|
|
namespace {
|
|
|
|
/// AccessRoot describes the root of an AccessShape.
|
|
struct AccessRoot {
|
|
/// The pointer-unwrapped type of the *transformed* variable.
|
|
/// This may be different for pointers in 'private' and 'function' address space, as the pointer
|
|
/// parameter type is to the *base object* instead of the input pointer type.
|
|
tint::sem::Type const* type = nullptr;
|
|
/// The originating module-scope variable ('private', 'storage', 'uniform', 'workgroup'),
|
|
/// function-scope variable ('function'), or pointer parameter in the source program.
|
|
tint::sem::Variable const* variable = nullptr;
|
|
/// The address space of the variable or pointer type.
|
|
tint::ast::AddressSpace address_space = tint::ast::AddressSpace::kUndefined;
|
|
};
|
|
|
|
/// Inequality operator for AccessRoot
|
|
bool operator!=(const AccessRoot& a, const AccessRoot& b) {
|
|
return a.type != b.type || a.variable != b.variable;
|
|
}
|
|
|
|
/// DynamicIndex is used by DirectVariableAccess::State::AccessOp to indicate an array, matrix or
|
|
/// vector index.
|
|
struct DynamicIndex {
|
|
/// The index of the expression in DirectVariableAccess::State::AccessChain::dynamic_indices
|
|
size_t slot = 0;
|
|
};
|
|
|
|
/// Inequality operator for DynamicIndex
|
|
bool operator!=(const DynamicIndex& a, const DynamicIndex& b) {
|
|
return a.slot != b.slot;
|
|
}
|
|
|
|
/// AccessOp describes a single access in an access chain.
|
|
/// The access is one of:
|
|
/// Symbol - a struct member access.
|
|
/// DynamicIndex - a runtime index on an array, matrix column, or vector element.
|
|
using AccessOp = std::variant<tint::Symbol, DynamicIndex>;
|
|
|
|
/// A vector of AccessOp. Describes the static "path" from a root variable to an element
|
|
/// within the variable. Array accessors index expressions are held externally to the
|
|
/// AccessShape, so AccessShape will be considered equal even if the array, matrix or vector
|
|
/// index values differ.
|
|
///
|
|
/// For example, consider the following:
|
|
///
|
|
/// ```
|
|
/// struct A {
|
|
/// x : array<i32, 8>,
|
|
/// y : u32,
|
|
/// };
|
|
/// struct B {
|
|
/// x : i32,
|
|
/// y : array<A, 4>
|
|
/// };
|
|
/// var<workgroup> C : B;
|
|
/// ```
|
|
///
|
|
/// The following AccessShape would describe the following:
|
|
///
|
|
/// +==============================+===============+=================================+
|
|
/// | AccessShape | Type | Expression |
|
|
/// +==============================+===============+=================================+
|
|
/// | [ Variable 'C', Symbol 'x' ] | i32 | C.x |
|
|
/// +------------------------------+---------------+---------------------------------+
|
|
/// | [ Variable 'C', Symbol 'y' ] | array<A, 4> | C.y |
|
|
/// +------------------------------+---------------+---------------------------------+
|
|
/// | [ Variable 'C', Symbol 'y', | A | C.y[dyn_idx[0]] |
|
|
/// | DynamicIndex ] | | |
|
|
/// +------------------------------+---------------+---------------------------------+
|
|
/// | [ Variable 'C', Symbol 'y', | array<i32, 8> | C.y[dyn_idx[0]].x |
|
|
/// | DynamicIndex, Symbol 'x' ] | | |
|
|
/// +------------------------------+---------------+---------------------------------+
|
|
/// | [ Variable 'C', Symbol 'y', | i32 | C.y[dyn_idx[0]].x[dyn_idx[1]] |
|
|
/// | DynamicIndex, Symbol 'x', | | |
|
|
/// | DynamicIndex ] | | |
|
|
/// +------------------------------+---------------+---------------------------------+
|
|
/// | [ Variable 'C', Symbol 'y', | u32 | C.y[dyn_idx[0]].y |
|
|
/// | DynamicIndex, Symbol 'y' ] | | |
|
|
/// +------------------------------+---------------+---------------------------------+
|
|
///
|
|
/// Where: `dyn_idx` is the AccessChain::dynamic_indices.
|
|
struct AccessShape {
|
|
// The originating variable.
|
|
AccessRoot root;
|
|
/// The chain of access ops.
|
|
tint::utils::Vector<AccessOp, 8> ops;
|
|
|
|
/// @returns the number of DynamicIndex operations in #ops.
|
|
uint32_t NumDynamicIndices() const {
|
|
uint32_t count = 0;
|
|
for (auto& op : ops) {
|
|
if (std::holds_alternative<DynamicIndex>(op)) {
|
|
count++;
|
|
}
|
|
}
|
|
return count;
|
|
}
|
|
};
|
|
|
|
/// Equality operator for AccessShape
|
|
bool operator==(const AccessShape& a, const AccessShape& b) {
|
|
return !(a.root != b.root) && a.ops == b.ops;
|
|
}
|
|
|
|
/// Inequality operator for AccessShape
|
|
bool operator!=(const AccessShape& a, const AccessShape& b) {
|
|
return !(a == b);
|
|
}
|
|
|
|
/// AccessChain describes a chain of access expressions originating from a variable.
|
|
struct AccessChain : AccessShape {
|
|
/// The array accessor index expressions. This vector is indexed by the `DynamicIndex`s in
|
|
/// #indices.
|
|
tint::utils::Vector<const tint::sem::Expression*, 8> dynamic_indices;
|
|
/// If true, then this access chain is used as an argument to call a variant.
|
|
bool used_in_call = false;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
namespace tint::utils {
|
|
|
|
/// Hasher specialization for AccessRoot
|
|
template <>
|
|
struct Hasher<AccessRoot> {
|
|
/// The hash function for the AccessRoot
|
|
/// @param d the AccessRoot to hash
|
|
/// @return the hash for the given AccessRoot
|
|
size_t operator()(const AccessRoot& d) const { return utils::Hash(d.type, d.variable); }
|
|
};
|
|
|
|
/// Hasher specialization for DynamicIndex
|
|
template <>
|
|
struct Hasher<DynamicIndex> {
|
|
/// The hash function for the DynamicIndex
|
|
/// @param d the DynamicIndex to hash
|
|
/// @return the hash for the given DynamicIndex
|
|
size_t operator()(const DynamicIndex& d) const { return utils::Hash(d.slot); }
|
|
};
|
|
|
|
/// Hasher specialization for AccessShape
|
|
template <>
|
|
struct Hasher<AccessShape> {
|
|
/// The hash function for the AccessShape
|
|
/// @param s the AccessShape to hash
|
|
/// @return the hash for the given AccessShape
|
|
size_t operator()(const AccessShape& s) const { return utils::Hash(s.root, s.ops); }
|
|
};
|
|
|
|
} // namespace tint::utils
|
|
|
|
namespace tint::transform {
|
|
|
|
/// The PIMPL state for the DirectVariableAccess transform
|
|
struct DirectVariableAccess::State {
|
|
/// Constructor
|
|
/// @param src the source Program
|
|
/// @param options the transform options
|
|
State(const Program* src, const Options& options)
|
|
: ctx{&b, src, /* auto_clone_symbols */ true}, opts(options) {}
|
|
|
|
/// The main function for the transform.
|
|
/// @returns the ApplyResult
|
|
ApplyResult Run() {
|
|
if (!ctx.src->Sem().Module()->Extensions().Contains(
|
|
ast::Extension::kChromiumExperimentalFullPtrParameters)) {
|
|
// If the 'chromium_experimental_full_ptr_parameters' extension is not enabled, then
|
|
// there's nothing for this transform to do.
|
|
return SkipTransform;
|
|
}
|
|
|
|
// Stage 1:
|
|
// Walk all the expressions of the program, starting with the expression leaves.
|
|
// Whenever we find an identifier resolving to a var, pointer parameter or pointer let to
|
|
// another chain, start constructing an access chain. When chains are accessed, these chains
|
|
// are grown and moved up the expression tree. After this stage, we are left with all the
|
|
// expression access chains to variables that we may need to transform.
|
|
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
|
if (auto* expr = sem.Get<sem::Expression>(node)) {
|
|
AppendAccessChain(expr);
|
|
}
|
|
}
|
|
|
|
// Stage 2:
|
|
// Walk the functions in dependency order, starting with the entry points.
|
|
// Construct the set of function 'variants' by examining the calls made by each function to
|
|
// their call target. Each variant holds a map of pointer parameter to access chains, and
|
|
// will have the pointer parameters replaced with an array of u32s, used to perform the
|
|
// pointer indexing in the variant.
|
|
// Function call pointer arguments are replaced with an array of these dynamic indices.
|
|
auto decls = sem.Module()->DependencyOrderedDeclarations();
|
|
for (auto* decl : utils::Reverse(decls)) {
|
|
if (auto* fn = sem.Get<sem::Function>(decl)) {
|
|
auto* fn_info = FnInfoFor(fn);
|
|
ProcessFunction(fn, fn_info);
|
|
TransformFunction(fn, fn_info);
|
|
}
|
|
}
|
|
|
|
// Stage 3:
|
|
// Filter out access chains that do not need transforming.
|
|
// Ensure that chain dynamic index expressions are evaluated once at the correct place
|
|
ProcessAccessChains();
|
|
|
|
// Stage 4:
|
|
// Replace all the access chain expressions in all functions with reconstructed expression
|
|
// using the originating global variable, and any dynamic indices passed in to the function
|
|
// variant.
|
|
TransformAccessChainExpressions();
|
|
|
|
// Stage 5:
|
|
// Actually kick the clone.
|
|
CloneState state;
|
|
clone_state = &state;
|
|
ctx.Clone();
|
|
return Program(std::move(*ctx.dst));
|
|
}
|
|
|
|
private:
|
|
/// Holds symbols of the transformed pointer parameter.
|
|
/// If both symbols are valid, then #base_ptr and #indices are both program-unique symbols
|
|
/// derived from the original parameter name.
|
|
/// If only one symbol is valid, then this is the original parameter symbol.
|
|
struct PtrParamSymbols {
|
|
/// The symbol of the base pointer parameter.
|
|
Symbol base_ptr;
|
|
/// The symbol of the dynamic indicies parameter.
|
|
Symbol indices;
|
|
};
|
|
|
|
/// FnVariant describes a unique variant of a function, specialized by the AccessShape of the
|
|
/// pointer arguments - also known as the variant's "signature".
|
|
///
|
|
/// To help understand what a variant is, consider the following WGSL:
|
|
///
|
|
/// ```
|
|
/// fn F(a : ptr<storage, u32>, b : u32, c : ptr<storage, u32>) {
|
|
/// return *a + b + *c;
|
|
/// }
|
|
///
|
|
/// @group(0) @binding(0) var<storage> S0 : u32;
|
|
/// @group(0) @binding(0) var<storage> S1 : array<u32, 64>;
|
|
///
|
|
/// fn x() {
|
|
/// F(&S0, 0, &S0); // (A)
|
|
/// F(&S0, 0, &S0); // (B)
|
|
/// F(&S1[0], 1, &S0); // (C)
|
|
/// F(&S1[5], 2, &S0); // (D)
|
|
/// F(&S1[5], 3, &S1[3]); // (E)
|
|
/// F(&S1[7], 4, &S1[2]); // (F)
|
|
/// }
|
|
/// ```
|
|
///
|
|
/// Given the calls in x(), function F() will have 3 variants:
|
|
/// (1) F<S0,S0> - called by (A) and (B).
|
|
/// Note that only 'uniform', 'storage' and 'workgroup' pointer
|
|
/// parameters are considered for a variant signature, and so
|
|
/// the argument for parameter 'b' is not included in the
|
|
/// signature.
|
|
/// (2) F<S1[dyn_idx],S0> - called by (C) and (D).
|
|
/// Note that the array index value is external to the
|
|
/// AccessShape, and so is not part of the variant signature.
|
|
/// (3) F<S1[dyn_idx],S1[dyn_idx]> - called by (E) and (F).
|
|
///
|
|
/// Each variant of the function will be emitted as a separate function by the transform, and
|
|
/// would look something like:
|
|
///
|
|
/// ```
|
|
/// // variant F<S0,S0> (1)
|
|
/// fn F_S0_S0(b : u32) {
|
|
/// return S0 + b + S0;
|
|
/// }
|
|
///
|
|
/// type S1_X = array<u32, 1>;
|
|
///
|
|
/// // variant F<S1[dyn_idx],S0> (2)
|
|
/// fn F_S1_X_S0(a : S1_X, b : u32) {
|
|
/// return S1[a[0]] + b + S0;
|
|
/// }
|
|
///
|
|
/// // variant F<S1[dyn_idx],S1[dyn_idx]> (3)
|
|
/// fn F_S1_X_S1_X(a : S1_X, b : u32, c : S1_X) {
|
|
/// return S1[a[0]] + b + S1[c[0]];
|
|
/// }
|
|
///
|
|
/// @group(0) @binding(0) var<storage> S0 : u32;
|
|
/// @group(0) @binding(0) var<storage> S1 : array<u32, 64>;
|
|
///
|
|
/// fn x() {
|
|
/// F_S0_S0(0); // (A)
|
|
/// F(&S0, 0, &S0); // (B)
|
|
/// F_S1_X_S0(S1_X(0), 1); // (C)
|
|
/// F_S1_X_S0(S1_X(5), 2); // (D)
|
|
/// F_S1_X_S1_X(S1_X(5), 3, S1_X(3)); // (E)
|
|
/// F_S1_X_S1_X(S1_X(7), 4, S1_X(2)); // (F)
|
|
/// }
|
|
/// ```
|
|
struct FnVariant {
|
|
/// The signature of the variant is a map of each of the function's 'uniform', 'storage' and
|
|
/// 'workgroup' pointer parameters to the caller's AccessShape.
|
|
using Signature = utils::Hashmap<const sem::Parameter*, AccessShape, 4>;
|
|
|
|
/// The unique name of the variant.
|
|
/// The symbol is in the `ctx.dst` program namespace.
|
|
Symbol name;
|
|
|
|
/// A map of direct calls made by this variant to the name of other function variants.
|
|
utils::Hashmap<const sem::Call*, Symbol, 4> calls;
|
|
|
|
/// A map of input program parameter to output parameter symbols.
|
|
utils::Hashmap<const sem::Parameter*, PtrParamSymbols, 4> ptr_param_symbols;
|
|
|
|
/// The declaration order of the variant, in relation to other variants of the same
|
|
/// function. Used to ensure deterministic ordering of the transform, as map iteration is
|
|
/// not deterministic between compilers.
|
|
size_t order = 0;
|
|
};
|
|
|
|
/// FnInfo holds information about a function in the input program.
|
|
struct FnInfo {
|
|
/// A map of variant signature to the variant data.
|
|
utils::Hashmap<FnVariant::Signature, FnVariant, 8> variants;
|
|
/// A map of expressions that have been hoisted to a 'let' declaration in the function.
|
|
utils::Hashmap<const sem::Expression*, Symbol, 8> hoisted_exprs;
|
|
|
|
/// @returns the variants of the function in a deterministically ordered vector.
|
|
utils::Vector<std::pair<const FnVariant::Signature*, FnVariant*>, 8> SortedVariants() {
|
|
utils::Vector<std::pair<const FnVariant::Signature*, FnVariant*>, 8> out;
|
|
out.Reserve(variants.Count());
|
|
for (auto it : variants) {
|
|
out.Push({&it.key, &it.value});
|
|
}
|
|
out.Sort([&](auto& va, auto& vb) { return va.second->order < vb.second->order; });
|
|
return out;
|
|
}
|
|
};
|
|
|
|
/// The program builder
|
|
ProgramBuilder b;
|
|
/// The clone context
|
|
CloneContext ctx;
|
|
/// The transform options
|
|
const Options& opts;
|
|
/// Alias to the semantic info in ctx.src
|
|
const sem::Info& sem = ctx.src->Sem();
|
|
/// Alias to the symbols in ctx.src
|
|
const SymbolTable& sym = ctx.src->Symbols();
|
|
/// Map of semantic function to the function info
|
|
utils::Hashmap<const sem::Function*, FnInfo*, 8> fns;
|
|
/// Map of AccessShape to the name of a type alias for the an array<u32, N> used for the
|
|
/// dynamic indices of an access chain, passed down as the transformed type of a variant's
|
|
/// pointer parameter.
|
|
utils::Hashmap<AccessShape, Symbol, 8> dynamic_index_array_aliases;
|
|
/// Map of semantic expression to AccessChain
|
|
utils::Hashmap<const sem::Expression*, AccessChain*, 32> access_chains;
|
|
/// Allocator for FnInfo
|
|
utils::BlockAllocator<FnInfo> fn_info_allocator;
|
|
/// Allocator for AccessChain
|
|
utils::BlockAllocator<AccessChain> access_chain_allocator;
|
|
/// Helper used for hoisting expressions to lets
|
|
HoistToDeclBefore hoist{ctx};
|
|
/// Map of string to unique symbol (no collisions in output program).
|
|
utils::Hashmap<std::string, Symbol, 8> unique_symbols;
|
|
|
|
/// CloneState holds pointers to the current function, variant and variant's parameters.
|
|
struct CloneState {
|
|
/// The current function being cloned
|
|
FnInfo* current_function = nullptr;
|
|
/// The current function variant being built
|
|
FnVariant* current_variant = nullptr;
|
|
/// The signature of the current function variant being built
|
|
const FnVariant::Signature* current_variant_sig = nullptr;
|
|
};
|
|
|
|
/// The clone state.
|
|
/// Only valid during the lifetime of the CloneContext::Clone().
|
|
CloneState* clone_state = nullptr;
|
|
|
|
/// AppendAccessChain creates or extends an existing AccessChain for the given expression,
|
|
/// modifying the #access_chains map.
|
|
void AppendAccessChain(const sem::Expression* expr) {
|
|
// take_chain moves the AccessChain from the expression `from` to the expression `expr`.
|
|
// Returns nullptr if `from` did not hold an access chain.
|
|
auto take_chain = [&](const sem::Expression* from) -> AccessChain* {
|
|
if (auto* chain = AccessChainFor(from)) {
|
|
access_chains.Remove(from);
|
|
access_chains.Add(expr, chain);
|
|
return chain;
|
|
}
|
|
return nullptr;
|
|
};
|
|
|
|
Switch(
|
|
expr,
|
|
[&](const sem::VariableUser* user) {
|
|
// Expression resolves to a variable.
|
|
auto* variable = user->Variable();
|
|
|
|
auto create_new_chain = [&] {
|
|
auto* chain = access_chain_allocator.Create();
|
|
chain->root.variable = variable;
|
|
chain->root.type = variable->Type();
|
|
chain->root.address_space = variable->AddressSpace();
|
|
if (auto* ptr = chain->root.type->As<sem::Pointer>()) {
|
|
chain->root.address_space = ptr->AddressSpace();
|
|
}
|
|
access_chains.Add(expr, chain);
|
|
};
|
|
|
|
Switch(
|
|
variable->Declaration(),
|
|
[&](const ast::Var*) {
|
|
if (variable->AddressSpace() != ast::AddressSpace::kHandle) {
|
|
// Start a new access chain for the non-handle 'var' access
|
|
create_new_chain();
|
|
}
|
|
},
|
|
[&](const ast::Parameter*) {
|
|
if (variable->Type()->Is<sem::Pointer>()) {
|
|
// Start a new access chain for the pointer parameter access
|
|
create_new_chain();
|
|
}
|
|
},
|
|
[&](const ast::Let*) {
|
|
if (variable->Type()->Is<sem::Pointer>()) {
|
|
// variable is a pointer-let.
|
|
auto* init = sem.Get(variable->Declaration()->initializer);
|
|
// Note: We do not use take_chain() here, as we need to preserve the
|
|
// AccessChain on the let's initializer, as the let needs its
|
|
// initializer updated, and the let may be used multiple times. Instead
|
|
// we copy the let's AccessChain into a a new AccessChain.
|
|
if (auto* init_chain = AccessChainFor(init)) {
|
|
access_chains.Add(expr, access_chain_allocator.Create(*init_chain));
|
|
}
|
|
}
|
|
});
|
|
},
|
|
[&](const sem::StructMemberAccess* a) {
|
|
// Structure member access.
|
|
// Append the Symbol of the member name to the chain, and move the chain to the
|
|
// member access expression.
|
|
if (auto* chain = take_chain(a->Object())) {
|
|
chain->ops.Push(a->Member()->Name());
|
|
}
|
|
},
|
|
[&](const sem::IndexAccessorExpression* a) {
|
|
// Array, matrix or vector index.
|
|
// Store the index expression into AccessChain::dynamic_indices, append a
|
|
// DynamicIndex to the chain, and move the chain to the index accessor expression.
|
|
if (auto* chain = take_chain(a->Object())) {
|
|
chain->ops.Push(DynamicIndex{chain->dynamic_indices.Length()});
|
|
chain->dynamic_indices.Push(a->Index());
|
|
}
|
|
},
|
|
[&](const sem::Expression* e) {
|
|
if (auto* unary = e->Declaration()->As<ast::UnaryOpExpression>()) {
|
|
// Unary op.
|
|
// If this is a '&' or '*', simply move the chain to the unary op expression.
|
|
if (unary->op == ast::UnaryOp::kAddressOf ||
|
|
unary->op == ast::UnaryOp::kIndirection) {
|
|
take_chain(sem.Get(unary->expr));
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
/// MaybeHoistDynamicIndices examines the AccessChain::dynamic_indices member of @p chain,
|
|
/// hoisting all expressions to their own uniquely named 'let' if none of the following are
|
|
/// true:
|
|
/// 1. The index expression is a constant value.
|
|
/// 2. The index expression's statement is the same as @p usage.
|
|
/// 3. The index expression is an identifier resolving to a 'let', 'const' or parameter, AND
|
|
/// that identifier resolves to the same variable at @p usage.
|
|
///
|
|
/// A dynamic index will only be hoisted once. The hoisting applies to all variants of the
|
|
/// function that holds the dynamic index expression.
|
|
void MaybeHoistDynamicIndices(AccessChain* chain, const sem::Statement* usage) {
|
|
for (auto& idx : chain->dynamic_indices) {
|
|
if (idx->ConstantValue()) {
|
|
// Dynamic index is constant.
|
|
continue; // Hoisting not required.
|
|
}
|
|
|
|
if (idx->Stmt() == usage) {
|
|
// The index expression is owned by the statement of usage.
|
|
continue; // Hoisting not required
|
|
}
|
|
|
|
if (auto* idx_variable_user = idx->UnwrapMaterialize()->As<sem::VariableUser>()) {
|
|
auto* idx_variable = idx_variable_user->Variable();
|
|
if (idx_variable->Declaration()->IsAnyOf<ast::Let, ast::Parameter>()) {
|
|
// Dynamic index is an immutable variable
|
|
continue; // Hoisting not required.
|
|
}
|
|
}
|
|
|
|
// The dynamic index needs to be hoisted (if it hasn't been already).
|
|
auto fn = FnInfoFor(idx->Stmt()->Function());
|
|
fn->hoisted_exprs.GetOrCreate(idx, [=] {
|
|
// Create a name for the new 'let'
|
|
auto name = b.Symbols().New("ptr_index_save");
|
|
// Insert a new 'let' just above the dynamic index statement.
|
|
hoist.InsertBefore(idx->Stmt(), [this, idx, name] {
|
|
return b.Decl(b.Let(name, ctx.CloneWithoutTransform(idx->Declaration())));
|
|
});
|
|
return name;
|
|
});
|
|
}
|
|
}
|
|
|
|
/// BuildDynamicIndex builds the AST expression node for the dynamic index expression used in an
|
|
/// AccessChain. This is similar to just cloning the expression, but BuildDynamicIndex()
|
|
/// also:
|
|
/// * Collapses constant value index expressions down to the computed value. This acts as an
|
|
/// constant folding optimization and reduces noise from the transform.
|
|
/// * Casts the resulting expression to a u32 if @p cast_to_u32 is true, and the expression type
|
|
/// isn't implicitly usable as a u32. This is to help feed the expression into a
|
|
/// `array<u32, N>` argument passed to a callee variant function.
|
|
const ast::Expression* BuildDynamicIndex(const sem::Expression* idx, bool cast_to_u32) {
|
|
if (auto* val = idx->ConstantValue()) {
|
|
// Expression evaluated to a constant value. Just emit that constant.
|
|
return b.Expr(val->As<AInt>());
|
|
}
|
|
|
|
// Expression is not a constant, clone the expression.
|
|
// Note: If the dynamic index expression was hoisted to a let, then cloning will return an
|
|
// identifier expression to the hoisted let.
|
|
auto* expr = ctx.Clone(idx->Declaration());
|
|
|
|
if (cast_to_u32) {
|
|
// The index may be fed to a dynamic index array<u32, N> argument, so the index
|
|
// expression may need casting to u32.
|
|
if (!idx->UnwrapMaterialize()
|
|
->Type()
|
|
->UnwrapRef()
|
|
->IsAnyOf<sem::U32, sem::AbstractInt>()) {
|
|
expr = b.Construct(b.ty.u32(), expr);
|
|
}
|
|
}
|
|
|
|
return expr;
|
|
}
|
|
|
|
/// ProcessFunction scans the direct calls made by the function @p fn, adding new variants to
|
|
/// the callee functions and transforming the call expression to pass dynamic indices instead of
|
|
/// true pointers.
|
|
/// If the function @p fn has pointer parameters that must be transformed to a caller variant,
|
|
/// and the function is not called, then the function is dropped from the output of the
|
|
/// transform, as it cannot be generated.
|
|
/// @note ProcessFunction must be called in dependency order for the program, starting with the
|
|
/// entry points.
|
|
void ProcessFunction(const sem::Function* fn, FnInfo* fn_info) {
|
|
if (fn_info->variants.IsEmpty()) {
|
|
// Function has no variants pre-generated by callers.
|
|
if (MustBeCalled(fn)) {
|
|
// Drop the function, as it wasn't called and cannot be generated.
|
|
ctx.Remove(ctx.src->AST().GlobalDeclarations(), fn->Declaration());
|
|
return;
|
|
}
|
|
|
|
// Function was not called. Create a single variant with an empty signature.
|
|
FnVariant variant;
|
|
variant.name = ctx.Clone(fn->Declaration()->symbol);
|
|
variant.order = 0; // Unaltered comes first.
|
|
fn_info->variants.Add(FnVariant::Signature{}, std::move(variant));
|
|
}
|
|
|
|
// Process each of the direct calls made by this function.
|
|
for (auto* call : fn->DirectCalls()) {
|
|
ProcessCall(fn_info, call);
|
|
}
|
|
}
|
|
|
|
/// ProcessCall creates new variants of the callee function by permuting the call for each of
|
|
/// the variants of @p caller. ProcessCall also registers the clone callback to transform the
|
|
/// call expression to pass dynamic indices instead of true pointers.
|
|
void ProcessCall(FnInfo* caller, const sem::Call* call) {
|
|
auto* target = call->Target()->As<sem::Function>();
|
|
if (!target) {
|
|
// Call target is not a user-declared function.
|
|
return; // Not interested in this call.
|
|
}
|
|
|
|
if (!HasPointerParameter(target)) {
|
|
return; // Not interested in this call.
|
|
}
|
|
|
|
bool call_needs_transforming = false;
|
|
|
|
// Build the call target function variant for each variant of the caller.
|
|
for (auto caller_variant_it : caller->SortedVariants()) {
|
|
auto& caller_signature = *caller_variant_it.first;
|
|
auto& caller_variant = *caller_variant_it.second;
|
|
|
|
// Build the target variant's signature.
|
|
FnVariant::Signature target_signature;
|
|
for (size_t i = 0; i < call->Arguments().Length(); i++) {
|
|
const auto* arg = call->Arguments()[i];
|
|
const auto* param = target->Parameters()[i];
|
|
const auto* param_ty = param->Type()->As<sem::Pointer>();
|
|
if (!param_ty) {
|
|
continue; // Parameter type is not a pointer.
|
|
}
|
|
|
|
// Fetch the access chain for the argument.
|
|
auto* arg_chain = AccessChainFor(arg);
|
|
if (!arg_chain) {
|
|
continue; // Argument does not have an access chain
|
|
}
|
|
|
|
// Construct the absolute AccessShape by considering the AccessShape of the caller
|
|
// variant's argument. This will propagate back through pointer parameters, to the
|
|
// outermost caller.
|
|
auto absolute = AbsoluteAccessShape(caller_signature, *arg_chain);
|
|
|
|
// If the address space of the root variable of the access chain does not require
|
|
// transformation, then there's nothing to do.
|
|
if (!AddressSpaceRequiresTransform(absolute.root.address_space)) {
|
|
continue;
|
|
}
|
|
|
|
// Record that this chain was used in a function call.
|
|
// This preserves the chain during the access chain filtering stage.
|
|
arg_chain->used_in_call = true;
|
|
|
|
if (IsPrivateOrFunction(absolute.root.address_space)) {
|
|
// Pointers in 'private' and 'function' address spaces need to be passed by
|
|
// pointer argument.
|
|
absolute.root.variable = param;
|
|
}
|
|
|
|
// Add the parameter's absolute AccessShape to the target's signature.
|
|
target_signature.Add(param, std::move(absolute));
|
|
}
|
|
|
|
// Construct a new FnVariant if this is the first caller of the target signature
|
|
auto* target_info = FnInfoFor(target);
|
|
auto& target_variant = target_info->variants.GetOrCreate(target_signature, [&] {
|
|
if (target_signature.IsEmpty()) {
|
|
// Call target does not require any argument changes.
|
|
FnVariant variant;
|
|
variant.name = ctx.Clone(target->Declaration()->symbol);
|
|
variant.order = 0; // Unaltered comes first.
|
|
return variant;
|
|
}
|
|
|
|
// Build an appropriate variant function name.
|
|
// This is derived from the original function name and the pointer parameter
|
|
// chains.
|
|
std::stringstream ss;
|
|
ss << ctx.src->Symbols().NameFor(target->Declaration()->symbol);
|
|
for (auto* param : target->Parameters()) {
|
|
if (auto indices = target_signature.Find(param)) {
|
|
ss << "_" << AccessShapeName(*indices);
|
|
}
|
|
}
|
|
|
|
// Build the pointer parameter symbols.
|
|
utils::Hashmap<const sem::Parameter*, PtrParamSymbols, 4> ptr_param_symbols;
|
|
for (auto param_it : target_signature) {
|
|
auto* param = param_it.key;
|
|
auto& shape = param_it.value;
|
|
|
|
// Parameter needs replacing with either zero, one or two parameters:
|
|
// If the parameter is in the 'private' or 'function' address space, then the
|
|
// originating pointer is always passed down. This always comes first.
|
|
// If the access chain has dynamic indices, then we create an array<u32, N>
|
|
// parameter to hold the dynamic indices.
|
|
bool requires_base_ptr_param = IsPrivateOrFunction(shape.root.address_space);
|
|
bool requires_indices_param = shape.NumDynamicIndices() > 0;
|
|
|
|
PtrParamSymbols symbols;
|
|
if (requires_base_ptr_param && requires_indices_param) {
|
|
auto original_name = param->Declaration()->symbol;
|
|
symbols.base_ptr = UniqueSymbolWithSuffix(original_name, "_base");
|
|
symbols.indices = UniqueSymbolWithSuffix(original_name, "_indices");
|
|
} else if (requires_base_ptr_param) {
|
|
symbols.base_ptr = ctx.Clone(param->Declaration()->symbol);
|
|
} else if (requires_indices_param) {
|
|
symbols.indices = ctx.Clone(param->Declaration()->symbol);
|
|
}
|
|
|
|
// Remember this base pointer name.
|
|
ptr_param_symbols.Add(param, symbols);
|
|
}
|
|
|
|
// Build the variant.
|
|
FnVariant variant;
|
|
variant.name = b.Symbols().New(ss.str());
|
|
variant.order = target_info->variants.Count() + 1;
|
|
variant.ptr_param_symbols = std::move(ptr_param_symbols);
|
|
return variant;
|
|
});
|
|
|
|
// Record the call made by caller variant to the target variant.
|
|
caller_variant.calls.Add(call, target_variant.name);
|
|
if (!target_signature.IsEmpty()) {
|
|
// The call expression will need transforming for at least one caller variant.
|
|
call_needs_transforming = true;
|
|
}
|
|
}
|
|
|
|
if (call_needs_transforming) {
|
|
// Register the clone callback to correctly transform the call expression into the
|
|
// appropriate variant calls.
|
|
TransformCall(call);
|
|
}
|
|
}
|
|
|
|
/// @returns true if the address space @p address_space requires transforming given the
|
|
/// transform's options.
|
|
bool AddressSpaceRequiresTransform(ast::AddressSpace address_space) const {
|
|
switch (address_space) {
|
|
case ast::AddressSpace::kUniform:
|
|
case ast::AddressSpace::kStorage:
|
|
case ast::AddressSpace::kWorkgroup:
|
|
return true;
|
|
case ast::AddressSpace::kPrivate:
|
|
return opts.transform_private;
|
|
case ast::AddressSpace::kFunction:
|
|
return opts.transform_function;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
/// @returns the AccessChain for the expression @p expr, or nullptr if the expression does
|
|
/// not hold an access chain.
|
|
AccessChain* AccessChainFor(const sem::Expression* expr) const {
|
|
if (auto chain = access_chains.Find(expr)) {
|
|
return *chain;
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
/// @returns the absolute AccessShape for @p indices, by replacing the originating pointer
|
|
/// parameter with the AccessChain of variant's signature.
|
|
AccessShape AbsoluteAccessShape(const FnVariant::Signature& signature,
|
|
const AccessShape& shape) const {
|
|
if (auto* root_param = shape.root.variable->As<sem::Parameter>()) {
|
|
if (auto incoming_chain = signature.Find(root_param)) {
|
|
// Access chain originates from a parameter, which will be transformed into an array
|
|
// of dynamic indices. Concatenate the signature's AccessShape for the parameter
|
|
// to the chain's indices, skipping over the chain's initial parameter index.
|
|
auto absolute = *incoming_chain;
|
|
for (auto& op : shape.ops) {
|
|
absolute.ops.Push(op);
|
|
}
|
|
return absolute;
|
|
}
|
|
}
|
|
|
|
// Chain does not originate from a parameter, so is already absolute.
|
|
return shape;
|
|
}
|
|
|
|
/// TransformFunction registers the clone callback to transform the function @p fn into the
|
|
/// (potentially multiple) function's variants. TransformFunction will assign the current
|
|
/// function and variant to #clone_state, which can be used by the other clone callbacks.
|
|
void TransformFunction(const sem::Function* fn, FnInfo* fn_info) {
|
|
// Register a custom handler for the specific function
|
|
ctx.Replace(fn->Declaration(), [this, fn, fn_info] {
|
|
// For the scope of this lambda, assign current_function to fn_info.
|
|
TINT_SCOPED_ASSIGNMENT(clone_state->current_function, fn_info);
|
|
|
|
// This callback expects a single function returned. As we're generating potentially
|
|
// many variant functions, keep a record of the last created variant, and explicitly add
|
|
// this to the module if it isn't the last. We'll return the last created variant,
|
|
// taking the place of the original function.
|
|
const ast::Function* pending_variant = nullptr;
|
|
|
|
// For each variant of fn...
|
|
for (auto variant_it : fn_info->SortedVariants()) {
|
|
if (pending_variant) {
|
|
b.AST().AddFunction(pending_variant);
|
|
}
|
|
|
|
auto& variant_sig = *variant_it.first;
|
|
auto& variant = *variant_it.second;
|
|
|
|
// For the rest of this scope, assign the current variant and variant signature.
|
|
TINT_SCOPED_ASSIGNMENT(clone_state->current_variant_sig, &variant_sig);
|
|
TINT_SCOPED_ASSIGNMENT(clone_state->current_variant, &variant);
|
|
|
|
// Build the variant's parameters.
|
|
// Pointer parameters in the 'uniform', 'storage' or 'workgroup' address space are
|
|
// either replaced with an array of dynamic indices, or are dropped (if there are no
|
|
// dynamic indices).
|
|
utils::Vector<const ast::Parameter*, 8> params;
|
|
for (auto* param : fn->Parameters()) {
|
|
if (auto incoming_shape = variant_sig.Find(param)) {
|
|
auto& symbols = *variant.ptr_param_symbols.Find(param);
|
|
if (symbols.base_ptr.IsValid()) {
|
|
auto* base_ptr_ty =
|
|
b.ty.pointer(CreateASTTypeFor(ctx, incoming_shape->root.type),
|
|
incoming_shape->root.address_space);
|
|
params.Push(b.Param(symbols.base_ptr, base_ptr_ty));
|
|
}
|
|
if (symbols.indices.IsValid()) {
|
|
// Variant has dynamic indices for this variant, replace it.
|
|
auto* dyn_idx_arr_type = DynamicIndexArrayType(*incoming_shape);
|
|
params.Push(b.Param(symbols.indices, dyn_idx_arr_type));
|
|
}
|
|
} else {
|
|
// Just a regular parameter. Just clone the original parameter.
|
|
params.Push(ctx.Clone(param->Declaration()));
|
|
}
|
|
}
|
|
|
|
// Build the variant by cloning the source function. The other clone callbacks will
|
|
// use clone_state->current_variant and clone_state->current_variant_sig to produce
|
|
// the variant.
|
|
auto* ret_ty = ctx.Clone(fn->Declaration()->return_type);
|
|
auto body = ctx.Clone(fn->Declaration()->body);
|
|
auto attrs = ctx.Clone(fn->Declaration()->attributes);
|
|
auto ret_attrs = ctx.Clone(fn->Declaration()->return_type_attributes);
|
|
pending_variant =
|
|
b.create<ast::Function>(variant.name, std::move(params), ret_ty, body,
|
|
std::move(attrs), std::move(ret_attrs));
|
|
}
|
|
|
|
return pending_variant;
|
|
});
|
|
}
|
|
|
|
/// TransformCall registers the clone callback to transform the call expression @p call to call
|
|
/// the correct target variant, and to replace pointers arguments with an array of dynamic
|
|
/// indices.
|
|
void TransformCall(const sem::Call* call) {
|
|
// Register a custom handler for the specific call expression
|
|
ctx.Replace(call->Declaration(), [this, call]() {
|
|
auto target_variant = clone_state->current_variant->calls.Find(call);
|
|
if (!target_variant) {
|
|
// The current variant does not need to transform this call.
|
|
return ctx.CloneWithoutTransform(call->Declaration());
|
|
}
|
|
|
|
// Build the new call expressions's arguments.
|
|
utils::Vector<const ast::Expression*, 8> new_args;
|
|
for (size_t arg_idx = 0; arg_idx < call->Arguments().Length(); arg_idx++) {
|
|
auto* arg = call->Arguments()[arg_idx];
|
|
auto* param = call->Target()->Parameters()[arg_idx];
|
|
auto* param_ty = param->Type()->As<sem::Pointer>();
|
|
if (!param_ty) {
|
|
// Parameter is not a pointer.
|
|
// Just clone the unaltered argument.
|
|
new_args.Push(ctx.Clone(arg->Declaration()));
|
|
continue; // Parameter is not a pointer
|
|
}
|
|
|
|
auto* chain = AccessChainFor(arg);
|
|
if (!chain) {
|
|
// No access chain means the argument is not a pointer that needs transforming.
|
|
// Just clone the unaltered argument.
|
|
new_args.Push(ctx.Clone(arg->Declaration()));
|
|
continue;
|
|
}
|
|
|
|
// Construct the absolute AccessShape by considering the AccessShape of the caller
|
|
// variant's argument. This will propagate back through pointer parameters, to the
|
|
// outermost caller.
|
|
auto full_indices = AbsoluteAccessShape(*clone_state->current_variant_sig, *chain);
|
|
|
|
// If the parameter is a pointer in the 'private' or 'function' address space, then
|
|
// we need to pass an additional pointer argument to the base object.
|
|
if (IsPrivateOrFunction(param_ty->AddressSpace())) {
|
|
auto* root_expr = BuildAccessRootExpr(chain->root, /* deref */ false);
|
|
if (!chain->root.variable->Is<sem::Parameter>()) {
|
|
root_expr = b.AddressOf(root_expr);
|
|
}
|
|
new_args.Push(root_expr);
|
|
}
|
|
|
|
// Get or create the dynamic indices array.
|
|
if (auto* dyn_idx_arr_ty = DynamicIndexArrayType(full_indices)) {
|
|
// Build an array of dynamic indices to pass as the replacement for the pointer.
|
|
utils::Vector<const ast::Expression*, 8> dyn_idx_args;
|
|
if (auto* root_param = chain->root.variable->As<sem::Parameter>()) {
|
|
// Access chain originates from a pointer parameter.
|
|
if (auto incoming_chain =
|
|
clone_state->current_variant_sig->Find(root_param)) {
|
|
auto indices =
|
|
clone_state->current_variant->ptr_param_symbols.Find(root_param)
|
|
->indices;
|
|
|
|
// This pointer parameter will have been replaced with a array<u32, N>
|
|
// holding the variant's dynamic indices for the pointer. Unpack these
|
|
// directly into the array constructor's arguments.
|
|
auto N = incoming_chain->NumDynamicIndices();
|
|
for (uint32_t i = 0; i < N; i++) {
|
|
dyn_idx_args.Push(b.IndexAccessor(indices, u32(i)));
|
|
}
|
|
}
|
|
}
|
|
// Pass the dynamic indices of the access chain into the array constructor.
|
|
for (auto& dyn_idx : chain->dynamic_indices) {
|
|
dyn_idx_args.Push(BuildDynamicIndex(dyn_idx, /* cast_to_u32 */ true));
|
|
}
|
|
// Construct the dynamic index array, and push as an argument.
|
|
new_args.Push(b.Construct(dyn_idx_arr_ty, std::move(dyn_idx_args)));
|
|
}
|
|
}
|
|
|
|
// Make the call to the target's variant.
|
|
return b.Call(*target_variant, std::move(new_args));
|
|
});
|
|
}
|
|
|
|
/// ProcessAccessChains performs the following:
|
|
/// * Removes all AccessChains from expressions that are not either used as a pointer argument
|
|
/// in a call, or originates from a pointer parameter.
|
|
/// * Hoists the dynamic index expressions of AccessChains to 'let' statements, to prevent
|
|
/// multiple evaluation of the expressions, and avoid expressions resolving to different
|
|
/// variables based on lexical scope.
|
|
void ProcessAccessChains() {
|
|
auto chain_exprs = access_chains.Keys();
|
|
chain_exprs.Sort([](const auto& expr_a, const auto& expr_b) {
|
|
return expr_a->Declaration()->node_id.value < expr_b->Declaration()->node_id.value;
|
|
});
|
|
|
|
for (auto* expr : chain_exprs) {
|
|
auto* chain = *access_chains.Get(expr);
|
|
if (!chain->used_in_call && !chain->root.variable->Is<sem::Parameter>()) {
|
|
// Chain was not used in a function call, and does not originate from a
|
|
// parameter. This chain does not need transforming. Drop it.
|
|
access_chains.Remove(expr);
|
|
continue;
|
|
}
|
|
|
|
// Chain requires transforming.
|
|
|
|
// We need to be careful that the chain does not use expressions with side-effects which
|
|
// cannot be repeatedly evaluated. In this situation we can hoist the dynamic index
|
|
// expressions to their own uniquely named lets (if required).
|
|
MaybeHoistDynamicIndices(chain, expr->Stmt());
|
|
}
|
|
}
|
|
|
|
/// TransformAccessChainExpressions registers the clone callback to:
|
|
/// * Transform all expressions that have an AccessChain (which aren't arguments to function
|
|
/// calls, these are handled by TransformCall()), into the equivalent expression using a
|
|
/// module-scope variable.
|
|
/// * Replace expressions that have been hoisted to a let, with an identifier expression to that
|
|
/// let.
|
|
void TransformAccessChainExpressions() {
|
|
// Register a custom handler for all non-function call expressions
|
|
ctx.ReplaceAll([this](const ast::Expression* ast_expr) -> const ast::Expression* {
|
|
if (!clone_state->current_variant) {
|
|
// Expression does not belong to a function variant.
|
|
return nullptr; // Just clone the expression.
|
|
}
|
|
|
|
auto* expr = sem.Get<sem::Expression>(ast_expr);
|
|
if (!expr) {
|
|
// No semantic node for the expression.
|
|
return nullptr; // Just clone the expression.
|
|
}
|
|
|
|
// If the expression has been hoisted to a 'let', then replace the expression with an
|
|
// identifier to the hoisted let.
|
|
if (auto hoisted = clone_state->current_function->hoisted_exprs.Find(expr)) {
|
|
return b.Expr(*hoisted);
|
|
}
|
|
|
|
auto* chain = AccessChainFor(expr);
|
|
if (!chain) {
|
|
// The expression does not have an AccessChain.
|
|
return nullptr; // Just clone the expression.
|
|
}
|
|
|
|
auto* root_param = chain->root.variable->As<sem::Parameter>();
|
|
if (!root_param) {
|
|
// The expression has an access chain, but does not originate with a pointer
|
|
// parameter. We don't need to change anything here.
|
|
return nullptr; // Just clone the expression.
|
|
}
|
|
|
|
auto incoming_shape = clone_state->current_variant_sig->Find(root_param);
|
|
if (!incoming_shape) {
|
|
// The root parameter of the access chain is not part of the variant's signature.
|
|
return nullptr; // Just clone the expression.
|
|
}
|
|
|
|
// Expression holds an access chain to a pointer parameter that needs transforming.
|
|
// Reconstruct the expression using the variant's incoming shape.
|
|
|
|
auto* chain_expr = BuildAccessRootExpr(incoming_shape->root, /* deref */ true);
|
|
|
|
// Chain starts with a pointer parameter.
|
|
// Replace this with the variant's incoming shape. This will bring the expression up to
|
|
// the incoming pointer.
|
|
auto indices =
|
|
clone_state->current_variant->ptr_param_symbols.Find(root_param)->indices;
|
|
for (auto param_access : incoming_shape->ops) {
|
|
chain_expr = BuildAccessExpr(chain_expr, param_access, [&](size_t i) {
|
|
return b.IndexAccessor(indices, AInt(i));
|
|
});
|
|
}
|
|
|
|
// Now build the expression chain within the function.
|
|
|
|
// For each access in the chain (excluding the pointer parameter)...
|
|
for (auto& op : chain->ops) {
|
|
chain_expr = BuildAccessExpr(chain_expr, op, [&](size_t i) {
|
|
return BuildDynamicIndex(chain->dynamic_indices[i], false);
|
|
});
|
|
}
|
|
|
|
// BuildAccessExpr() always returns a non-pointer.
|
|
// If the expression we're replacing is a pointer, take the address.
|
|
if (expr->Type()->Is<sem::Pointer>()) {
|
|
chain_expr = b.AddressOf(chain_expr);
|
|
}
|
|
|
|
return chain_expr;
|
|
});
|
|
}
|
|
|
|
/// @returns the FnInfo for the given function, constructing a new FnInfo if @p fn doesn't
|
|
/// already have one.
|
|
FnInfo* FnInfoFor(const sem::Function* fn) {
|
|
return fns.GetOrCreate(fn, [this] { return fn_info_allocator.Create(); });
|
|
}
|
|
|
|
/// @returns the type alias used to hold the dynamic indices for @p shape, declaring a new alias
|
|
/// if this is the first call for the given shape.
|
|
const ast::TypeName* DynamicIndexArrayType(const AccessShape& shape) {
|
|
auto name = dynamic_index_array_aliases.GetOrCreate(shape, [&] {
|
|
// Count the number of dynamic indices
|
|
uint32_t num_dyn_indices = shape.NumDynamicIndices();
|
|
if (num_dyn_indices == 0) {
|
|
return Symbol{};
|
|
}
|
|
auto symbol = b.Symbols().New(AccessShapeName(shape));
|
|
b.Alias(symbol, b.ty.array(b.ty.u32(), u32(num_dyn_indices)));
|
|
return symbol;
|
|
});
|
|
return name.IsValid() ? b.ty.type_name(name) : nullptr;
|
|
}
|
|
|
|
/// @returns a name describing the given shape
|
|
std::string AccessShapeName(const AccessShape& shape) {
|
|
std::stringstream ss;
|
|
|
|
if (IsPrivateOrFunction(shape.root.address_space)) {
|
|
ss << "F";
|
|
} else {
|
|
ss << ctx.src->Symbols().NameFor(shape.root.variable->Declaration()->symbol);
|
|
}
|
|
|
|
for (auto& op : shape.ops) {
|
|
ss << "_";
|
|
|
|
if (std::holds_alternative<DynamicIndex>(op)) {
|
|
/// The op uses a dynamic (runtime-expression) index.
|
|
ss << "X";
|
|
continue;
|
|
}
|
|
|
|
if (auto* member = std::get_if<Symbol>(&op)) {
|
|
ss << sym.NameFor(*member);
|
|
continue;
|
|
}
|
|
|
|
TINT_ICE(Transform, b.Diagnostics()) << "unhandled variant for access chain";
|
|
break;
|
|
}
|
|
return ss.str();
|
|
}
|
|
|
|
/// Builds an expresion to the root of an access, returning the new expression.
|
|
/// @param root the AccessRoot
|
|
/// @param deref if true, the returned expression will always be a reference type.
|
|
const ast::Expression* BuildAccessRootExpr(const AccessRoot& root, bool deref) {
|
|
if (auto* param = root.variable->As<sem::Parameter>()) {
|
|
if (auto symbols = clone_state->current_variant->ptr_param_symbols.Find(param)) {
|
|
if (deref) {
|
|
return b.Deref(b.Expr(symbols->base_ptr));
|
|
}
|
|
return b.Expr(symbols->base_ptr);
|
|
}
|
|
}
|
|
|
|
const ast::Expression* expr = b.Expr(ctx.Clone(root.variable->Declaration()->symbol));
|
|
if (deref) {
|
|
if (root.variable->Type()->Is<sem::Pointer>()) {
|
|
expr = b.Deref(expr);
|
|
}
|
|
}
|
|
return expr;
|
|
}
|
|
|
|
/// Builds a single access in an access chain, returning the new expression.
|
|
/// The returned expression will always be of a reference type.
|
|
/// @param expr the input expression
|
|
/// @param access the access to perform on the current expression
|
|
/// @param dynamic_index a function that obtains the i'th dynamic index
|
|
const ast::Expression* BuildAccessExpr(
|
|
const ast::Expression* expr,
|
|
const AccessOp& access,
|
|
std::function<const ast::Expression*(size_t)> dynamic_index) {
|
|
if (auto* dyn_idx = std::get_if<DynamicIndex>(&access)) {
|
|
/// The access uses a dynamic (runtime-expression) index.
|
|
auto* idx = dynamic_index(dyn_idx->slot);
|
|
return b.IndexAccessor(expr, idx);
|
|
}
|
|
|
|
if (auto* member = std::get_if<Symbol>(&access)) {
|
|
/// The access is a member access.
|
|
return b.MemberAccessor(expr, ctx.Clone(*member));
|
|
}
|
|
|
|
TINT_ICE(Transform, b.Diagnostics()) << "unhandled variant type for access chain";
|
|
return nullptr;
|
|
}
|
|
|
|
/// @returns a new Symbol starting with @p symbol concatenated with @p suffix, and possibly an
|
|
/// underscore and number, if the symbol is already taken.
|
|
Symbol UniqueSymbolWithSuffix(Symbol symbol, const std::string& suffix) {
|
|
auto str = ctx.src->Symbols().NameFor(symbol) + suffix;
|
|
return unique_symbols.GetOrCreate(str, [&] { return b.Symbols().New(str); });
|
|
}
|
|
|
|
/// @returns true if the function @p fn has at least one pointer parameter.
|
|
static bool HasPointerParameter(const sem::Function* fn) {
|
|
for (auto* param : fn->Parameters()) {
|
|
if (param->Type()->Is<sem::Pointer>()) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/// @returns true if the function @p fn has at least one pointer parameter in an address space
|
|
/// that must be replaced. If this function is not called, then the function cannot be sensibly
|
|
/// generated, and must be stripped.
|
|
static bool MustBeCalled(const sem::Function* fn) {
|
|
for (auto* param : fn->Parameters()) {
|
|
if (auto* ptr = param->Type()->As<sem::Pointer>()) {
|
|
switch (ptr->AddressSpace()) {
|
|
case ast::AddressSpace::kUniform:
|
|
case ast::AddressSpace::kStorage:
|
|
case ast::AddressSpace::kWorkgroup:
|
|
return true;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/// @returns true if the given address space is 'private' or 'function'.
|
|
static bool IsPrivateOrFunction(const ast::AddressSpace sc) {
|
|
return sc == ast::AddressSpace::kPrivate || sc == ast::AddressSpace::kFunction;
|
|
}
|
|
};
|
|
|
|
DirectVariableAccess::Config::Config(const Options& opt) : options(opt) {}
|
|
|
|
DirectVariableAccess::Config::~Config() = default;
|
|
|
|
DirectVariableAccess::DirectVariableAccess() = default;
|
|
|
|
DirectVariableAccess::~DirectVariableAccess() = default;
|
|
|
|
Transform::ApplyResult DirectVariableAccess::Apply(const Program* program,
|
|
const DataMap& inputs,
|
|
DataMap&) const {
|
|
Options options;
|
|
if (auto* cfg = inputs.Get<Config>()) {
|
|
options = cfg->options;
|
|
}
|
|
return State(program, options).Run();
|
|
}
|
|
|
|
} // namespace tint::transform
|