Revert "msl: Use a struct for threadgroup memory arguments"

This reverts commit af8cd3b7f5.

Reason for revert: breaking roll into Dawn.

Original change's description:
> msl: Use a struct for threadgroup memory arguments
>
> MSL has a limit on the number of threadgroup memory arguments, so use
> a struct to support an arbitrary number of workgroup variables.
>
> Bug: tint:938
> Change-Id: I40e4a8d99bc4ae074010479a56e13e2e0acdded3
> Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/64380
> Kokoro: Kokoro <noreply+kokoro@google.com>
> Auto-Submit: James Price <jrprice@google.com>
> Reviewed-by: Ben Clayton <bclayton@google.com>
> Commit-Queue: James Price <jrprice@google.com>

TBR=bclayton@google.com,jrprice@google.com,noreply+kokoro@google.com,tint-scoped@luci-project-accounts.iam.gserviceaccount.com

Change-Id: I58a07c4ab7e92bda205e2bbbab41e0b347aeb1e8
No-Presubmit: true
No-Tree-Checks: true
No-Try: true
Bug: tint:938
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/65162
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Kokoro: Corentin Wallez <cwallez@chromium.org>
This commit is contained in:
Corentin Wallez
2021-09-27 19:00:15 +00:00
committed by Tint LUCI CQ
parent af8cd3b7f5
commit 40ef4a8269
12 changed files with 62 additions and 1856 deletions

View File

@@ -47,27 +47,6 @@ bool ContainsMatrix(const sem::Type* type) {
}
return false;
}
// Clone any struct types that are contained in `ty` (including `ty` itself),
// and add it to the global declarations now, so that they precede new global
// declarations that need to reference them.
void CloneStructTypes(const sem::Type* ty, CloneContext& ctx) {
if (auto* str = ty->As<sem::Struct>()) {
// Recurse into members.
for (auto* member : str->Members()) {
CloneStructTypes(member->Type(), ctx);
}
// Clone the struct and add it to the global declaration list.
// Remove the old declaration.
auto* ast_str = str->Declaration();
ctx.dst->AST().AddTypeDecl(ctx.Clone(const_cast<ast::Struct*>(ast_str)));
ctx.Remove(ctx.src->AST().GlobalDeclarations(), ast_str);
} else if (auto* arr = ty->As<sem::Array>()) {
CloneStructTypes(arr->ElemType(), ctx);
}
}
} // namespace
ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
@@ -133,17 +112,6 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
// Map module-scope variables onto their function-scope replacement.
std::unordered_map<const sem::Variable*, Symbol> var_to_symbol;
// We aggregate all workgroup variables into a struct to avoid hitting MSL's
// limit for threadgroup memory arguments.
Symbol workgroup_parameter_symbol;
ast::StructMemberList workgroup_parameter_members;
auto workgroup_param = [&]() {
if (!workgroup_parameter_symbol.IsValid()) {
workgroup_parameter_symbol = ctx.dst->Sym();
}
return workgroup_parameter_symbol;
};
for (auto* var : func_sem->ReferencedModuleVariables()) {
if (var->StorageClass() != ast::StorageClass::kPrivate &&
var->StorageClass() != ast::StorageClass::kWorkgroup &&
@@ -154,16 +122,13 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
// This is the symbol for the variable that replaces the module-scope var.
auto new_var_symbol = ctx.dst->Sym();
// Helper to create an AST node for the store type of the variable.
auto store_type = [&]() {
return CreateASTTypeFor(ctx, var->Type()->UnwrapRef());
};
auto* store_type = CreateASTTypeFor(ctx, var->Type()->UnwrapRef());
// Track whether the new variable is a pointer or not.
bool is_pointer = false;
if (is_entry_point) {
if (var->Type()->UnwrapRef()->is_handle()) {
if (store_type->is_handle()) {
// For a texture or sampler variable, redeclare it as an entry point
// parameter. Disable entry point parameter validation.
auto* disable_validation =
@@ -171,7 +136,7 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
ctx.dst->ID(), ast::DisabledValidation::kEntryPointParameter);
auto decos = ctx.Clone(var->Declaration()->decorations());
decos.push_back(disable_validation);
auto* param = ctx.dst->Param(new_var_symbol, store_type(), decos);
auto* param = ctx.dst->Param(new_var_symbol, store_type, decos);
ctx.InsertFront(func_ast->params(), param);
} else {
if (var->StorageClass() == ast::StorageClass::kWorkgroup &&
@@ -179,24 +144,15 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
// Due to a bug in the MSL compiler, we use a threadgroup memory
// argument for any workgroup allocation that contains a matrix.
// See crbug.com/tint/938.
// TODO(jrprice): Do this for all other workgroup variables too.
// Create a member in the workgroup parameter struct.
auto member = ctx.Clone(var->Declaration()->symbol());
workgroup_parameter_members.push_back(
ctx.dst->Member(member, store_type()));
CloneStructTypes(var->Type()->UnwrapRef(), ctx);
// Create a function-scope variable that is a pointer to the member.
auto* member_ptr = ctx.dst->AddressOf(ctx.dst->MemberAccessor(
ctx.dst->Deref(workgroup_param()), member));
auto* local_var =
ctx.dst->Const(new_var_symbol,
ctx.dst->ty.pointer(
store_type(), ast::StorageClass::kWorkgroup),
member_ptr);
ctx.InsertFront(func_ast->body()->statements(),
ctx.dst->Decl(local_var));
auto* disable_validation =
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(),
ast::DisabledValidation::kEntryPointParameter);
auto* param_type =
ctx.dst->ty.pointer(store_type, var->StorageClass());
auto* param = ctx.dst->Param(new_var_symbol, param_type,
{disable_validation});
ctx.InsertFront(func_ast->params(), param);
is_pointer = true;
} else {
// For any other private or workgroup variable, redeclare it at
@@ -208,7 +164,7 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
ast::DisabledValidation::kIgnoreStorageClass);
auto* constructor = ctx.Clone(var->Declaration()->constructor());
auto* local_var = ctx.dst->Var(
new_var_symbol, store_type(), var->StorageClass(), constructor,
new_var_symbol, store_type, var->StorageClass(), constructor,
ast::DecorationList{disable_validation});
ctx.InsertFront(func_ast->body()->statements(),
ctx.dst->Decl(local_var));
@@ -217,21 +173,13 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
} else {
// For a regular function, redeclare the variable as a parameter.
// Use a pointer for non-handle types.
auto* param_type = store_type();
ast::DecorationList attributes;
if (!param_type->is_handle()) {
auto* param_type = store_type;
if (!store_type->is_handle()) {
param_type = ctx.dst->ty.pointer(param_type, var->StorageClass());
is_pointer = true;
// Disable validation of arguments passed to this pointer parameter,
// as we will sometimes pass pointers to struct members.
attributes.push_back(
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(),
ast::DisabledValidation::kIgnoreInvalidPointerArgument));
}
ctx.InsertBack(func_ast->params(),
ctx.dst->Param(new_var_symbol, param_type, attributes));
ctx.dst->Param(new_var_symbol, param_type));
}
// Replace all uses of the module-scope variable.
@@ -258,22 +206,6 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
var_to_symbol[var] = new_var_symbol;
}
if (!workgroup_parameter_members.empty()) {
// Create the workgroup memory parameter.
// The parameter is a struct that contains members for each workgroup
// variable.
auto* str = ctx.dst->Structure(ctx.dst->Sym(),
std::move(workgroup_parameter_members));
auto* param_type = ctx.dst->ty.pointer(ctx.dst->ty.Of(str),
ast::StorageClass::kWorkgroup);
auto* disable_validation =
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(), ast::DisabledValidation::kEntryPointParameter);
auto* param =
ctx.dst->Param(workgroup_param(), param_type, {disable_validation});
ctx.InsertFront(func_ast->params(), param);
}
// Pass the variables as pointers to any functions that need them.
for (auto* call : calls_to_replace[func_ast]) {
auto* target = ctx.src->AST().Functions().Find(call->func()->symbol());

View File

@@ -78,12 +78,12 @@ fn main() {
fn no_uses() {
}
fn bar(a : f32, b : f32, [[internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol : ptr<private, f32>, [[internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_1 : ptr<workgroup, f32>) {
fn bar(a : f32, b : f32, tint_symbol : ptr<private, f32>, tint_symbol_1 : ptr<workgroup, f32>) {
*(tint_symbol) = a;
*(tint_symbol_1) = b;
}
fn foo(a : f32, [[internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_2 : ptr<private, f32>, [[internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_3 : ptr<workgroup, f32>) {
fn foo(a : f32, tint_symbol_2 : ptr<private, f32>, tint_symbol_3 : ptr<workgroup, f32>) {
let b : f32 = 2.0;
bar(a, b, tint_symbol_2, tint_symbol_3);
no_uses();
@@ -181,7 +181,7 @@ fn bar(p : ptr<private, f32>) {
*(p) = 0.0;
}
fn foo([[internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol : ptr<private, f32>) {
fn foo(tint_symbol : ptr<private, f32>) {
bar(tint_symbol);
}
@@ -340,13 +340,8 @@ fn main() {
)";
auto* expect = R"(
struct tint_symbol_2 {
m : mat2x2<f32>;
};
[[stage(compute), workgroup_size(1)]]
fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol_1 : ptr<workgroup, tint_symbol_2>) {
let tint_symbol : ptr<workgroup, mat2x2<f32>> = &((*(tint_symbol_1)).m);
fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol : ptr<workgroup, mat2x2<f32>>) {
let x = *(tint_symbol);
}
)";
@@ -381,13 +376,8 @@ struct S2 {
s : S1;
};
struct tint_symbol_2 {
m : array<S2, 4u>;
};
[[stage(compute), workgroup_size(1)]]
fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol_1 : ptr<workgroup, tint_symbol_2>) {
let tint_symbol : ptr<workgroup, array<S2, 4u>> = &((*(tint_symbol_1)).m);
fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol : ptr<workgroup, array<S2, 4u>>) {
let x = *(tint_symbol);
}
)";