writer/msl: Avoid generating unnecessary pointers

When moving private and workgroup variables into the entry point,
generate pointers to pass as arguments to sub-functions on demand,
instead of upfront. This removes a bunch of unnecessary dereferences
for accesses inside the entry point, and one function variable.

Change-Id: I7d1aabdf14eae33b569b3316dfc0f9fbd288131e
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/54300
Auto-Submit: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: James Price <jrprice@google.com>
This commit is contained in:
James Price
2021-06-11 12:29:56 +00:00
committed by Tint LUCI CQ
parent 3628949efc
commit 2940c7002c
16 changed files with 114 additions and 157 deletions

View File

@@ -89,8 +89,7 @@ void Msl::HandlePrivateAndWorkgroupVariables(CloneContext& ctx) const {
// [[stage(compute)]]
// fn main() {
// var<private> v : f32 = 2.0;
// let v_ptr : ptr<private, f32> = &f32;
// foo(v_ptr);
// foo(&v);
// }
// ```
@@ -127,6 +126,7 @@ void Msl::HandlePrivateAndWorkgroupVariables(CloneContext& ctx) const {
for (auto* func_ast : functions_to_process) {
auto* func_sem = ctx.src->Sem().Get(func_ast);
bool is_entry_point = func_ast->IsEntryPoint();
// Map module-scope variables onto their function-scope replacement.
std::unordered_map<const sem::Variable*, Symbol> var_to_symbol;
@@ -137,12 +137,12 @@ void Msl::HandlePrivateAndWorkgroupVariables(CloneContext& ctx) const {
continue;
}
// This is the symbol for the pointer that replaces the module-scope var.
// This is the symbol for the variable that replaces the module-scope var.
auto new_var_symbol = ctx.dst->Sym();
auto* store_type = CreateASTTypeFor(&ctx, var->Type()->UnwrapRef());
if (func_ast->IsEntryPoint()) {
if (is_entry_point) {
// For an entry point, redeclare the variable at function-scope.
// Disable storage class validation on this variable.
auto* disable_validation =
@@ -151,16 +151,10 @@ void Msl::HandlePrivateAndWorkgroupVariables(CloneContext& ctx) const {
ast::DisabledValidation::kFunctionVarStorageClass);
auto* constructor = ctx.Clone(var->Declaration()->constructor());
auto* local_var =
ctx.dst->Var(ctx.dst->Sym(), store_type, var->StorageClass(),
ctx.dst->Var(new_var_symbol, store_type, var->StorageClass(),
constructor, ast::DecorationList{disable_validation});
ctx.InsertBefore(func_ast->body()->statements(),
*func_ast->body()->begin(), ctx.dst->Decl(local_var));
// Now take the address of the variable.
auto* ptr = ctx.dst->Const(new_var_symbol, nullptr,
ctx.dst->AddressOf(local_var));
ctx.InsertBefore(func_ast->body()->statements(),
*func_ast->body()->begin(), ctx.dst->Decl(ptr));
} else {
// For a regular function, redeclare the variable as a pointer function
// parameter.
@@ -169,18 +163,22 @@ void Msl::HandlePrivateAndWorkgroupVariables(CloneContext& ctx) const {
ctx.dst->Param(new_var_symbol, ptr_type));
}
// Replace all uses of the module-scope variable with the pointer
// replacement (dereferenced).
// Replace all uses of the module-scope variable.
for (auto* user : var->Users()) {
if (user->Stmt()->Function() == func_ast) {
ctx.Replace(user->Declaration(), ctx.dst->Deref(new_var_symbol));
ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
if (!is_entry_point) {
// For non-entry points, dereference the pointer argument.
expr = ctx.dst->Deref(expr);
}
ctx.Replace(user->Declaration(), expr);
}
}
var_to_symbol[var] = new_var_symbol;
}
// Pass the pointers through to any functions that need them.
// Pass the variables as pointers to any functions that need them.
for (auto* call : calls_to_replace[func_ast]) {
auto* target = ctx.src->AST().Functions().Find(call->func()->symbol());
auto* target_sem = ctx.src->Sem().Get(target);
@@ -189,8 +187,12 @@ void Msl::HandlePrivateAndWorkgroupVariables(CloneContext& ctx) const {
for (auto* target_var : target_sem->ReferencedModuleVariables()) {
if (target_var->StorageClass() == ast::StorageClass::kPrivate ||
target_var->StorageClass() == ast::StorageClass::kWorkgroup) {
ctx.InsertBack(call->params(),
ctx.dst->Expr(var_to_symbol[target_var]));
ast::Expression* arg = ctx.dst->Expr(var_to_symbol[target_var]);
if (is_entry_point) {
// For entry points, pass the address of the variable.
arg = ctx.dst->AddressOf(arg);
}
ctx.InsertBack(call->params(), arg);
}
}
}

View File

@@ -36,11 +36,9 @@ fn main() {
auto* expect = R"(
[[stage(compute)]]
fn main() {
[[internal(disable_validation__function_var_storage_class)]] var<workgroup> tint_symbol_1 : f32;
let tint_symbol = &(tint_symbol_1);
[[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol_3 : f32;
let tint_symbol_2 = &(tint_symbol_3);
*(tint_symbol) = *(tint_symbol_2);
[[internal(disable_validation__function_var_storage_class)]] var<workgroup> tint_symbol : f32;
[[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol_1 : f32;
tint_symbol = tint_symbol_1;
}
)";
@@ -91,11 +89,9 @@ fn foo(a : f32, tint_symbol_2 : ptr<private, f32>, tint_symbol_3 : ptr<workgroup
[[stage(compute)]]
fn main() {
[[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol_5 : f32;
let tint_symbol_4 = &(tint_symbol_5);
[[internal(disable_validation__function_var_storage_class)]] var<workgroup> tint_symbol_7 : f32;
let tint_symbol_6 = &(tint_symbol_7);
foo(1.0, tint_symbol_4, tint_symbol_6);
[[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol_4 : f32;
[[internal(disable_validation__function_var_storage_class)]] var<workgroup> tint_symbol_5 : f32;
foo(1.0, &(tint_symbol_4), &(tint_symbol_5));
}
)";
@@ -118,11 +114,9 @@ fn main() {
auto* expect = R"(
[[stage(compute)]]
fn main() {
[[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol_1 : f32 = 1.0;
let tint_symbol = &(tint_symbol_1);
[[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol_3 : f32 = f32();
let tint_symbol_2 = &(tint_symbol_3);
let x : f32 = (*(tint_symbol) + *(tint_symbol_2));
[[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol : f32 = 1.0;
[[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol_1 : f32 = f32();
let x : f32 = (tint_symbol + tint_symbol_1);
}
)";
@@ -148,12 +142,10 @@ fn main() {
auto* expect = R"(
[[stage(compute)]]
fn main() {
[[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol_1 : f32;
let tint_symbol = &(tint_symbol_1);
[[internal(disable_validation__function_var_storage_class)]] var<workgroup> tint_symbol_3 : f32;
let tint_symbol_2 = &(tint_symbol_3);
let p_ptr : ptr<private, f32> = &(*(tint_symbol));
let w_ptr : ptr<workgroup, f32> = &(*(tint_symbol_2));
[[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol : f32;
[[internal(disable_validation__function_var_storage_class)]] var<workgroup> tint_symbol_1 : f32;
let p_ptr : ptr<private, f32> = &(tint_symbol);
let w_ptr : ptr<workgroup, f32> = &(tint_symbol_1);
let x : f32 = (*(p_ptr) + *(w_ptr));
*(p_ptr) = x;
}