238 lines
8.7 KiB
C++
238 lines
8.7 KiB
C++
// Copyright 2020 The Tint Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "src/transform/msl.h"
|
|
|
|
#include <unordered_map>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "src/ast/disable_validation_decoration.h"
|
|
#include "src/program_builder.h"
|
|
#include "src/sem/call.h"
|
|
#include "src/sem/function.h"
|
|
#include "src/sem/statement.h"
|
|
#include "src/sem/variable.h"
|
|
#include "src/transform/canonicalize_entry_point_io.h"
|
|
#include "src/transform/external_texture_transform.h"
|
|
#include "src/transform/manager.h"
|
|
#include "src/transform/pad_array_elements.h"
|
|
#include "src/transform/promote_initializers_to_const_var.h"
|
|
#include "src/transform/wrap_arrays_in_structs.h"
|
|
|
|
namespace tint {
|
|
namespace transform {
|
|
|
|
Msl::Msl() = default;
|
|
Msl::~Msl() = default;
|
|
|
|
Output Msl::Run(const Program* in, const DataMap&) {
|
|
Manager manager;
|
|
DataMap data;
|
|
manager.Add<CanonicalizeEntryPointIO>();
|
|
manager.Add<ExternalTextureTransform>();
|
|
manager.Add<PromoteInitializersToConstVar>();
|
|
manager.Add<WrapArraysInStructs>();
|
|
manager.Add<PadArrayElements>();
|
|
data.Add<CanonicalizeEntryPointIO::Config>(
|
|
CanonicalizeEntryPointIO::BuiltinStyle::kParameter);
|
|
auto out = manager.Run(in, data);
|
|
if (!out.program.IsValid()) {
|
|
return out;
|
|
}
|
|
|
|
ProgramBuilder builder;
|
|
CloneContext ctx(&builder, &out.program);
|
|
// TODO(jrprice): Consider making this a standalone transform, with target
|
|
// storage class(es) as transform options.
|
|
HandleModuleScopeVariables(ctx);
|
|
ctx.Clone();
|
|
return Output{Program(std::move(builder))};
|
|
}
|
|
|
|
void Msl::HandleModuleScopeVariables(CloneContext& ctx) const {
|
|
// MSL does not allow private and workgroup variables at module-scope, so we
|
|
// push these declarations into the entry point function and then pass them as
|
|
// pointer parameters to any function that references them.
|
|
// Similarly, texture and sampler types are converted to entry point
|
|
// parameters and passed by value to functions that need them.
|
|
//
|
|
// Since WGSL does not allow function-scope variables to have these storage
|
|
// classes, we annotate the new variable declarations with an attribute that
|
|
// bypasses that validation rule.
|
|
//
|
|
// Before:
|
|
// ```
|
|
// var<private> v : f32 = 2.0;
|
|
//
|
|
// fn foo() {
|
|
// v = v + 1.0;
|
|
// }
|
|
//
|
|
// [[stage(compute)]]
|
|
// fn main() {
|
|
// foo();
|
|
// }
|
|
// ```
|
|
//
|
|
// After:
|
|
// ```
|
|
// fn foo(v : ptr<private, f32>) {
|
|
// *v = *v + 1.0;
|
|
// }
|
|
//
|
|
// [[stage(compute)]]
|
|
// fn main() {
|
|
// var<private> v : f32 = 2.0;
|
|
// foo(&v);
|
|
// }
|
|
// ```
|
|
|
|
// Predetermine the list of function calls that need to be replaced.
|
|
using CallList = std::vector<const ast::CallExpression*>;
|
|
std::unordered_map<const ast::Function*, CallList> calls_to_replace;
|
|
|
|
std::vector<ast::Function*> functions_to_process;
|
|
|
|
// Build a list of functions that transitively reference any private or
|
|
// workgroup variables, or texture/sampler variables.
|
|
for (auto* func_ast : ctx.src->AST().Functions()) {
|
|
auto* func_sem = ctx.src->Sem().Get(func_ast);
|
|
|
|
bool needs_processing = false;
|
|
for (auto* var : func_sem->ReferencedModuleVariables()) {
|
|
if (var->StorageClass() == ast::StorageClass::kPrivate ||
|
|
var->StorageClass() == ast::StorageClass::kWorkgroup ||
|
|
var->StorageClass() == ast::StorageClass::kUniformConstant) {
|
|
needs_processing = true;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (needs_processing) {
|
|
functions_to_process.push_back(func_ast);
|
|
|
|
// Find all of the calls to this function that will need to be replaced.
|
|
for (auto* call : func_sem->CallSites()) {
|
|
auto* call_sem = ctx.src->Sem().Get(call);
|
|
calls_to_replace[call_sem->Stmt()->Function()].push_back(call);
|
|
}
|
|
}
|
|
}
|
|
|
|
for (auto* func_ast : functions_to_process) {
|
|
auto* func_sem = ctx.src->Sem().Get(func_ast);
|
|
bool is_entry_point = func_ast->IsEntryPoint();
|
|
|
|
// Map module-scope variables onto their function-scope replacement.
|
|
std::unordered_map<const sem::Variable*, Symbol> var_to_symbol;
|
|
|
|
for (auto* var : func_sem->ReferencedModuleVariables()) {
|
|
if (var->StorageClass() != ast::StorageClass::kPrivate &&
|
|
var->StorageClass() != ast::StorageClass::kWorkgroup &&
|
|
var->StorageClass() != ast::StorageClass::kUniformConstant) {
|
|
continue;
|
|
}
|
|
|
|
// This is the symbol for the variable that replaces the module-scope var.
|
|
auto new_var_symbol = ctx.dst->Sym();
|
|
|
|
auto* store_type = CreateASTTypeFor(&ctx, var->Type()->UnwrapRef());
|
|
|
|
if (is_entry_point) {
|
|
if (store_type->is_handle()) {
|
|
// For a texture or sampler variable, redeclare it as an entry point
|
|
// parameter. Disable entry point parameter validation.
|
|
auto* disable_validation =
|
|
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
|
|
ctx.dst->ID(), ast::DisabledValidation::kEntryPointParameter);
|
|
auto decos = ctx.Clone(var->Declaration()->decorations());
|
|
decos.push_back(disable_validation);
|
|
auto* param = ctx.dst->Param(new_var_symbol, store_type, decos);
|
|
ctx.InsertFront(func_ast->params(), param);
|
|
} else {
|
|
// For a private or workgroup variable, redeclare it at function
|
|
// scope. Disable storage class validation on this variable.
|
|
auto* disable_validation =
|
|
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
|
|
ctx.dst->ID(),
|
|
ast::DisabledValidation::kFunctionVarStorageClass);
|
|
auto* constructor = ctx.Clone(var->Declaration()->constructor());
|
|
auto* local_var = ctx.dst->Var(new_var_symbol, store_type,
|
|
var->StorageClass(), constructor,
|
|
ast::DecorationList{disable_validation});
|
|
ctx.InsertFront(func_ast->body()->statements(),
|
|
ctx.dst->Decl(local_var));
|
|
}
|
|
} else {
|
|
// For a regular function, redeclare the variable as a parameter.
|
|
// Use a pointer for non-handle types.
|
|
auto* param_type = store_type;
|
|
if (!store_type->is_handle()) {
|
|
param_type = ctx.dst->ty.pointer(param_type, var->StorageClass());
|
|
}
|
|
ctx.InsertBack(func_ast->params(),
|
|
ctx.dst->Param(new_var_symbol, param_type));
|
|
}
|
|
|
|
// Replace all uses of the module-scope variable.
|
|
// For non-entry points, dereference non-handle pointer parameters.
|
|
for (auto* user : var->Users()) {
|
|
if (user->Stmt()->Function() == func_ast) {
|
|
ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
|
|
if (!is_entry_point && !store_type->is_handle()) {
|
|
expr = ctx.dst->Deref(expr);
|
|
}
|
|
ctx.Replace(user->Declaration(), expr);
|
|
}
|
|
}
|
|
|
|
var_to_symbol[var] = new_var_symbol;
|
|
}
|
|
|
|
// Pass the variables as pointers to any functions that need them.
|
|
for (auto* call : calls_to_replace[func_ast]) {
|
|
auto* target = ctx.src->AST().Functions().Find(call->func()->symbol());
|
|
auto* target_sem = ctx.src->Sem().Get(target);
|
|
|
|
// Add new arguments for any variables that are needed by the callee.
|
|
// For entry points, pass non-handle types as pointers.
|
|
for (auto* target_var : target_sem->ReferencedModuleVariables()) {
|
|
if (target_var->StorageClass() == ast::StorageClass::kPrivate ||
|
|
target_var->StorageClass() == ast::StorageClass::kWorkgroup ||
|
|
target_var->StorageClass() == ast::StorageClass::kUniformConstant) {
|
|
ast::Expression* arg = ctx.dst->Expr(var_to_symbol[target_var]);
|
|
if (is_entry_point && !target_var->Type()->UnwrapRef()->is_handle()) {
|
|
arg = ctx.dst->AddressOf(arg);
|
|
}
|
|
ctx.InsertBack(call->params(), arg);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Now remove all module-scope variables with these storage classes.
|
|
for (auto* var_ast : ctx.src->AST().GlobalVariables()) {
|
|
auto* var_sem = ctx.src->Sem().Get(var_ast);
|
|
if (var_sem->StorageClass() == ast::StorageClass::kPrivate ||
|
|
var_sem->StorageClass() == ast::StorageClass::kWorkgroup ||
|
|
var_sem->StorageClass() == ast::StorageClass::kUniformConstant) {
|
|
ctx.Remove(ctx.src->AST().GlobalDeclarations(), var_ast);
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace transform
|
|
} // namespace tint
|