msl: Use a struct for threadgroup memory arguments

MSL has a limit on the number of threadgroup memory arguments, so use
a struct to support an arbitrary number of workgroup variables.

This commit introduces a `State` object to this transform, which is
used to track which structs have been cloned eagerly, in order to
avoid duplicating them.

Bug: tint:938
Change-Id: Ia467db186e176a08f160455eab5fd3b3662f56b8
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/65360
Auto-Submit: James Price <jrprice@google.com>
Kokoro: James Price <jrprice@google.com>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
James Price
2021-09-29 18:56:17 +00:00
committed by Tint LUCI CQ
parent efe1f14685
commit 1ca6fbad8f
13 changed files with 2054 additions and 189 deletions

View File

@@ -42,6 +42,8 @@ std::string DisableValidationDecoration::InternalName() const {
return "disable_validation__ignore_constructible_function_parameter";
case DisabledValidation::kIgnoreStrideDecoration:
return "disable_validation__ignore_stride";
case DisabledValidation::kIgnoreInvalidPointerArgument:
return "disable_validation__ignore_invalid_pointer_argument";
}
return "<invalid>";
}

View File

@@ -43,6 +43,10 @@ enum class DisabledValidation {
/// When applied to a member decoration, a stride decoration may be applied to
/// non-array types.
kIgnoreStrideDecoration,
/// When applied to a pointer function parameter, the validator will not
/// require a function call argument passed for that parameter to have a
/// certain form.
kIgnoreInvalidPointerArgument,
};
/// An internal decoration used to tell the validator to ignore specific

View File

@@ -2659,7 +2659,10 @@ bool Resolver::ValidateFunctionCall(const ast::CallExpression* call,
}
}
if (!is_valid) {
if (!is_valid &&
IsValidationEnabled(
param->declaration->decorations(),
ast::DisabledValidation::kIgnoreInvalidPointerArgument)) {
AddError(
"expected an address-of expression of a variable identifier "
"expression or a function parameter",

View File

@@ -49,110 +49,164 @@ bool ContainsMatrix(const sem::Type* type) {
}
} // namespace
ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
/// State holds the current transform state.
struct ModuleScopeVarToEntryPointParam::State {
/// The clone context.
CloneContext& ctx;
ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default;
/// Constructor
/// @param context the clone context
explicit State(CloneContext& context) : ctx(context) {}
void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
const DataMap&,
DataMap&) {
// Predetermine the list of function calls that need to be replaced.
using CallList = std::vector<const ast::CallExpression*>;
std::unordered_map<const ast::Function*, CallList> calls_to_replace;
std::vector<ast::Function*> functions_to_process;
// Build a list of functions that transitively reference any private or
// workgroup variables, or texture/sampler variables.
for (auto* func_ast : ctx.src->AST().Functions()) {
auto* func_sem = ctx.src->Sem().Get(func_ast);
bool needs_processing = false;
for (auto* var : func_sem->ReferencedModuleVariables()) {
if (var->StorageClass() == ast::StorageClass::kPrivate ||
var->StorageClass() == ast::StorageClass::kWorkgroup ||
var->StorageClass() == ast::StorageClass::kUniformConstant) {
needs_processing = true;
break;
/// Clone any struct types that are contained in `ty` (including `ty` itself),
/// and add it to the global declarations now, so that they precede new global
/// declarations that need to reference them.
/// @param ty the type to clone
void CloneStructTypes(const sem::Type* ty) {
if (auto* str = ty->As<sem::Struct>()) {
if (!cloned_structs_.emplace(str).second) {
// The struct has already been cloned.
return;
}
}
if (needs_processing) {
functions_to_process.push_back(func_ast);
// Find all of the calls to this function that will need to be replaced.
for (auto* call : func_sem->CallSites()) {
auto* call_sem = ctx.src->Sem().Get(call);
calls_to_replace[call_sem->Stmt()->Function()].push_back(call);
// Recurse into members.
for (auto* member : str->Members()) {
CloneStructTypes(member->Type());
}
// Clone the struct and add it to the global declaration list.
// Remove the old declaration.
auto* ast_str = str->Declaration();
ctx.dst->AST().AddTypeDecl(ctx.Clone(const_cast<ast::Struct*>(ast_str)));
ctx.Remove(ctx.src->AST().GlobalDeclarations(), ast_str);
} else if (auto* arr = ty->As<sem::Array>()) {
CloneStructTypes(arr->ElemType());
}
}
// Build a list of `&ident` expressions. We'll use this later to avoid
// generating expressions of the form `&*ident`, which break WGSL validation
// rules when this expression is passed to a function.
// TODO(jrprice): We should add support for bidirectional SEM tree traversal
// so that we can do this on the fly instead.
std::unordered_map<ast::IdentifierExpression*, ast::UnaryOpExpression*>
ident_to_address_of;
for (auto* node : ctx.src->ASTNodes().Objects()) {
auto* address_of = node->As<ast::UnaryOpExpression>();
if (!address_of || address_of->op() != ast::UnaryOp::kAddressOf) {
continue;
/// Process the module.
void Process() {
// Predetermine the list of function calls that need to be replaced.
using CallList = std::vector<const ast::CallExpression*>;
std::unordered_map<const ast::Function*, CallList> calls_to_replace;
std::vector<ast::Function*> functions_to_process;
// Build a list of functions that transitively reference any private or
// workgroup variables, or texture/sampler variables.
for (auto* func_ast : ctx.src->AST().Functions()) {
auto* func_sem = ctx.src->Sem().Get(func_ast);
bool needs_processing = false;
for (auto* var : func_sem->ReferencedModuleVariables()) {
if (var->StorageClass() == ast::StorageClass::kPrivate ||
var->StorageClass() == ast::StorageClass::kWorkgroup ||
var->StorageClass() == ast::StorageClass::kUniformConstant) {
needs_processing = true;
break;
}
}
if (needs_processing) {
functions_to_process.push_back(func_ast);
// Find all of the calls to this function that will need to be replaced.
for (auto* call : func_sem->CallSites()) {
auto* call_sem = ctx.src->Sem().Get(call);
calls_to_replace[call_sem->Stmt()->Function()].push_back(call);
}
}
}
if (auto* ident = address_of->expr()->As<ast::IdentifierExpression>()) {
ident_to_address_of[ident] = address_of;
}
}
for (auto* func_ast : functions_to_process) {
auto* func_sem = ctx.src->Sem().Get(func_ast);
bool is_entry_point = func_ast->IsEntryPoint();
// Map module-scope variables onto their function-scope replacement.
std::unordered_map<const sem::Variable*, Symbol> var_to_symbol;
for (auto* var : func_sem->ReferencedModuleVariables()) {
if (var->StorageClass() != ast::StorageClass::kPrivate &&
var->StorageClass() != ast::StorageClass::kWorkgroup &&
var->StorageClass() != ast::StorageClass::kUniformConstant) {
// Build a list of `&ident` expressions. We'll use this later to avoid
// generating expressions of the form `&*ident`, which break WGSL validation
// rules when this expression is passed to a function.
// TODO(jrprice): We should add support for bidirectional SEM tree traversal
// so that we can do this on the fly instead.
std::unordered_map<ast::IdentifierExpression*, ast::UnaryOpExpression*>
ident_to_address_of;
for (auto* node : ctx.src->ASTNodes().Objects()) {
auto* address_of = node->As<ast::UnaryOpExpression>();
if (!address_of || address_of->op() != ast::UnaryOp::kAddressOf) {
continue;
}
if (auto* ident = address_of->expr()->As<ast::IdentifierExpression>()) {
ident_to_address_of[ident] = address_of;
}
}
// This is the symbol for the variable that replaces the module-scope var.
auto new_var_symbol = ctx.dst->Sym();
for (auto* func_ast : functions_to_process) {
auto* func_sem = ctx.src->Sem().Get(func_ast);
bool is_entry_point = func_ast->IsEntryPoint();
auto* store_type = CreateASTTypeFor(ctx, var->Type()->UnwrapRef());
// Map module-scope variables onto their function-scope replacement.
std::unordered_map<const sem::Variable*, Symbol> var_to_symbol;
// Track whether the new variable is a pointer or not.
bool is_pointer = false;
// We aggregate all workgroup variables into a struct to avoid hitting
// MSL's limit for threadgroup memory arguments.
Symbol workgroup_parameter_symbol;
ast::StructMemberList workgroup_parameter_members;
auto workgroup_param = [&]() {
if (!workgroup_parameter_symbol.IsValid()) {
workgroup_parameter_symbol = ctx.dst->Sym();
}
return workgroup_parameter_symbol;
};
if (is_entry_point) {
if (store_type->is_handle()) {
// For a texture or sampler variable, redeclare it as an entry point
// parameter. Disable entry point parameter validation.
auto* disable_validation =
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(), ast::DisabledValidation::kEntryPointParameter);
auto decos = ctx.Clone(var->Declaration()->decorations());
decos.push_back(disable_validation);
auto* param = ctx.dst->Param(new_var_symbol, store_type, decos);
ctx.InsertFront(func_ast->params(), param);
} else {
if (var->StorageClass() == ast::StorageClass::kWorkgroup &&
ContainsMatrix(var->Type())) {
// Due to a bug in the MSL compiler, we use a threadgroup memory
// argument for any workgroup allocation that contains a matrix.
// See crbug.com/tint/938.
for (auto* var : func_sem->ReferencedModuleVariables()) {
if (var->StorageClass() != ast::StorageClass::kPrivate &&
var->StorageClass() != ast::StorageClass::kWorkgroup &&
var->StorageClass() != ast::StorageClass::kUniformConstant) {
continue;
}
// This is the symbol for the variable that replaces the module-scope
// var.
auto new_var_symbol = ctx.dst->Sym();
// Helper to create an AST node for the store type of the variable.
auto store_type = [&]() {
return CreateASTTypeFor(ctx, var->Type()->UnwrapRef());
};
// Track whether the new variable is a pointer or not.
bool is_pointer = false;
if (is_entry_point) {
if (var->Type()->UnwrapRef()->is_handle()) {
// For a texture or sampler variable, redeclare it as an entry point
// parameter. Disable entry point parameter validation.
auto* disable_validation =
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(),
ast::DisabledValidation::kEntryPointParameter);
auto* param_type =
ctx.dst->ty.pointer(store_type, var->StorageClass());
auto* param = ctx.dst->Param(new_var_symbol, param_type,
{disable_validation});
auto decos = ctx.Clone(var->Declaration()->decorations());
decos.push_back(disable_validation);
auto* param = ctx.dst->Param(new_var_symbol, store_type(), decos);
ctx.InsertFront(func_ast->params(), param);
} else if (var->StorageClass() == ast::StorageClass::kWorkgroup &&
ContainsMatrix(var->Type())) {
// Due to a bug in the MSL compiler, we use a threadgroup memory
// argument for any workgroup allocation that contains a matrix.
// See crbug.com/tint/938.
// TODO(jrprice): Do this for all other workgroup variables too.
// Create a member in the workgroup parameter struct.
auto member = ctx.Clone(var->Declaration()->symbol());
workgroup_parameter_members.push_back(
ctx.dst->Member(member, store_type()));
CloneStructTypes(var->Type()->UnwrapRef());
// Create a function-scope variable that is a pointer to the member.
auto* member_ptr = ctx.dst->AddressOf(ctx.dst->MemberAccessor(
ctx.dst->Deref(workgroup_param()), member));
auto* local_var =
ctx.dst->Const(new_var_symbol,
ctx.dst->ty.pointer(
store_type(), ast::StorageClass::kWorkgroup),
member_ptr);
ctx.InsertFront(func_ast->body()->statements(),
ctx.dst->Decl(local_var));
is_pointer = true;
} else {
// For any other private or workgroup variable, redeclare it at
@@ -164,83 +218,123 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
ast::DisabledValidation::kIgnoreStorageClass);
auto* constructor = ctx.Clone(var->Declaration()->constructor());
auto* local_var = ctx.dst->Var(
new_var_symbol, store_type, var->StorageClass(), constructor,
new_var_symbol, store_type(), var->StorageClass(), constructor,
ast::DecorationList{disable_validation});
ctx.InsertFront(func_ast->body()->statements(),
ctx.dst->Decl(local_var));
}
}
} else {
// For a regular function, redeclare the variable as a parameter.
// Use a pointer for non-handle types.
auto* param_type = store_type;
if (!store_type->is_handle()) {
param_type = ctx.dst->ty.pointer(param_type, var->StorageClass());
is_pointer = true;
}
ctx.InsertBack(func_ast->params(),
ctx.dst->Param(new_var_symbol, param_type));
}
} else {
// For a regular function, redeclare the variable as a parameter.
// Use a pointer for non-handle types.
auto* param_type = store_type();
ast::DecorationList attributes;
if (!param_type->is_handle()) {
param_type = ctx.dst->ty.pointer(param_type, var->StorageClass());
is_pointer = true;
// Replace all uses of the module-scope variable.
// For non-entry points, dereference non-handle pointer parameters.
for (auto* user : var->Users()) {
if (user->Stmt()->Function() == func_ast) {
ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
if (is_pointer) {
// If this identifier is used by an address-of operator, just remove
// the address-of instead of adding a deref, since we already have a
// pointer.
auto* ident = user->Declaration()->As<ast::IdentifierExpression>();
if (ident_to_address_of.count(ident)) {
ctx.Replace(ident_to_address_of[ident], expr);
continue;
// Disable validation of arguments passed to this pointer parameter,
// as we will sometimes pass pointers to struct members.
attributes.push_back(
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(),
ast::DisabledValidation::kIgnoreInvalidPointerArgument));
}
ctx.InsertBack(
func_ast->params(),
ctx.dst->Param(new_var_symbol, param_type, attributes));
}
// Replace all uses of the module-scope variable.
// For non-entry points, dereference non-handle pointer parameters.
for (auto* user : var->Users()) {
if (user->Stmt()->Function() == func_ast) {
ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
if (is_pointer) {
// If this identifier is used by an address-of operator, just
// remove the address-of instead of adding a deref, since we
// already have a pointer.
auto* ident =
user->Declaration()->As<ast::IdentifierExpression>();
if (ident_to_address_of.count(ident)) {
ctx.Replace(ident_to_address_of[ident], expr);
continue;
}
expr = ctx.dst->Deref(expr);
}
expr = ctx.dst->Deref(expr);
ctx.Replace(user->Declaration(), expr);
}
ctx.Replace(user->Declaration(), expr);
}
var_to_symbol[var] = new_var_symbol;
}
var_to_symbol[var] = new_var_symbol;
if (!workgroup_parameter_members.empty()) {
// Create the workgroup memory parameter.
// The parameter is a struct that contains members for each workgroup
// variable.
auto* str = ctx.dst->Structure(ctx.dst->Sym(),
std::move(workgroup_parameter_members));
auto* param_type = ctx.dst->ty.pointer(ctx.dst->ty.Of(str),
ast::StorageClass::kWorkgroup);
auto* disable_validation =
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(), ast::DisabledValidation::kEntryPointParameter);
auto* param =
ctx.dst->Param(workgroup_param(), param_type, {disable_validation});
ctx.InsertFront(func_ast->params(), param);
}
// Pass the variables as pointers to any functions that need them.
for (auto* call : calls_to_replace[func_ast]) {
auto* target = ctx.src->AST().Functions().Find(call->func()->symbol());
auto* target_sem = ctx.src->Sem().Get(target);
// Add new arguments for any variables that are needed by the callee.
// For entry points, pass non-handle types as pointers.
for (auto* target_var : target_sem->ReferencedModuleVariables()) {
bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
bool is_workgroup_matrix =
target_var->StorageClass() == ast::StorageClass::kWorkgroup &&
ContainsMatrix(target_var->Type());
if (target_var->StorageClass() == ast::StorageClass::kPrivate ||
target_var->StorageClass() == ast::StorageClass::kWorkgroup ||
target_var->StorageClass() ==
ast::StorageClass::kUniformConstant) {
ast::Expression* arg = ctx.dst->Expr(var_to_symbol[target_var]);
if (is_entry_point && !is_handle && !is_workgroup_matrix) {
arg = ctx.dst->AddressOf(arg);
}
ctx.InsertBack(call->params(), arg);
}
}
}
}
// Pass the variables as pointers to any functions that need them.
for (auto* call : calls_to_replace[func_ast]) {
auto* target = ctx.src->AST().Functions().Find(call->func()->symbol());
auto* target_sem = ctx.src->Sem().Get(target);
// Add new arguments for any variables that are needed by the callee.
// For entry points, pass non-handle types as pointers.
for (auto* target_var : target_sem->ReferencedModuleVariables()) {
bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
bool is_workgroup_matrix =
target_var->StorageClass() == ast::StorageClass::kWorkgroup &&
ContainsMatrix(target_var->Type());
if (target_var->StorageClass() == ast::StorageClass::kPrivate ||
target_var->StorageClass() == ast::StorageClass::kWorkgroup ||
target_var->StorageClass() == ast::StorageClass::kUniformConstant) {
ast::Expression* arg = ctx.dst->Expr(var_to_symbol[target_var]);
if (is_entry_point && !is_handle && !is_workgroup_matrix) {
arg = ctx.dst->AddressOf(arg);
}
ctx.InsertBack(call->params(), arg);
}
// Now remove all module-scope variables with these storage classes.
for (auto* var_ast : ctx.src->AST().GlobalVariables()) {
auto* var_sem = ctx.src->Sem().Get(var_ast);
if (var_sem->StorageClass() == ast::StorageClass::kPrivate ||
var_sem->StorageClass() == ast::StorageClass::kWorkgroup ||
var_sem->StorageClass() == ast::StorageClass::kUniformConstant) {
ctx.Remove(ctx.src->AST().GlobalDeclarations(), var_ast);
}
}
}
// Now remove all module-scope variables with these storage classes.
for (auto* var_ast : ctx.src->AST().GlobalVariables()) {
auto* var_sem = ctx.src->Sem().Get(var_ast);
if (var_sem->StorageClass() == ast::StorageClass::kPrivate ||
var_sem->StorageClass() == ast::StorageClass::kWorkgroup ||
var_sem->StorageClass() == ast::StorageClass::kUniformConstant) {
ctx.Remove(ctx.src->AST().GlobalDeclarations(), var_ast);
}
}
private:
std::unordered_set<const sem::Struct*> cloned_structs_;
};
ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default;
void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
const DataMap&,
DataMap&) {
State state{ctx};
state.Process();
ctx.Clone();
}

View File

@@ -74,6 +74,8 @@ class ModuleScopeVarToEntryPointParam
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
struct State;
};
} // namespace transform

View File

@@ -78,12 +78,12 @@ fn main() {
fn no_uses() {
}
fn bar(a : f32, b : f32, tint_symbol : ptr<private, f32>, tint_symbol_1 : ptr<workgroup, f32>) {
fn bar(a : f32, b : f32, [[internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol : ptr<private, f32>, [[internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_1 : ptr<workgroup, f32>) {
*(tint_symbol) = a;
*(tint_symbol_1) = b;
}
fn foo(a : f32, tint_symbol_2 : ptr<private, f32>, tint_symbol_3 : ptr<workgroup, f32>) {
fn foo(a : f32, [[internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_2 : ptr<private, f32>, [[internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_3 : ptr<workgroup, f32>) {
let b : f32 = 2.0;
bar(a, b, tint_symbol_2, tint_symbol_3);
no_uses();
@@ -181,7 +181,7 @@ fn bar(p : ptr<private, f32>) {
*(p) = 0.0;
}
fn foo(tint_symbol : ptr<private, f32>) {
fn foo([[internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol : ptr<private, f32>) {
bar(tint_symbol);
}
@@ -340,8 +340,13 @@ fn main() {
)";
auto* expect = R"(
struct tint_symbol_2 {
m : mat2x2<f32>;
};
[[stage(compute), workgroup_size(1)]]
fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol : ptr<workgroup, mat2x2<f32>>) {
fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol_1 : ptr<workgroup, tint_symbol_2>) {
let tint_symbol : ptr<workgroup, mat2x2<f32>> = &((*(tint_symbol_1)).m);
let x = *(tint_symbol);
}
)";
@@ -376,8 +381,13 @@ struct S2 {
s : S1;
};
struct tint_symbol_2 {
m : array<S2, 4u>;
};
[[stage(compute), workgroup_size(1)]]
fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol : ptr<workgroup, array<S2, 4u>>) {
fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol_1 : ptr<workgroup, tint_symbol_2>) {
let tint_symbol : ptr<workgroup, array<S2, 4u>> = &((*(tint_symbol_1)).m);
let x = *(tint_symbol);
}
)";
@@ -387,6 +397,49 @@ fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol : pt
EXPECT_EQ(expect, str(got));
}
// Test that we do not duplicate a struct type used by multiple workgroup
// variables that are promoted to threadgroup memory arguments.
TEST_F(ModuleScopeVarToEntryPointParamTest, DuplicateThreadgroupArgumentTypes) {
auto* src = R"(
struct S {
m : mat2x2<f32>;
};
var<workgroup> a : S;
var<workgroup> b : S;
[[stage(compute), workgroup_size(1)]]
fn main() {
let x = a;
let y = b;
}
)";
auto* expect = R"(
struct S {
m : mat2x2<f32>;
};
struct tint_symbol_3 {
a : S;
b : S;
};
[[stage(compute), workgroup_size(1)]]
fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol_1 : ptr<workgroup, tint_symbol_3>) {
let tint_symbol : ptr<workgroup, S> = &((*(tint_symbol_1)).a);
let tint_symbol_2 : ptr<workgroup, S> = &((*(tint_symbol_1)).b);
let x = *(tint_symbol);
let y = *(tint_symbol_2);
}
)";
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, EmtpyModule) {
auto* src = "";

View File

@@ -142,6 +142,10 @@ TEST_F(MslGeneratorImplTest, WorkgroupMatrix) {
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
using namespace metal;
struct tint_symbol_3 {
float2x2 m;
};
void comp_main_inner(uint local_invocation_index, threadgroup float2x2* const tint_symbol) {
{
*(tint_symbol) = float2x2();
@@ -150,8 +154,8 @@ void comp_main_inner(uint local_invocation_index, threadgroup float2x2* const ti
float2x2 const x = *(tint_symbol);
}
kernel void comp_main(threadgroup float2x2* tint_symbol_1 [[threadgroup(0)]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
comp_main_inner(local_invocation_index, tint_symbol_1);
kernel void comp_main(threadgroup tint_symbol_3* tint_symbol_2 [[threadgroup(0)]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
comp_main_inner(local_invocation_index, &((*(tint_symbol_2)).m));
return;
}
@@ -178,6 +182,9 @@ using namespace metal;
struct tint_array_wrapper {
float2x2 arr[4];
};
struct tint_symbol_3 {
tint_array_wrapper m;
};
void comp_main_inner(uint local_invocation_index, threadgroup tint_array_wrapper* const tint_symbol) {
for(uint idx = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
@@ -188,8 +195,8 @@ void comp_main_inner(uint local_invocation_index, threadgroup tint_array_wrapper
tint_array_wrapper const x = *(tint_symbol);
}
kernel void comp_main(threadgroup tint_array_wrapper* tint_symbol_1 [[threadgroup(0)]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
comp_main_inner(local_invocation_index, tint_symbol_1);
kernel void comp_main(threadgroup tint_symbol_3* tint_symbol_2 [[threadgroup(0)]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
comp_main_inner(local_invocation_index, &((*(tint_symbol_2)).m));
return;
}
@@ -227,6 +234,9 @@ struct S1 {
struct S2 {
S1 s;
};
struct tint_symbol_4 {
S2 s;
};
void comp_main_inner(uint local_invocation_index, threadgroup S2* const tint_symbol_1) {
{
@@ -237,8 +247,8 @@ void comp_main_inner(uint local_invocation_index, threadgroup S2* const tint_sym
S2 const x = *(tint_symbol_1);
}
kernel void comp_main(threadgroup S2* tint_symbol_2 [[threadgroup(0)]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
comp_main_inner(local_invocation_index, tint_symbol_2);
kernel void comp_main(threadgroup tint_symbol_4* tint_symbol_3 [[threadgroup(0)]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
comp_main_inner(local_invocation_index, &((*(tint_symbol_3)).s));
return;
}
@@ -291,6 +301,22 @@ TEST_F(MslGeneratorImplTest, WorkgroupMatrix_Multiples) {
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
using namespace metal;
struct tint_symbol_7 {
float2x2 m1;
float2x3 m2;
float2x4 m3;
};
struct tint_symbol_15 {
float3x2 m4;
float3x3 m5;
float3x4 m6;
};
struct tint_symbol_23 {
float4x2 m7;
float4x3 m8;
float4x4 m9;
};
void main1_inner(uint local_invocation_index, threadgroup float2x2* const tint_symbol, threadgroup float2x3* const tint_symbol_1, threadgroup float2x4* const tint_symbol_2) {
{
*(tint_symbol) = float2x2();
@@ -303,42 +329,42 @@ void main1_inner(uint local_invocation_index, threadgroup float2x2* const tint_s
float2x4 const a3 = *(tint_symbol_2);
}
kernel void main1(threadgroup float2x2* tint_symbol_3 [[threadgroup(0)]], threadgroup float2x3* tint_symbol_4 [[threadgroup(1)]], threadgroup float2x4* tint_symbol_5 [[threadgroup(2)]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
main1_inner(local_invocation_index, tint_symbol_3, tint_symbol_4, tint_symbol_5);
kernel void main1(threadgroup tint_symbol_7* tint_symbol_4 [[threadgroup(0)]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
main1_inner(local_invocation_index, &((*(tint_symbol_4)).m1), &((*(tint_symbol_4)).m2), &((*(tint_symbol_4)).m3));
return;
}
void main2_inner(uint local_invocation_index_1, threadgroup float3x2* const tint_symbol_6, threadgroup float3x3* const tint_symbol_7, threadgroup float3x4* const tint_symbol_8) {
void main2_inner(uint local_invocation_index_1, threadgroup float3x2* const tint_symbol_8, threadgroup float3x3* const tint_symbol_9, threadgroup float3x4* const tint_symbol_10) {
{
*(tint_symbol_6) = float3x2();
*(tint_symbol_7) = float3x3();
*(tint_symbol_8) = float3x4();
*(tint_symbol_8) = float3x2();
*(tint_symbol_9) = float3x3();
*(tint_symbol_10) = float3x4();
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float3x2 const a1 = *(tint_symbol_6);
float3x3 const a2 = *(tint_symbol_7);
float3x4 const a3 = *(tint_symbol_8);
float3x2 const a1 = *(tint_symbol_8);
float3x3 const a2 = *(tint_symbol_9);
float3x4 const a3 = *(tint_symbol_10);
}
kernel void main2(threadgroup float3x2* tint_symbol_9 [[threadgroup(0)]], threadgroup float3x3* tint_symbol_10 [[threadgroup(1)]], threadgroup float3x4* tint_symbol_11 [[threadgroup(2)]], uint local_invocation_index_1 [[thread_index_in_threadgroup]]) {
main2_inner(local_invocation_index_1, tint_symbol_9, tint_symbol_10, tint_symbol_11);
kernel void main2(threadgroup tint_symbol_15* tint_symbol_12 [[threadgroup(0)]], uint local_invocation_index_1 [[thread_index_in_threadgroup]]) {
main2_inner(local_invocation_index_1, &((*(tint_symbol_12)).m4), &((*(tint_symbol_12)).m5), &((*(tint_symbol_12)).m6));
return;
}
void main3_inner(uint local_invocation_index_2, threadgroup float4x2* const tint_symbol_12, threadgroup float4x3* const tint_symbol_13, threadgroup float4x4* const tint_symbol_14) {
void main3_inner(uint local_invocation_index_2, threadgroup float4x2* const tint_symbol_16, threadgroup float4x3* const tint_symbol_17, threadgroup float4x4* const tint_symbol_18) {
{
*(tint_symbol_12) = float4x2();
*(tint_symbol_13) = float4x3();
*(tint_symbol_14) = float4x4();
*(tint_symbol_16) = float4x2();
*(tint_symbol_17) = float4x3();
*(tint_symbol_18) = float4x4();
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float4x2 const a1 = *(tint_symbol_12);
float4x3 const a2 = *(tint_symbol_13);
float4x4 const a3 = *(tint_symbol_14);
float4x2 const a1 = *(tint_symbol_16);
float4x3 const a2 = *(tint_symbol_17);
float4x4 const a3 = *(tint_symbol_18);
}
kernel void main3(threadgroup float4x2* tint_symbol_15 [[threadgroup(0)]], threadgroup float4x3* tint_symbol_16 [[threadgroup(1)]], threadgroup float4x4* tint_symbol_17 [[threadgroup(2)]], uint local_invocation_index_2 [[thread_index_in_threadgroup]]) {
main3_inner(local_invocation_index_2, tint_symbol_15, tint_symbol_16, tint_symbol_17);
kernel void main3(threadgroup tint_symbol_23* tint_symbol_20 [[threadgroup(0)]], uint local_invocation_index_2 [[thread_index_in_threadgroup]]) {
main3_inner(local_invocation_index_2, &((*(tint_symbol_20)).m7), &((*(tint_symbol_20)).m8), &((*(tint_symbol_20)).m9));
return;
}
@@ -353,18 +379,12 @@ kernel void main4_no_usages() {
ASSERT_TRUE(allocations.count("main2"));
ASSERT_TRUE(allocations.count("main3"));
EXPECT_EQ(allocations.count("main4_no_usages"), 0u);
ASSERT_EQ(allocations["main1"].size(), 3u);
EXPECT_EQ(allocations["main1"][0], 2u * 2u * sizeof(float));
EXPECT_EQ(allocations["main1"][1], 2u * 4u * sizeof(float));
EXPECT_EQ(allocations["main1"][2], 2u * 4u * sizeof(float));
ASSERT_EQ(allocations["main2"].size(), 3u);
EXPECT_EQ(allocations["main2"][0], 3u * 2u * sizeof(float));
EXPECT_EQ(allocations["main2"][1], 3u * 4u * sizeof(float));
EXPECT_EQ(allocations["main2"][2], 3u * 4u * sizeof(float));
ASSERT_EQ(allocations["main3"].size(), 3u);
EXPECT_EQ(allocations["main3"][0], 4u * 2u * sizeof(float));
EXPECT_EQ(allocations["main3"][1], 4u * 4u * sizeof(float));
EXPECT_EQ(allocations["main3"][2], 4u * 4u * sizeof(float));
ASSERT_EQ(allocations["main1"].size(), 1u);
EXPECT_EQ(allocations["main1"][0], 20u * sizeof(float));
ASSERT_EQ(allocations["main2"].size(), 1u);
EXPECT_EQ(allocations["main2"][0], 32u * sizeof(float));
ASSERT_EQ(allocations["main3"].size(), 1u);
EXPECT_EQ(allocations["main3"][0], 40u * sizeof(float));
}
} // namespace