// Copyright 2021 The Tint Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "src/transform/zero_init_workgroup_memory.h" #include #include #include #include #include #include "src/ast/workgroup_decoration.h" #include "src/program_builder.h" #include "src/sem/atomic_type.h" #include "src/sem/function.h" #include "src/sem/variable.h" #include "src/utils/get_or_create.h" #include "src/utils/unique_vector.h" TINT_INSTANTIATE_TYPEINFO(tint::transform::ZeroInitWorkgroupMemory); namespace tint { namespace transform { /// PIMPL state for the ZeroInitWorkgroupMemory transform struct ZeroInitWorkgroupMemory::State { /// The clone context CloneContext& ctx; /// An alias to *ctx.dst ProgramBuilder& b = *ctx.dst; /// The constant size of the workgroup. If 0, then #workgroup_size_expr should /// be used instead. uint32_t workgroup_size_const = 0; /// The size of the workgroup as an expression generator. Use if /// #workgroup_size_const is 0. std::function workgroup_size_expr; /// ArrayIndex represents a function on the local invocation index, of /// the form: `array_index = (local_invocation_index % modulo) / division` struct ArrayIndex { /// The RHS of the modulus part of the expression uint32_t modulo = 1; /// The RHS of the division part of the expression uint32_t division = 1; /// Equality operator /// @param i the ArrayIndex to compare to this ArrayIndex /// @returns true if `i` and this ArrayIndex are equal bool operator==(const ArrayIndex& i) const { return modulo == i.modulo && division == i.division; } /// Hash function for the ArrayIndex type struct Hasher { /// @param i the ArrayIndex to calculate a hash for /// @returns the hash value for the ArrayIndex `i` size_t operator()(const ArrayIndex& i) const { return utils::Hash(i.modulo, i.division); } }; }; /// A list of unique ArrayIndex using ArrayIndices = utils::UniqueVector; /// Expression holds information about an expression that is being built for a /// statement will zero workgroup values. struct Expression { /// The AST expression node const ast::Expression* expr = nullptr; /// The number of iterations required to zero the value uint32_t num_iterations = 0; /// All array indices used by this expression ArrayIndices array_indices; }; /// Statement holds information about a statement that will zero workgroup /// values. struct Statement { /// The AST statement node const ast::Statement* stmt; /// The number of iterations required to zero the value uint32_t num_iterations; /// All array indices used by this statement ArrayIndices array_indices; }; /// All statements that zero workgroup memory std::vector statements; /// A map of ArrayIndex to the name reserved for the `let` declaration of that /// index. std::unordered_map array_index_names; /// Constructor /// @param c the CloneContext used for the transform explicit State(CloneContext& c) : ctx(c) {} /// Run inserts the workgroup memory zero-initialization logic at the top of /// the given function /// @param fn a compute shader entry point function void Run(const ast::Function* fn) { auto& sem = ctx.src->Sem(); CalculateWorkgroupSize( ast::GetDecoration(fn->decorations)); // Generate a list of statements to zero initialize each of the // workgroup storage variables used by `fn`. This will populate #statements. auto* func = sem.Get(fn); for (auto* var : func->ReferencedModuleVariables()) { if (var->StorageClass() == ast::StorageClass::kWorkgroup) { BuildZeroingStatements( var->Type()->UnwrapRef(), [&](uint32_t num_values) { auto var_name = ctx.Clone(var->Declaration()->symbol); return Expression{b.Expr(var_name), num_values, ArrayIndices{}}; }); } } if (statements.empty()) { return; // No workgroup variables to initialize. } // Scan the entry point for an existing local_invocation_index builtin // parameter std::function local_index; for (auto* param : fn->params) { if (auto* builtin = ast::GetDecoration(param->decorations)) { if (builtin->builtin == ast::Builtin::kLocalInvocationIndex) { local_index = [=] { return b.Expr(ctx.Clone(param->symbol)); }; break; } } if (auto* str = sem.Get(param)->Type()->As()) { for (auto* member : str->Members()) { if (auto* builtin = ast::GetDecoration( member->Declaration()->decorations)) { if (builtin->builtin == ast::Builtin::kLocalInvocationIndex) { local_index = [=] { auto* param_expr = b.Expr(ctx.Clone(param->symbol)); auto member_name = ctx.Clone(member->Declaration()->symbol); return b.MemberAccessor(param_expr, member_name); }; break; } } } } } if (!local_index) { // No existing local index parameter. Append one to the entry point. auto* param = b.Param(b.Symbols().New("local_invocation_index"), b.ty.u32(), {b.Builtin(ast::Builtin::kLocalInvocationIndex)}); ctx.InsertBack(fn->params, param); local_index = [=] { return b.Expr(param->symbol); }; } // Take the zeroing statements and bin them by the number of iterations // required to zero the workgroup data. We then emit these in blocks, // possibly wrapped in if-statements or for-loops. std::unordered_map> stmts_by_num_iterations; std::vector num_sorted_iterations; for (auto& s : statements) { auto& stmts = stmts_by_num_iterations[s.num_iterations]; if (stmts.empty()) { num_sorted_iterations.emplace_back(s.num_iterations); } stmts.emplace_back(s); } std::sort(num_sorted_iterations.begin(), num_sorted_iterations.end()); // Loop over the statements, grouped by num_iterations. for (auto num_iterations : num_sorted_iterations) { auto& stmts = stmts_by_num_iterations[num_iterations]; // Gather all the array indices used by all the statements in the block. ArrayIndices array_indices; for (auto& s : stmts) { for (auto& idx : s.array_indices) { array_indices.add(idx); } } // Determine the block type used to emit these statements. if (workgroup_size_const == 0 || num_iterations > workgroup_size_const) { // Either the workgroup size is dynamic, or smaller than num_iterations. // In either case, we need to generate a for loop to ensure we // initialize all the array elements. // // for (var idx : u32 = local_index; // idx < num_iterations; // idx += workgroup_size) { // ... // } auto idx = b.Symbols().New("idx"); auto* init = b.Decl(b.Var(idx, b.ty.u32(), local_index())); auto* cond = b.create( ast::BinaryOp::kLessThan, b.Expr(idx), b.Expr(num_iterations)); auto* cont = b.Assign( idx, b.Add(idx, workgroup_size_const ? b.Expr(workgroup_size_const) : workgroup_size_expr())); auto block = DeclareArrayIndices(num_iterations, array_indices, [&] { return b.Expr(idx); }); for (auto& s : stmts) { block.emplace_back(s.stmt); } auto* for_loop = b.For(init, cond, cont, b.Block(block)); ctx.InsertFront(fn->body->statements, for_loop); } else if (num_iterations < workgroup_size_const) { // Workgroup size is a known constant, but is greater than // num_iterations. Emit an if statement: // // if (local_index < num_iterations) { // ... // } auto* cond = b.create( ast::BinaryOp::kLessThan, local_index(), b.Expr(num_iterations)); auto block = DeclareArrayIndices(num_iterations, array_indices, [&] { return b.Expr(local_index()); }); for (auto& s : stmts) { block.emplace_back(s.stmt); } auto* if_stmt = b.If(cond, b.Block(block)); ctx.InsertFront(fn->body->statements, if_stmt); } else { // Workgroup size exactly equals num_iterations. // No need for any conditionals. Just emit a basic block: // // { // ... // } auto block = DeclareArrayIndices(num_iterations, array_indices, [&] { return b.Expr(local_index()); }); for (auto& s : stmts) { block.emplace_back(s.stmt); } ctx.InsertFront(fn->body->statements, b.Block(block)); } } // Append a single workgroup barrier after the zero initialization. ctx.InsertFront(fn->body->statements, b.CallStmt(b.Call("workgroupBarrier"))); } /// BuildZeroingExpr is a function that builds a sub-expression used to zero /// workgroup values. `num_values` is the number of elements that the /// expression will be used to zero. Returns the expression. using BuildZeroingExpr = std::function; /// BuildZeroingStatements() generates the statements required to zero /// initialize the workgroup storage expression of type `ty`. /// @param ty the expression type /// @param get_expr a function that builds the AST nodes for the expression. void BuildZeroingStatements(const sem::Type* ty, const BuildZeroingExpr& get_expr) { if (CanTriviallyZero(ty)) { auto var = get_expr(1u); auto* zero_init = b.Construct(CreateASTTypeFor(ctx, ty)); statements.emplace_back(Statement{b.Assign(var.expr, zero_init), var.num_iterations, var.array_indices}); return; } if (auto* atomic = ty->As()) { auto* zero_init = b.Construct(CreateASTTypeFor(ctx, atomic->Type())); auto expr = get_expr(1u); auto* store = b.Call("atomicStore", b.AddressOf(expr.expr), zero_init); statements.emplace_back(Statement{b.CallStmt(store), expr.num_iterations, expr.array_indices}); return; } if (auto* str = ty->As()) { for (auto* member : str->Members()) { auto name = ctx.Clone(member->Declaration()->symbol); BuildZeroingStatements(member->Type(), [&](uint32_t num_values) { auto s = get_expr(num_values); return Expression{b.MemberAccessor(s.expr, name), s.num_iterations, s.array_indices}; }); } return; } if (auto* arr = ty->As()) { BuildZeroingStatements(arr->ElemType(), [&](uint32_t num_values) { // num_values is the number of values to zero for the element type. // The number of iterations required to zero the array and its elements // is: // `num_values * arr->Count()` // The index for this array is: // `(idx % modulo) / division` auto modulo = num_values * arr->Count(); auto division = num_values; auto a = get_expr(modulo); auto array_indices = a.array_indices; array_indices.add(ArrayIndex{modulo, division}); auto index = utils::GetOrCreate(array_index_names, ArrayIndex{modulo, division}, [&] { return b.Symbols().New("i"); }); return Expression{b.IndexAccessor(a.expr, index), a.num_iterations, array_indices}; }); return; } TINT_UNREACHABLE(Transform, b.Diagnostics()) << "could not zero workgroup type: " << ty->type_name(); } /// DeclareArrayIndices returns a list of statements that contain the `let` /// declarations for all of the ArrayIndices. /// @param num_iterations the number of iterations for the block /// @param array_indices the list of array indices to generate `let` /// declarations for /// @param iteration a function that returns the index of the current /// iteration. /// @returns the list of `let` statements that declare the array indices ast::StatementList DeclareArrayIndices( uint32_t num_iterations, const ArrayIndices& array_indices, const std::function& iteration) { ast::StatementList stmts; std::map indices_by_name; for (auto index : array_indices) { auto name = array_index_names.at(index); auto* mod = (num_iterations > index.modulo) ? b.create( ast::BinaryOp::kModulo, iteration(), b.Expr(index.modulo)) : iteration(); auto* div = (index.division != 1u) ? b.Div(mod, index.division) : mod; auto* decl = b.Decl(b.Const(name, b.ty.u32(), div)); stmts.emplace_back(decl); } return stmts; } /// CalculateWorkgroupSize initializes the members #workgroup_size_const and /// #workgroup_size_expr with the linear workgroup size. /// @param deco the workgroup decoration applied to the entry point function void CalculateWorkgroupSize(const ast::WorkgroupDecoration* deco) { bool is_signed = false; workgroup_size_const = 1u; workgroup_size_expr = nullptr; for (auto* expr : deco->Values()) { if (!expr) { continue; } auto* sem = ctx.src->Sem().Get(expr); if (auto c = sem->ConstantValue()) { if (c.ElementType()->Is()) { workgroup_size_const *= static_cast(c.Elements()[0].i32); continue; } else if (c.ElementType()->Is()) { workgroup_size_const *= c.Elements()[0].u32; continue; } } // Constant value could not be found. Build expression instead. workgroup_size_expr = [this, expr, size = workgroup_size_expr] { auto* e = ctx.Clone(expr); if (ctx.src->TypeOf(expr)->UnwrapRef()->Is()) { e = b.Construct(e); } return size ? b.Mul(size(), e) : e; }; } if (workgroup_size_expr) { if (workgroup_size_const != 1) { // Fold workgroup_size_const in to workgroup_size_expr workgroup_size_expr = [this, is_signed, const_size = workgroup_size_const, expr_size = workgroup_size_expr] { return is_signed ? b.Mul(expr_size(), static_cast(const_size)) : b.Mul(expr_size(), const_size); }; } // Indicate that workgroup_size_expr should be used instead of the // constant. workgroup_size_const = 0; } } /// @returns true if a variable with store type `ty` can be efficiently zeroed /// by assignment of a type constructor without operands. If /// CanTriviallyZero() returns false, then the type needs to be /// initialized by decomposing the initialization into multiple /// sub-initializations. /// @param ty the type to inspect bool CanTriviallyZero(const sem::Type* ty) { if (ty->Is()) { return false; } if (auto* str = ty->As()) { for (auto* member : str->Members()) { if (!CanTriviallyZero(member->Type())) { return false; } } } if (ty->Is()) { return false; } // True for all other storable types return true; } }; ZeroInitWorkgroupMemory::ZeroInitWorkgroupMemory() = default; ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default; void ZeroInitWorkgroupMemory::Run(CloneContext& ctx, const DataMap&, DataMap&) { for (auto* fn : ctx.src->AST().Functions()) { if (fn->PipelineStage() == ast::PipelineStage::kCompute) { State{ctx}.Run(fn); } } ctx.Clone(); } } // namespace transform } // namespace tint