// Copyright 2020 The Tint Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "src/transform/msl.h" #include #include #include #include #include "src/ast/disable_validation_decoration.h" #include "src/program_builder.h" #include "src/sem/call.h" #include "src/sem/function.h" #include "src/sem/statement.h" #include "src/sem/variable.h" #include "src/transform/array_length_from_uniform.h" #include "src/transform/canonicalize_entry_point_io.h" #include "src/transform/external_texture_transform.h" #include "src/transform/inline_pointer_lets.h" #include "src/transform/manager.h" #include "src/transform/pad_array_elements.h" #include "src/transform/promote_initializers_to_const_var.h" #include "src/transform/simplify.h" #include "src/transform/wrap_arrays_in_structs.h" #include "src/transform/zero_init_workgroup_memory.h" TINT_INSTANTIATE_TYPEINFO(tint::transform::Msl::Config); TINT_INSTANTIATE_TYPEINFO(tint::transform::Msl::Result); namespace tint { namespace transform { Msl::Msl() = default; Msl::~Msl() = default; Output Msl::Run(const Program* in, const DataMap& inputs) { Manager manager; DataMap internal_inputs; auto* cfg = inputs.Get(); // Build the configs for the internal transforms. uint32_t buffer_size_ubo_index = kDefaultBufferSizeUniformIndex; uint32_t fixed_sample_mask = 0xFFFFFFFF; if (cfg) { buffer_size_ubo_index = cfg->buffer_size_ubo_index; fixed_sample_mask = cfg->fixed_sample_mask; } auto array_length_from_uniform_cfg = ArrayLengthFromUniform::Config( sem::BindingPoint{0, buffer_size_ubo_index}); auto entry_point_io_cfg = CanonicalizeEntryPointIO::Config( CanonicalizeEntryPointIO::BuiltinStyle::kParameter, fixed_sample_mask); // Use the SSBO binding numbers as the indices for the buffer size lookups. for (auto* var : in->AST().GlobalVariables()) { auto* sem_var = in->Sem().Get(var); if (sem_var->StorageClass() == ast::StorageClass::kStorage) { array_length_from_uniform_cfg.bindpoint_to_size_index.emplace( sem_var->BindingPoint(), sem_var->BindingPoint().binding); } } // ZeroInitWorkgroupMemory must come before CanonicalizeEntryPointIO as // ZeroInitWorkgroupMemory may inject new builtin parameters. manager.Add(); manager.Add(); manager.Add(); manager.Add(); manager.Add(); manager.Add(); manager.Add(); manager.Add(); // ArrayLengthFromUniform must come after InlinePointerLets and Simplify, as // it assumes that the form of the array length argument is &var.array. manager.Add(); internal_inputs.Add( std::move(array_length_from_uniform_cfg)); internal_inputs.Add( std::move(entry_point_io_cfg)); auto out = manager.Run(in, internal_inputs); if (!out.program.IsValid()) { return out; } ProgramBuilder builder; CloneContext ctx(&builder, &out.program); // TODO(jrprice): Consider making this a standalone transform, with target // storage class(es) as transform options. HandleModuleScopeVariables(ctx); ctx.Clone(); auto result = std::make_unique( out.data.Get()->needs_buffer_sizes); return Output{Program(std::move(builder)), std::move(result)}; } void Msl::HandleModuleScopeVariables(CloneContext& ctx) const { // MSL does not allow private and workgroup variables at module-scope, so we // push these declarations into the entry point function and then pass them as // pointer parameters to any function that references them. // Similarly, texture and sampler types are converted to entry point // parameters and passed by value to functions that need them. // // Since WGSL does not allow function-scope variables to have these storage // classes, we annotate the new variable declarations with an attribute that // bypasses that validation rule. // // Before: // ``` // var v : f32 = 2.0; // // fn foo() { // v = v + 1.0; // } // // [[stage(compute)]] // fn main() { // foo(); // } // ``` // // After: // ``` // fn foo(v : ptr) { // *v = *v + 1.0; // } // // [[stage(compute)]] // fn main() { // var v : f32 = 2.0; // foo(&v); // } // ``` // Predetermine the list of function calls that need to be replaced. using CallList = std::vector; std::unordered_map calls_to_replace; std::vector functions_to_process; // Build a list of functions that transitively reference any private or // workgroup variables, or texture/sampler variables. for (auto* func_ast : ctx.src->AST().Functions()) { auto* func_sem = ctx.src->Sem().Get(func_ast); bool needs_processing = false; for (auto* var : func_sem->ReferencedModuleVariables()) { if (var->StorageClass() == ast::StorageClass::kPrivate || var->StorageClass() == ast::StorageClass::kWorkgroup || var->StorageClass() == ast::StorageClass::kUniformConstant) { needs_processing = true; break; } } if (needs_processing) { functions_to_process.push_back(func_ast); // Find all of the calls to this function that will need to be replaced. for (auto* call : func_sem->CallSites()) { auto* call_sem = ctx.src->Sem().Get(call); calls_to_replace[call_sem->Stmt()->Function()].push_back(call); } } } for (auto* func_ast : functions_to_process) { auto* func_sem = ctx.src->Sem().Get(func_ast); bool is_entry_point = func_ast->IsEntryPoint(); // Map module-scope variables onto their function-scope replacement. std::unordered_map var_to_symbol; for (auto* var : func_sem->ReferencedModuleVariables()) { if (var->StorageClass() != ast::StorageClass::kPrivate && var->StorageClass() != ast::StorageClass::kWorkgroup && var->StorageClass() != ast::StorageClass::kUniformConstant) { continue; } // This is the symbol for the variable that replaces the module-scope var. auto new_var_symbol = ctx.dst->Sym(); auto* store_type = CreateASTTypeFor(&ctx, var->Type()->UnwrapRef()); if (is_entry_point) { if (store_type->is_handle()) { // For a texture or sampler variable, redeclare it as an entry point // parameter. Disable entry point parameter validation. auto* disable_validation = ctx.dst->ASTNodes().Create( ctx.dst->ID(), ast::DisabledValidation::kEntryPointParameter); auto decos = ctx.Clone(var->Declaration()->decorations()); decos.push_back(disable_validation); auto* param = ctx.dst->Param(new_var_symbol, store_type, decos); ctx.InsertFront(func_ast->params(), param); } else { // For a private or workgroup variable, redeclare it at function // scope. Disable storage class validation on this variable. auto* disable_validation = ctx.dst->ASTNodes().Create( ctx.dst->ID(), ast::DisabledValidation::kFunctionVarStorageClass); auto* constructor = ctx.Clone(var->Declaration()->constructor()); auto* local_var = ctx.dst->Var( new_var_symbol, store_type, var->StorageClass(), constructor, ast::DecorationList{disable_validation}); ctx.InsertFront(func_ast->body()->statements(), ctx.dst->Decl(local_var)); } } else { // For a regular function, redeclare the variable as a parameter. // Use a pointer for non-handle types. auto* param_type = store_type; if (!store_type->is_handle()) { param_type = ctx.dst->ty.pointer(param_type, var->StorageClass()); } ctx.InsertBack(func_ast->params(), ctx.dst->Param(new_var_symbol, param_type)); } // Replace all uses of the module-scope variable. // For non-entry points, dereference non-handle pointer parameters. for (auto* user : var->Users()) { if (user->Stmt()->Function() == func_ast) { ast::Expression* expr = ctx.dst->Expr(new_var_symbol); if (!is_entry_point && !store_type->is_handle()) { expr = ctx.dst->Deref(expr); } ctx.Replace(user->Declaration(), expr); } } var_to_symbol[var] = new_var_symbol; } // Pass the variables as pointers to any functions that need them. for (auto* call : calls_to_replace[func_ast]) { auto* target = ctx.src->AST().Functions().Find(call->func()->symbol()); auto* target_sem = ctx.src->Sem().Get(target); // Add new arguments for any variables that are needed by the callee. // For entry points, pass non-handle types as pointers. for (auto* target_var : target_sem->ReferencedModuleVariables()) { if (target_var->StorageClass() == ast::StorageClass::kPrivate || target_var->StorageClass() == ast::StorageClass::kWorkgroup || target_var->StorageClass() == ast::StorageClass::kUniformConstant) { ast::Expression* arg = ctx.dst->Expr(var_to_symbol[target_var]); if (is_entry_point && !target_var->Type()->UnwrapRef()->is_handle()) { arg = ctx.dst->AddressOf(arg); } ctx.InsertBack(call->params(), arg); } } } } // Now remove all module-scope variables with these storage classes. for (auto* var_ast : ctx.src->AST().GlobalVariables()) { auto* var_sem = ctx.src->Sem().Get(var_ast); if (var_sem->StorageClass() == ast::StorageClass::kPrivate || var_sem->StorageClass() == ast::StorageClass::kWorkgroup || var_sem->StorageClass() == ast::StorageClass::kUniformConstant) { ctx.Remove(ctx.src->AST().GlobalDeclarations(), var_ast); } } } Msl::Config::Config(uint32_t buffer_size_ubo_idx, uint32_t sample_mask) : buffer_size_ubo_index(buffer_size_ubo_idx), fixed_sample_mask(sample_mask) {} Msl::Config::Config(const Config&) = default; Msl::Config::~Config() = default; Msl::Result::Result(bool needs_buffer_sizes) : needs_storage_buffer_sizes(needs_buffer_sizes) {} Msl::Result::Result(const Result&) = default; Msl::Result::~Result() = default; } // namespace transform } // namespace tint