msl: Handle workgroup matrix allocations

Use a threadgroup memory argument for any workgroup variable that
contains a matrix.

The generator now provides a list of threadgroup memory arguments for
each entry point, so that the runtime knows how many bytes to allocate
for each argument.

Bug: tint:938
Change-Id: Ia4af33cd6a44c4f74258793443eb737c2931f5eb
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/64042
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: James Price <jrprice@google.com>
This commit is contained in:
James Price
2021-09-13 19:56:01 +00:00
committed by Tint LUCI CQ
parent de767b1842
commit acaecab29d
8 changed files with 405 additions and 206 deletions

View File

@@ -15,6 +15,7 @@
#include "src/transform/module_scope_var_to_entry_point_param.h"
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
@@ -29,6 +30,24 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::ModuleScopeVarToEntryPointParam);
namespace tint {
namespace transform {
namespace {
// Returns `true` if `type` is or contains a matrix type.
bool ContainsMatrix(const sem::Type* type) {
type = type->UnwrapRef();
if (type->Is<sem::Matrix>()) {
return true;
} else if (auto* ary = type->As<sem::Array>()) {
return ContainsMatrix(ary->ElemType());
} else if (auto* str = type->As<sem::Struct>()) {
for (auto* member : str->Members()) {
if (ContainsMatrix(member->Type())) {
return true;
}
}
}
return false;
}
} // namespace
ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
@@ -105,6 +124,9 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
auto* store_type = CreateASTTypeFor(ctx, var->Type()->UnwrapRef());
// Track whether the new variable is a pointer or not.
bool is_pointer = false;
if (is_entry_point) {
if (store_type->is_handle()) {
// For a texture or sampler variable, redeclare it as an entry point
@@ -117,17 +139,36 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
auto* param = ctx.dst->Param(new_var_symbol, store_type, decos);
ctx.InsertFront(func_ast->params(), param);
} else {
// For a private or workgroup variable, redeclare it at function
// scope. Disable storage class validation on this variable.
auto* disable_validation =
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(), ast::DisabledValidation::kIgnoreStorageClass);
auto* constructor = ctx.Clone(var->Declaration()->constructor());
auto* local_var = ctx.dst->Var(
new_var_symbol, store_type, var->StorageClass(), constructor,
ast::DecorationList{disable_validation});
ctx.InsertFront(func_ast->body()->statements(),
ctx.dst->Decl(local_var));
if (var->StorageClass() == ast::StorageClass::kWorkgroup &&
ContainsMatrix(var->Type())) {
// Due to a bug in the MSL compiler, we use a threadgroup memory
// argument for any workgroup allocation that contains a matrix.
// See crbug.com/tint/938.
auto* disable_validation =
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(),
ast::DisabledValidation::kEntryPointParameter);
auto* param_type =
ctx.dst->ty.pointer(store_type, var->StorageClass());
auto* param = ctx.dst->Param(new_var_symbol, param_type,
{disable_validation});
ctx.InsertFront(func_ast->params(), param);
is_pointer = true;
} else {
// For any other private or workgroup variable, redeclare it at
// function scope. Disable storage class validation on this
// variable.
auto* disable_validation =
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(),
ast::DisabledValidation::kIgnoreStorageClass);
auto* constructor = ctx.Clone(var->Declaration()->constructor());
auto* local_var = ctx.dst->Var(
new_var_symbol, store_type, var->StorageClass(), constructor,
ast::DecorationList{disable_validation});
ctx.InsertFront(func_ast->body()->statements(),
ctx.dst->Decl(local_var));
}
}
} else {
// For a regular function, redeclare the variable as a parameter.
@@ -135,6 +176,7 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
auto* param_type = store_type;
if (!store_type->is_handle()) {
param_type = ctx.dst->ty.pointer(param_type, var->StorageClass());
is_pointer = true;
}
ctx.InsertBack(func_ast->params(),
ctx.dst->Param(new_var_symbol, param_type));
@@ -145,7 +187,7 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
for (auto* user : var->Users()) {
if (user->Stmt()->Function() == func_ast) {
ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
if (!is_entry_point && !store_type->is_handle()) {
if (is_pointer) {
// If this identifier is used by an address-of operator, just remove
// the address-of instead of adding a deref, since we already have a
// pointer.
@@ -172,11 +214,15 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
// Add new arguments for any variables that are needed by the callee.
// For entry points, pass non-handle types as pointers.
for (auto* target_var : target_sem->ReferencedModuleVariables()) {
bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
bool is_workgroup_matrix =
target_var->StorageClass() == ast::StorageClass::kWorkgroup &&
ContainsMatrix(target_var->Type());
if (target_var->StorageClass() == ast::StorageClass::kPrivate ||
target_var->StorageClass() == ast::StorageClass::kWorkgroup ||
target_var->StorageClass() == ast::StorageClass::kUniformConstant) {
ast::Expression* arg = ctx.dst->Expr(var_to_symbol[target_var]);
if (is_entry_point && !target_var->Type()->UnwrapRef()->is_handle()) {
if (is_entry_point && !is_handle && !is_workgroup_matrix) {
arg = ctx.dst->AddressOf(arg);
}
ctx.InsertBack(call->params(), arg);

View File

@@ -329,6 +329,64 @@ fn main([[group(0), binding(0), internal(disable_validation__entry_point_paramet
EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Matrix) {
auto* src = R"(
var<workgroup> m : mat2x2<f32>;
[[stage(compute), workgroup_size(1)]]
fn main() {
let x = m;
}
)";
auto* expect = R"(
[[stage(compute), workgroup_size(1)]]
fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol : ptr<workgroup, mat2x2<f32>>) {
let x = *(tint_symbol);
}
)";
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, NestedMatrix) {
auto* src = R"(
struct S1 {
m : mat2x2<f32>;
};
struct S2 {
s : S1;
};
var<workgroup> m : array<S2, 4>;
[[stage(compute), workgroup_size(1)]]
fn main() {
let x = m;
}
)";
auto* expect = R"(
struct S1 {
m : mat2x2<f32>;
};
struct S2 {
s : S1;
};
[[stage(compute), workgroup_size(1)]]
fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol : ptr<workgroup, array<S2, 4u>>) {
let x = *(tint_symbol);
}
)";
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, EmtpyModule) {
auto* src = "";