tint: Refactor ModuleScopeVarToEntryPointParam

Split the main Process() function up into smaller functions to make it
less unwieldy.

Change-Id: Ibbe3141a82221879b9d4ed232eebdd0344f1698c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/94000
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
James Price 2022-06-20 12:39:33 +00:00 committed by Dawn LUCI CQ
parent 159e78f50b
commit 46583621c0
1 changed files with 219 additions and 147 deletions

View File

@ -30,6 +30,10 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::ModuleScopeVarToEntryPointParam);
namespace tint::transform {
namespace {
// The name of the struct member for arrays that are wrapped in structures.
const char* kWrappedArrayMemberName = "arr";
// Returns `true` if `type` is or contains a matrix type.
bool ContainsMatrix(const sem::Type* type) {
type = type->UnwrapRef();
@ -83,6 +87,192 @@ struct ModuleScopeVarToEntryPointParam::State {
}
}
/// Process a variable `var` that is referenced in the entry point function `func`.
/// This will redeclare the variable as a function parameter, possibly as a pointer.
/// Some workgroup variables will be redeclared as a member inside a workgroup structure.
/// @param func the entry point function
/// @param var the variable
/// @param new_var_symbol the symbol to use for the replacement
/// @param workgroup_param helper function to get a symbol to a workgroup struct parameter
/// @param workgroup_parameter_members reference to a list of a workgroup struct members
/// @param is_pointer output signalling whether the replacement is a pointer
/// @param is_wrapped output signalling whether the replacement is wrapped in a struct
void ProcessVariableInEntryPoint(const ast::Function* func,
const sem::Variable* var,
Symbol new_var_symbol,
std::function<Symbol()> workgroup_param,
ast::StructMemberList& workgroup_parameter_members,
bool& is_pointer,
bool& is_wrapped) {
auto* var_ast = var->Declaration()->As<ast::Var>();
auto* ty = var->Type()->UnwrapRef();
// Helper to create an AST node for the store type of the variable.
auto store_type = [&]() { return CreateASTTypeFor(ctx, ty); };
ast::StorageClass sc = var->StorageClass();
switch (sc) {
case ast::StorageClass::kHandle: {
// For a texture or sampler variable, redeclare it as an entry point parameter.
// Disable entry point parameter validation.
auto* disable_validation =
ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter);
auto attrs = ctx.Clone(var->Declaration()->attributes);
attrs.push_back(disable_validation);
auto* param = ctx.dst->Param(new_var_symbol, store_type(), attrs);
ctx.InsertFront(func->params, param);
break;
}
case ast::StorageClass::kStorage:
case ast::StorageClass::kUniform: {
// Variables into the Storage and Uniform storage classes are redeclared as entry
// point parameters with a pointer type.
auto attributes = ctx.Clone(var->Declaration()->attributes);
attributes.push_back(
ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter));
attributes.push_back(
ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
auto* param_type = store_type();
if (auto* arr = ty->As<sem::Array>(); arr && arr->IsRuntimeSized()) {
// Wrap runtime-sized arrays in structures, so that we can declare pointers to
// them. Ideally we'd just emit the array itself as a pointer, but this is not
// representable in Tint's AST.
CloneStructTypes(ty);
auto* wrapper = ctx.dst->Structure(
ctx.dst->Sym(), {ctx.dst->Member(kWrappedArrayMemberName, param_type)});
param_type = ctx.dst->ty.Of(wrapper);
is_wrapped = true;
}
param_type = ctx.dst->ty.pointer(param_type, sc, var_ast->declared_access);
auto* param = ctx.dst->Param(new_var_symbol, param_type, attributes);
ctx.InsertFront(func->params, param);
is_pointer = true;
break;
}
case ast::StorageClass::kWorkgroup: {
if (ContainsMatrix(var->Type())) {
// Due to a bug in the MSL compiler, we use a threadgroup memory argument for
// any workgroup allocation that contains a matrix. See crbug.com/tint/938.
// TODO(jrprice): Do this for all other workgroup variables too.
// Create a member in the workgroup parameter struct.
auto member = ctx.Clone(var->Declaration()->symbol);
workgroup_parameter_members.push_back(ctx.dst->Member(member, store_type()));
CloneStructTypes(var->Type()->UnwrapRef());
// Create a function-scope variable that is a pointer to the member.
auto* member_ptr = ctx.dst->AddressOf(
ctx.dst->MemberAccessor(ctx.dst->Deref(workgroup_param()), member));
auto* local_var = ctx.dst->Let(
new_var_symbol,
ctx.dst->ty.pointer(store_type(), ast::StorageClass::kWorkgroup),
member_ptr);
ctx.InsertFront(func->body->statements, ctx.dst->Decl(local_var));
is_pointer = true;
break;
}
[[fallthrough]];
}
case ast::StorageClass::kPrivate: {
// Variables in the Private and Workgroup storage classes are redeclared at function
// scope. Disable storage class validation on this variable.
auto* disable_validation =
ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass);
auto* constructor = ctx.Clone(var->Declaration()->constructor);
auto* local_var = ctx.dst->Var(new_var_symbol, store_type(), sc, constructor,
ast::AttributeList{disable_validation});
ctx.InsertFront(func->body->statements, ctx.dst->Decl(local_var));
break;
}
default: {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unhandled module-scope storage class (" << sc << ")";
}
}
}
/// Process a variable `var` that is referenced in the user-defined function `func`.
/// This will redeclare the variable as a function parameter, possibly as a pointer.
/// @param func the user-defined function
/// @param var the variable
/// @param new_var_symbol the symbol to use for the replacement
/// @param is_pointer output signalling whether the replacement is a pointer or not
void ProcessVariableInUserFunction(const ast::Function* func,
const sem::Variable* var,
Symbol new_var_symbol,
bool& is_pointer) {
auto* var_ast = var->Declaration()->As<ast::Var>();
auto* ty = var->Type()->UnwrapRef();
auto* param_type = CreateASTTypeFor(ctx, ty);
auto sc = var->StorageClass();
switch (sc) {
case ast::StorageClass::kPrivate:
case ast::StorageClass::kStorage:
case ast::StorageClass::kUniform:
case ast::StorageClass::kHandle:
case ast::StorageClass::kWorkgroup:
break;
default:
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unhandled module-scope storage class (" << sc << ")";
}
// Use a pointer for non-handle types.
ast::AttributeList attributes;
if (!ty->is_handle()) {
param_type = ctx.dst->ty.pointer(param_type, sc, var_ast->declared_access);
is_pointer = true;
// Disable validation of the parameter's storage class and of arguments passed to it.
attributes.push_back(ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
attributes.push_back(
ctx.dst->Disable(ast::DisabledValidation::kIgnoreInvalidPointerArgument));
}
// Redeclare the variable as a parameter.
ctx.InsertBack(func->params, ctx.dst->Param(new_var_symbol, param_type, attributes));
}
/// Replace all uses of `var` in `func` with references to `new_var`.
/// @param func the function
/// @param var the variable to replace
/// @param new_var the symbol to use for replacement
/// @param is_pointer true if `new_var` is a pointer to the new variable
/// @param is_wrapped true if `new_var` is an array wrapped in a structure
void ReplaceUsesInFunction(const ast::Function* func,
const sem::Variable* var,
Symbol new_var,
bool is_pointer,
bool is_wrapped) {
for (auto* user : var->Users()) {
if (user->Stmt()->Function()->Declaration() == func) {
const ast::Expression* expr = ctx.dst->Expr(new_var);
if (is_pointer) {
// If this identifier is used by an address-of operator, just remove the
// address-of instead of adding a deref, since we already have a pointer.
auto* ident = user->Declaration()->As<ast::IdentifierExpression>();
if (ident_to_address_of_.count(ident)) {
ctx.Replace(ident_to_address_of_[ident], expr);
continue;
}
expr = ctx.dst->Deref(expr);
}
if (is_wrapped) {
// Get the member from the wrapper structure.
expr = ctx.dst->MemberAccessor(expr, kWrappedArrayMemberName);
}
ctx.Replace(user->Declaration(), expr);
}
}
}
/// Process the module.
void Process() {
// Predetermine the list of function calls that need to be replaced.
@ -91,8 +281,7 @@ struct ModuleScopeVarToEntryPointParam::State {
std::vector<const ast::Function*> functions_to_process;
// Build a list of functions that transitively reference any module-scope
// variables.
// Build a list of functions that transitively reference any module-scope variables.
for (auto* func_ast : ctx.src->AST().Functions()) {
auto* func_sem = ctx.src->Sem().Get(func_ast);
@ -114,20 +303,18 @@ struct ModuleScopeVarToEntryPointParam::State {
}
}
// Build a list of `&ident` expressions. We'll use this later to avoid
// generating expressions of the form `&*ident`, which break WGSL validation
// rules when this expression is passed to a function.
// TODO(jrprice): We should add support for bidirectional SEM tree traversal
// so that we can do this on the fly instead.
std::unordered_map<const ast::IdentifierExpression*, const ast::UnaryOpExpression*>
ident_to_address_of;
// Build a list of `&ident` expressions. We'll use this later to avoid generating
// expressions of the form `&*ident`, which break WGSL validation rules when this expression
// is passed to a function.
// TODO(jrprice): We should add support for bidirectional SEM tree traversal so that we can
// do this on the fly instead.
for (auto* node : ctx.src->ASTNodes().Objects()) {
auto* address_of = node->As<ast::UnaryOpExpression>();
if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
continue;
}
if (auto* ident = address_of->expr->As<ast::IdentifierExpression>()) {
ident_to_address_of[ident] = address_of;
ident_to_address_of_[ident] = address_of;
}
}
@ -141,11 +328,10 @@ struct ModuleScopeVarToEntryPointParam::State {
bool is_pointer;
bool is_wrapped;
};
const char* kWrappedArrayMemberName = "arr";
std::unordered_map<const sem::Variable*, NewVar> var_to_newvar;
// We aggregate all workgroup variables into a struct to avoid hitting
// MSL's limit for threadgroup memory arguments.
// We aggregate all workgroup variables into a struct to avoid hitting MSL's limit for
// threadgroup memory arguments.
Symbol workgroup_parameter_symbol;
ast::StructMemberList workgroup_parameter_members;
auto workgroup_param = [&]() {
@ -155,159 +341,40 @@ struct ModuleScopeVarToEntryPointParam::State {
return workgroup_parameter_symbol;
};
for (auto* global : func_sem->TransitivelyReferencedGlobals()) {
auto* var = global->Declaration()->As<ast::Var>();
if (!var) {
// Process and redeclare all variables referenced by the function.
for (auto* var : func_sem->TransitivelyReferencedGlobals()) {
if (var->StorageClass() == ast::StorageClass::kNone) {
continue;
}
auto sc = global->StorageClass();
auto* ty = global->Type()->UnwrapRef();
if (sc == ast::StorageClass::kNone) {
continue;
}
if (sc != ast::StorageClass::kPrivate && sc != ast::StorageClass::kStorage &&
sc != ast::StorageClass::kUniform && sc != ast::StorageClass::kHandle &&
sc != ast::StorageClass::kWorkgroup) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unhandled module-scope storage class (" << sc << ")";
}
// This is the symbol for the variable that replaces the module-scope
// var.
// This is the symbol for the variable that replaces the module-scope var.
auto new_var_symbol = ctx.dst->Sym();
// Helper to create an AST node for the store type of the variable.
auto store_type = [&]() { return CreateASTTypeFor(ctx, ty); };
// Track whether the new variable is a pointer or not.
bool is_pointer = false;
// Track whether the new variable was wrapped in a struct or not.
bool is_wrapped = false;
// Process the variable to redeclare it as a function parameter or local variable.
if (is_entry_point) {
if (global->Type()->UnwrapRef()->is_handle()) {
// For a texture or sampler variable, redeclare it as an entry point
// parameter. Disable entry point parameter validation.
auto* disable_validation =
ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter);
auto attrs = ctx.Clone(var->attributes);
attrs.push_back(disable_validation);
auto* param = ctx.dst->Param(new_var_symbol, store_type(), attrs);
ctx.InsertFront(func_ast->params, param);
} else if (sc == ast::StorageClass::kStorage ||
sc == ast::StorageClass::kUniform) {
// Variables into the Storage and Uniform storage classes are
// redeclared as entry point parameters with a pointer type.
auto attributes = ctx.Clone(var->attributes);
attributes.push_back(
ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter));
attributes.push_back(
ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
auto* param_type = store_type();
if (auto* arr = ty->As<sem::Array>(); arr && arr->IsRuntimeSized()) {
// Wrap runtime-sized arrays in structures, so that we can declare
// pointers to them. Ideally we'd just emit the array itself as a
// pointer, but this is not representable in Tint's AST.
CloneStructTypes(ty);
auto* wrapper = ctx.dst->Structure(
ctx.dst->Sym(),
{ctx.dst->Member(kWrappedArrayMemberName, param_type)});
param_type = ctx.dst->ty.Of(wrapper);
is_wrapped = true;
}
param_type = ctx.dst->ty.pointer(param_type, sc, var->declared_access);
auto* param = ctx.dst->Param(new_var_symbol, param_type, attributes);
ctx.InsertFront(func_ast->params, param);
is_pointer = true;
} else if (sc == ast::StorageClass::kWorkgroup &&
ContainsMatrix(global->Type())) {
// Due to a bug in the MSL compiler, we use a threadgroup memory
// argument for any workgroup allocation that contains a matrix.
// See crbug.com/tint/938.
// TODO(jrprice): Do this for all other workgroup variables too.
// Create a member in the workgroup parameter struct.
auto member = ctx.Clone(var->symbol);
workgroup_parameter_members.push_back(
ctx.dst->Member(member, store_type()));
CloneStructTypes(global->Type()->UnwrapRef());
// Create a function-scope variable that is a pointer to the member.
auto* member_ptr = ctx.dst->AddressOf(
ctx.dst->MemberAccessor(ctx.dst->Deref(workgroup_param()), member));
auto* local_var = ctx.dst->Let(
new_var_symbol,
ctx.dst->ty.pointer(store_type(), ast::StorageClass::kWorkgroup),
member_ptr);
ctx.InsertFront(func_ast->body->statements, ctx.dst->Decl(local_var));
is_pointer = true;
} else {
// Variables in the Private and Workgroup storage classes are
// redeclared at function scope. Disable storage class validation on
// this variable.
auto* disable_validation =
ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass);
auto* constructor = ctx.Clone(var->constructor);
auto* local_var =
ctx.dst->Var(new_var_symbol, store_type(), sc, constructor,
ast::AttributeList{disable_validation});
ctx.InsertFront(func_ast->body->statements, ctx.dst->Decl(local_var));
}
ProcessVariableInEntryPoint(func_ast, var, new_var_symbol, workgroup_param,
workgroup_parameter_members, is_pointer,
is_wrapped);
} else {
// For a regular function, redeclare the variable as a parameter.
// Use a pointer for non-handle types.
auto* param_type = store_type();
ast::AttributeList attributes;
if (!global->Type()->UnwrapRef()->is_handle()) {
param_type = ctx.dst->ty.pointer(param_type, sc, var->declared_access);
is_pointer = true;
// Disable validation of the parameter's storage class and of
// arguments passed it.
attributes.push_back(
ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
attributes.push_back(ctx.dst->Disable(
ast::DisabledValidation::kIgnoreInvalidPointerArgument));
}
ctx.InsertBack(func_ast->params,
ctx.dst->Param(new_var_symbol, param_type, attributes));
ProcessVariableInUserFunction(func_ast, var, new_var_symbol, is_pointer);
}
// Replace all uses of the module-scope variable.
// For non-entry points, dereference non-handle pointer parameters.
for (auto* user : global->Users()) {
if (user->Stmt()->Function()->Declaration() == func_ast) {
const ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
if (is_pointer) {
// If this identifier is used by an address-of operator, just
// remove the address-of instead of adding a deref, since we
// already have a pointer.
auto* ident = user->Declaration()->As<ast::IdentifierExpression>();
if (ident_to_address_of.count(ident)) {
ctx.Replace(ident_to_address_of[ident], expr);
continue;
}
ReplaceUsesInFunction(func_ast, var, new_var_symbol, is_pointer, is_wrapped);
expr = ctx.dst->Deref(expr);
}
if (is_wrapped) {
// Get the member from the wrapper structure.
expr = ctx.dst->MemberAccessor(expr, kWrappedArrayMemberName);
}
ctx.Replace(user->Declaration(), expr);
}
}
var_to_newvar[global] = {new_var_symbol, is_pointer, is_wrapped};
// Record the replacement symbol.
var_to_newvar[var] = {new_var_symbol, is_pointer, is_wrapped};
}
if (!workgroup_parameter_members.empty()) {
// Create the workgroup memory parameter.
// The parameter is a struct that contains members for each workgroup
// variable.
// The parameter is a struct that contains members for each workgroup variable.
auto* str =
ctx.dst->Structure(ctx.dst->Sym(), std::move(workgroup_parameter_members));
auto* param_type =
@ -335,8 +402,8 @@ struct ModuleScopeVarToEntryPointParam::State {
bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
const ast::Expression* arg = ctx.dst->Expr(new_var.symbol);
if (new_var.is_wrapped) {
// The variable is wrapped in a struct, so we need to pass a pointer
// to the struct member instead.
// The variable is wrapped in a struct, so we need to pass a pointer to the
// struct member instead.
arg = ctx.dst->AddressOf(
ctx.dst->MemberAccessor(ctx.dst->Deref(arg), kWrappedArrayMemberName));
} else if (is_entry_point && !is_handle && !new_var.is_pointer) {
@ -359,7 +426,12 @@ struct ModuleScopeVarToEntryPointParam::State {
}
private:
// The structures that have already been cloned by this transform.
std::unordered_set<const sem::Struct*> cloned_structs_;
// Map from identifier expression to the address-of expression that uses it.
std::unordered_map<const ast::IdentifierExpression*, const ast::UnaryOpExpression*>
ident_to_address_of_;
};
ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;