diff --git a/src/tint/transform/module_scope_var_to_entry_point_param.cc b/src/tint/transform/module_scope_var_to_entry_point_param.cc index cc89fde782..5a4d222544 100644 --- a/src/tint/transform/module_scope_var_to_entry_point_param.cc +++ b/src/tint/transform/module_scope_var_to_entry_point_param.cc @@ -30,6 +30,10 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::ModuleScopeVarToEntryPointParam); namespace tint::transform { namespace { + +// The name of the struct member for arrays that are wrapped in structures. +const char* kWrappedArrayMemberName = "arr"; + // Returns `true` if `type` is or contains a matrix type. bool ContainsMatrix(const sem::Type* type) { type = type->UnwrapRef(); @@ -83,6 +87,192 @@ struct ModuleScopeVarToEntryPointParam::State { } } + /// Process a variable `var` that is referenced in the entry point function `func`. + /// This will redeclare the variable as a function parameter, possibly as a pointer. + /// Some workgroup variables will be redeclared as a member inside a workgroup structure. + /// @param func the entry point function + /// @param var the variable + /// @param new_var_symbol the symbol to use for the replacement + /// @param workgroup_param helper function to get a symbol to a workgroup struct parameter + /// @param workgroup_parameter_members reference to a list of a workgroup struct members + /// @param is_pointer output signalling whether the replacement is a pointer + /// @param is_wrapped output signalling whether the replacement is wrapped in a struct + void ProcessVariableInEntryPoint(const ast::Function* func, + const sem::Variable* var, + Symbol new_var_symbol, + std::function workgroup_param, + ast::StructMemberList& workgroup_parameter_members, + bool& is_pointer, + bool& is_wrapped) { + auto* var_ast = var->Declaration()->As(); + auto* ty = var->Type()->UnwrapRef(); + + // Helper to create an AST node for the store type of the variable. + auto store_type = [&]() { return CreateASTTypeFor(ctx, ty); }; + + ast::StorageClass sc = var->StorageClass(); + switch (sc) { + case ast::StorageClass::kHandle: { + // For a texture or sampler variable, redeclare it as an entry point parameter. + // Disable entry point parameter validation. + auto* disable_validation = + ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter); + auto attrs = ctx.Clone(var->Declaration()->attributes); + attrs.push_back(disable_validation); + auto* param = ctx.dst->Param(new_var_symbol, store_type(), attrs); + ctx.InsertFront(func->params, param); + + break; + } + case ast::StorageClass::kStorage: + case ast::StorageClass::kUniform: { + // Variables into the Storage and Uniform storage classes are redeclared as entry + // point parameters with a pointer type. + auto attributes = ctx.Clone(var->Declaration()->attributes); + attributes.push_back( + ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter)); + attributes.push_back( + ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass)); + + auto* param_type = store_type(); + if (auto* arr = ty->As(); arr && arr->IsRuntimeSized()) { + // Wrap runtime-sized arrays in structures, so that we can declare pointers to + // them. Ideally we'd just emit the array itself as a pointer, but this is not + // representable in Tint's AST. + CloneStructTypes(ty); + auto* wrapper = ctx.dst->Structure( + ctx.dst->Sym(), {ctx.dst->Member(kWrappedArrayMemberName, param_type)}); + param_type = ctx.dst->ty.Of(wrapper); + is_wrapped = true; + } + + param_type = ctx.dst->ty.pointer(param_type, sc, var_ast->declared_access); + auto* param = ctx.dst->Param(new_var_symbol, param_type, attributes); + ctx.InsertFront(func->params, param); + is_pointer = true; + + break; + } + case ast::StorageClass::kWorkgroup: { + if (ContainsMatrix(var->Type())) { + // Due to a bug in the MSL compiler, we use a threadgroup memory argument for + // any workgroup allocation that contains a matrix. See crbug.com/tint/938. + // TODO(jrprice): Do this for all other workgroup variables too. + + // Create a member in the workgroup parameter struct. + auto member = ctx.Clone(var->Declaration()->symbol); + workgroup_parameter_members.push_back(ctx.dst->Member(member, store_type())); + CloneStructTypes(var->Type()->UnwrapRef()); + + // Create a function-scope variable that is a pointer to the member. + auto* member_ptr = ctx.dst->AddressOf( + ctx.dst->MemberAccessor(ctx.dst->Deref(workgroup_param()), member)); + auto* local_var = ctx.dst->Let( + new_var_symbol, + ctx.dst->ty.pointer(store_type(), ast::StorageClass::kWorkgroup), + member_ptr); + ctx.InsertFront(func->body->statements, ctx.dst->Decl(local_var)); + is_pointer = true; + + break; + } + [[fallthrough]]; + } + case ast::StorageClass::kPrivate: { + // Variables in the Private and Workgroup storage classes are redeclared at function + // scope. Disable storage class validation on this variable. + auto* disable_validation = + ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass); + auto* constructor = ctx.Clone(var->Declaration()->constructor); + auto* local_var = ctx.dst->Var(new_var_symbol, store_type(), sc, constructor, + ast::AttributeList{disable_validation}); + ctx.InsertFront(func->body->statements, ctx.dst->Decl(local_var)); + + break; + } + default: { + TINT_ICE(Transform, ctx.dst->Diagnostics()) + << "unhandled module-scope storage class (" << sc << ")"; + } + } + } + + /// Process a variable `var` that is referenced in the user-defined function `func`. + /// This will redeclare the variable as a function parameter, possibly as a pointer. + /// @param func the user-defined function + /// @param var the variable + /// @param new_var_symbol the symbol to use for the replacement + /// @param is_pointer output signalling whether the replacement is a pointer or not + void ProcessVariableInUserFunction(const ast::Function* func, + const sem::Variable* var, + Symbol new_var_symbol, + bool& is_pointer) { + auto* var_ast = var->Declaration()->As(); + auto* ty = var->Type()->UnwrapRef(); + auto* param_type = CreateASTTypeFor(ctx, ty); + auto sc = var->StorageClass(); + switch (sc) { + case ast::StorageClass::kPrivate: + case ast::StorageClass::kStorage: + case ast::StorageClass::kUniform: + case ast::StorageClass::kHandle: + case ast::StorageClass::kWorkgroup: + break; + default: + TINT_ICE(Transform, ctx.dst->Diagnostics()) + << "unhandled module-scope storage class (" << sc << ")"; + } + + // Use a pointer for non-handle types. + ast::AttributeList attributes; + if (!ty->is_handle()) { + param_type = ctx.dst->ty.pointer(param_type, sc, var_ast->declared_access); + is_pointer = true; + + // Disable validation of the parameter's storage class and of arguments passed to it. + attributes.push_back(ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass)); + attributes.push_back( + ctx.dst->Disable(ast::DisabledValidation::kIgnoreInvalidPointerArgument)); + } + + // Redeclare the variable as a parameter. + ctx.InsertBack(func->params, ctx.dst->Param(new_var_symbol, param_type, attributes)); + } + + /// Replace all uses of `var` in `func` with references to `new_var`. + /// @param func the function + /// @param var the variable to replace + /// @param new_var the symbol to use for replacement + /// @param is_pointer true if `new_var` is a pointer to the new variable + /// @param is_wrapped true if `new_var` is an array wrapped in a structure + void ReplaceUsesInFunction(const ast::Function* func, + const sem::Variable* var, + Symbol new_var, + bool is_pointer, + bool is_wrapped) { + for (auto* user : var->Users()) { + if (user->Stmt()->Function()->Declaration() == func) { + const ast::Expression* expr = ctx.dst->Expr(new_var); + if (is_pointer) { + // If this identifier is used by an address-of operator, just remove the + // address-of instead of adding a deref, since we already have a pointer. + auto* ident = user->Declaration()->As(); + if (ident_to_address_of_.count(ident)) { + ctx.Replace(ident_to_address_of_[ident], expr); + continue; + } + + expr = ctx.dst->Deref(expr); + } + if (is_wrapped) { + // Get the member from the wrapper structure. + expr = ctx.dst->MemberAccessor(expr, kWrappedArrayMemberName); + } + ctx.Replace(user->Declaration(), expr); + } + } + } + /// Process the module. void Process() { // Predetermine the list of function calls that need to be replaced. @@ -91,8 +281,7 @@ struct ModuleScopeVarToEntryPointParam::State { std::vector functions_to_process; - // Build a list of functions that transitively reference any module-scope - // variables. + // Build a list of functions that transitively reference any module-scope variables. for (auto* func_ast : ctx.src->AST().Functions()) { auto* func_sem = ctx.src->Sem().Get(func_ast); @@ -114,20 +303,18 @@ struct ModuleScopeVarToEntryPointParam::State { } } - // Build a list of `&ident` expressions. We'll use this later to avoid - // generating expressions of the form `&*ident`, which break WGSL validation - // rules when this expression is passed to a function. - // TODO(jrprice): We should add support for bidirectional SEM tree traversal - // so that we can do this on the fly instead. - std::unordered_map - ident_to_address_of; + // Build a list of `&ident` expressions. We'll use this later to avoid generating + // expressions of the form `&*ident`, which break WGSL validation rules when this expression + // is passed to a function. + // TODO(jrprice): We should add support for bidirectional SEM tree traversal so that we can + // do this on the fly instead. for (auto* node : ctx.src->ASTNodes().Objects()) { auto* address_of = node->As(); if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) { continue; } if (auto* ident = address_of->expr->As()) { - ident_to_address_of[ident] = address_of; + ident_to_address_of_[ident] = address_of; } } @@ -141,11 +328,10 @@ struct ModuleScopeVarToEntryPointParam::State { bool is_pointer; bool is_wrapped; }; - const char* kWrappedArrayMemberName = "arr"; std::unordered_map var_to_newvar; - // We aggregate all workgroup variables into a struct to avoid hitting - // MSL's limit for threadgroup memory arguments. + // We aggregate all workgroup variables into a struct to avoid hitting MSL's limit for + // threadgroup memory arguments. Symbol workgroup_parameter_symbol; ast::StructMemberList workgroup_parameter_members; auto workgroup_param = [&]() { @@ -155,159 +341,40 @@ struct ModuleScopeVarToEntryPointParam::State { return workgroup_parameter_symbol; }; - for (auto* global : func_sem->TransitivelyReferencedGlobals()) { - auto* var = global->Declaration()->As(); - if (!var) { + // Process and redeclare all variables referenced by the function. + for (auto* var : func_sem->TransitivelyReferencedGlobals()) { + if (var->StorageClass() == ast::StorageClass::kNone) { continue; } - auto sc = global->StorageClass(); - auto* ty = global->Type()->UnwrapRef(); - if (sc == ast::StorageClass::kNone) { - continue; - } - if (sc != ast::StorageClass::kPrivate && sc != ast::StorageClass::kStorage && - sc != ast::StorageClass::kUniform && sc != ast::StorageClass::kHandle && - sc != ast::StorageClass::kWorkgroup) { - TINT_ICE(Transform, ctx.dst->Diagnostics()) - << "unhandled module-scope storage class (" << sc << ")"; - } - // This is the symbol for the variable that replaces the module-scope - // var. + // This is the symbol for the variable that replaces the module-scope var. auto new_var_symbol = ctx.dst->Sym(); - // Helper to create an AST node for the store type of the variable. - auto store_type = [&]() { return CreateASTTypeFor(ctx, ty); }; - // Track whether the new variable is a pointer or not. bool is_pointer = false; // Track whether the new variable was wrapped in a struct or not. bool is_wrapped = false; + // Process the variable to redeclare it as a function parameter or local variable. if (is_entry_point) { - if (global->Type()->UnwrapRef()->is_handle()) { - // For a texture or sampler variable, redeclare it as an entry point - // parameter. Disable entry point parameter validation. - auto* disable_validation = - ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter); - auto attrs = ctx.Clone(var->attributes); - attrs.push_back(disable_validation); - auto* param = ctx.dst->Param(new_var_symbol, store_type(), attrs); - ctx.InsertFront(func_ast->params, param); - } else if (sc == ast::StorageClass::kStorage || - sc == ast::StorageClass::kUniform) { - // Variables into the Storage and Uniform storage classes are - // redeclared as entry point parameters with a pointer type. - auto attributes = ctx.Clone(var->attributes); - attributes.push_back( - ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter)); - attributes.push_back( - ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass)); - - auto* param_type = store_type(); - if (auto* arr = ty->As(); arr && arr->IsRuntimeSized()) { - // Wrap runtime-sized arrays in structures, so that we can declare - // pointers to them. Ideally we'd just emit the array itself as a - // pointer, but this is not representable in Tint's AST. - CloneStructTypes(ty); - auto* wrapper = ctx.dst->Structure( - ctx.dst->Sym(), - {ctx.dst->Member(kWrappedArrayMemberName, param_type)}); - param_type = ctx.dst->ty.Of(wrapper); - is_wrapped = true; - } - - param_type = ctx.dst->ty.pointer(param_type, sc, var->declared_access); - auto* param = ctx.dst->Param(new_var_symbol, param_type, attributes); - ctx.InsertFront(func_ast->params, param); - is_pointer = true; - } else if (sc == ast::StorageClass::kWorkgroup && - ContainsMatrix(global->Type())) { - // Due to a bug in the MSL compiler, we use a threadgroup memory - // argument for any workgroup allocation that contains a matrix. - // See crbug.com/tint/938. - // TODO(jrprice): Do this for all other workgroup variables too. - - // Create a member in the workgroup parameter struct. - auto member = ctx.Clone(var->symbol); - workgroup_parameter_members.push_back( - ctx.dst->Member(member, store_type())); - CloneStructTypes(global->Type()->UnwrapRef()); - - // Create a function-scope variable that is a pointer to the member. - auto* member_ptr = ctx.dst->AddressOf( - ctx.dst->MemberAccessor(ctx.dst->Deref(workgroup_param()), member)); - auto* local_var = ctx.dst->Let( - new_var_symbol, - ctx.dst->ty.pointer(store_type(), ast::StorageClass::kWorkgroup), - member_ptr); - ctx.InsertFront(func_ast->body->statements, ctx.dst->Decl(local_var)); - is_pointer = true; - } else { - // Variables in the Private and Workgroup storage classes are - // redeclared at function scope. Disable storage class validation on - // this variable. - auto* disable_validation = - ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass); - auto* constructor = ctx.Clone(var->constructor); - auto* local_var = - ctx.dst->Var(new_var_symbol, store_type(), sc, constructor, - ast::AttributeList{disable_validation}); - ctx.InsertFront(func_ast->body->statements, ctx.dst->Decl(local_var)); - } + ProcessVariableInEntryPoint(func_ast, var, new_var_symbol, workgroup_param, + workgroup_parameter_members, is_pointer, + is_wrapped); } else { - // For a regular function, redeclare the variable as a parameter. - // Use a pointer for non-handle types. - auto* param_type = store_type(); - ast::AttributeList attributes; - if (!global->Type()->UnwrapRef()->is_handle()) { - param_type = ctx.dst->ty.pointer(param_type, sc, var->declared_access); - is_pointer = true; - - // Disable validation of the parameter's storage class and of - // arguments passed it. - attributes.push_back( - ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass)); - attributes.push_back(ctx.dst->Disable( - ast::DisabledValidation::kIgnoreInvalidPointerArgument)); - } - ctx.InsertBack(func_ast->params, - ctx.dst->Param(new_var_symbol, param_type, attributes)); + ProcessVariableInUserFunction(func_ast, var, new_var_symbol, is_pointer); } // Replace all uses of the module-scope variable. - // For non-entry points, dereference non-handle pointer parameters. - for (auto* user : global->Users()) { - if (user->Stmt()->Function()->Declaration() == func_ast) { - const ast::Expression* expr = ctx.dst->Expr(new_var_symbol); - if (is_pointer) { - // If this identifier is used by an address-of operator, just - // remove the address-of instead of adding a deref, since we - // already have a pointer. - auto* ident = user->Declaration()->As(); - if (ident_to_address_of.count(ident)) { - ctx.Replace(ident_to_address_of[ident], expr); - continue; - } + ReplaceUsesInFunction(func_ast, var, new_var_symbol, is_pointer, is_wrapped); - expr = ctx.dst->Deref(expr); - } - if (is_wrapped) { - // Get the member from the wrapper structure. - expr = ctx.dst->MemberAccessor(expr, kWrappedArrayMemberName); - } - ctx.Replace(user->Declaration(), expr); - } - } - - var_to_newvar[global] = {new_var_symbol, is_pointer, is_wrapped}; + // Record the replacement symbol. + var_to_newvar[var] = {new_var_symbol, is_pointer, is_wrapped}; } if (!workgroup_parameter_members.empty()) { // Create the workgroup memory parameter. - // The parameter is a struct that contains members for each workgroup - // variable. + // The parameter is a struct that contains members for each workgroup variable. auto* str = ctx.dst->Structure(ctx.dst->Sym(), std::move(workgroup_parameter_members)); auto* param_type = @@ -335,8 +402,8 @@ struct ModuleScopeVarToEntryPointParam::State { bool is_handle = target_var->Type()->UnwrapRef()->is_handle(); const ast::Expression* arg = ctx.dst->Expr(new_var.symbol); if (new_var.is_wrapped) { - // The variable is wrapped in a struct, so we need to pass a pointer - // to the struct member instead. + // The variable is wrapped in a struct, so we need to pass a pointer to the + // struct member instead. arg = ctx.dst->AddressOf( ctx.dst->MemberAccessor(ctx.dst->Deref(arg), kWrappedArrayMemberName)); } else if (is_entry_point && !is_handle && !new_var.is_pointer) { @@ -359,7 +426,12 @@ struct ModuleScopeVarToEntryPointParam::State { } private: + // The structures that have already been cloned by this transform. std::unordered_set cloned_structs_; + + // Map from identifier expression to the address-of expression that uses it. + std::unordered_map + ident_to_address_of_; }; ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;