mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-17 17:05:31 +00:00
msl: Handle workgroup matrix allocations
Use a threadgroup memory argument for any workgroup variable that contains a matrix. The generator now provides a list of threadgroup memory arguments for each entry point, so that the runtime knows how many bytes to allocate for each argument. Bug: tint:938 Change-Id: Ia4af33cd6a44c4f74258793443eb737c2931f5eb Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/64042 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Ben Clayton <bclayton@google.com> Commit-Queue: James Price <jrprice@google.com>
This commit is contained in:
committed by
Tint LUCI CQ
parent
de767b1842
commit
acaecab29d
@@ -15,6 +15,7 @@
|
||||
#include "src/transform/module_scope_var_to_entry_point_param.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
@@ -29,6 +30,24 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::ModuleScopeVarToEntryPointParam);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
// Returns `true` if `type` is or contains a matrix type.
|
||||
bool ContainsMatrix(const sem::Type* type) {
|
||||
type = type->UnwrapRef();
|
||||
if (type->Is<sem::Matrix>()) {
|
||||
return true;
|
||||
} else if (auto* ary = type->As<sem::Array>()) {
|
||||
return ContainsMatrix(ary->ElemType());
|
||||
} else if (auto* str = type->As<sem::Struct>()) {
|
||||
for (auto* member : str->Members()) {
|
||||
if (ContainsMatrix(member->Type())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
|
||||
|
||||
@@ -105,6 +124,9 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
|
||||
|
||||
auto* store_type = CreateASTTypeFor(ctx, var->Type()->UnwrapRef());
|
||||
|
||||
// Track whether the new variable is a pointer or not.
|
||||
bool is_pointer = false;
|
||||
|
||||
if (is_entry_point) {
|
||||
if (store_type->is_handle()) {
|
||||
// For a texture or sampler variable, redeclare it as an entry point
|
||||
@@ -117,17 +139,36 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
|
||||
auto* param = ctx.dst->Param(new_var_symbol, store_type, decos);
|
||||
ctx.InsertFront(func_ast->params(), param);
|
||||
} else {
|
||||
// For a private or workgroup variable, redeclare it at function
|
||||
// scope. Disable storage class validation on this variable.
|
||||
auto* disable_validation =
|
||||
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
|
||||
ctx.dst->ID(), ast::DisabledValidation::kIgnoreStorageClass);
|
||||
auto* constructor = ctx.Clone(var->Declaration()->constructor());
|
||||
auto* local_var = ctx.dst->Var(
|
||||
new_var_symbol, store_type, var->StorageClass(), constructor,
|
||||
ast::DecorationList{disable_validation});
|
||||
ctx.InsertFront(func_ast->body()->statements(),
|
||||
ctx.dst->Decl(local_var));
|
||||
if (var->StorageClass() == ast::StorageClass::kWorkgroup &&
|
||||
ContainsMatrix(var->Type())) {
|
||||
// Due to a bug in the MSL compiler, we use a threadgroup memory
|
||||
// argument for any workgroup allocation that contains a matrix.
|
||||
// See crbug.com/tint/938.
|
||||
auto* disable_validation =
|
||||
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
|
||||
ctx.dst->ID(),
|
||||
ast::DisabledValidation::kEntryPointParameter);
|
||||
auto* param_type =
|
||||
ctx.dst->ty.pointer(store_type, var->StorageClass());
|
||||
auto* param = ctx.dst->Param(new_var_symbol, param_type,
|
||||
{disable_validation});
|
||||
ctx.InsertFront(func_ast->params(), param);
|
||||
is_pointer = true;
|
||||
} else {
|
||||
// For any other private or workgroup variable, redeclare it at
|
||||
// function scope. Disable storage class validation on this
|
||||
// variable.
|
||||
auto* disable_validation =
|
||||
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
|
||||
ctx.dst->ID(),
|
||||
ast::DisabledValidation::kIgnoreStorageClass);
|
||||
auto* constructor = ctx.Clone(var->Declaration()->constructor());
|
||||
auto* local_var = ctx.dst->Var(
|
||||
new_var_symbol, store_type, var->StorageClass(), constructor,
|
||||
ast::DecorationList{disable_validation});
|
||||
ctx.InsertFront(func_ast->body()->statements(),
|
||||
ctx.dst->Decl(local_var));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// For a regular function, redeclare the variable as a parameter.
|
||||
@@ -135,6 +176,7 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
|
||||
auto* param_type = store_type;
|
||||
if (!store_type->is_handle()) {
|
||||
param_type = ctx.dst->ty.pointer(param_type, var->StorageClass());
|
||||
is_pointer = true;
|
||||
}
|
||||
ctx.InsertBack(func_ast->params(),
|
||||
ctx.dst->Param(new_var_symbol, param_type));
|
||||
@@ -145,7 +187,7 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
|
||||
for (auto* user : var->Users()) {
|
||||
if (user->Stmt()->Function() == func_ast) {
|
||||
ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
|
||||
if (!is_entry_point && !store_type->is_handle()) {
|
||||
if (is_pointer) {
|
||||
// If this identifier is used by an address-of operator, just remove
|
||||
// the address-of instead of adding a deref, since we already have a
|
||||
// pointer.
|
||||
@@ -172,11 +214,15 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
|
||||
// Add new arguments for any variables that are needed by the callee.
|
||||
// For entry points, pass non-handle types as pointers.
|
||||
for (auto* target_var : target_sem->ReferencedModuleVariables()) {
|
||||
bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
|
||||
bool is_workgroup_matrix =
|
||||
target_var->StorageClass() == ast::StorageClass::kWorkgroup &&
|
||||
ContainsMatrix(target_var->Type());
|
||||
if (target_var->StorageClass() == ast::StorageClass::kPrivate ||
|
||||
target_var->StorageClass() == ast::StorageClass::kWorkgroup ||
|
||||
target_var->StorageClass() == ast::StorageClass::kUniformConstant) {
|
||||
ast::Expression* arg = ctx.dst->Expr(var_to_symbol[target_var]);
|
||||
if (is_entry_point && !target_var->Type()->UnwrapRef()->is_handle()) {
|
||||
if (is_entry_point && !is_handle && !is_workgroup_matrix) {
|
||||
arg = ctx.dst->AddressOf(arg);
|
||||
}
|
||||
ctx.InsertBack(call->params(), arg);
|
||||
|
||||
@@ -329,6 +329,64 @@ fn main([[group(0), binding(0), internal(disable_validation__entry_point_paramet
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(ModuleScopeVarToEntryPointParamTest, Matrix) {
|
||||
auto* src = R"(
|
||||
var<workgroup> m : mat2x2<f32>;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
let x = m;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol : ptr<workgroup, mat2x2<f32>>) {
|
||||
let x = *(tint_symbol);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(ModuleScopeVarToEntryPointParamTest, NestedMatrix) {
|
||||
auto* src = R"(
|
||||
struct S1 {
|
||||
m : mat2x2<f32>;
|
||||
};
|
||||
struct S2 {
|
||||
s : S1;
|
||||
};
|
||||
var<workgroup> m : array<S2, 4>;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
let x = m;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct S1 {
|
||||
m : mat2x2<f32>;
|
||||
};
|
||||
|
||||
struct S2 {
|
||||
s : S1;
|
||||
};
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol : ptr<workgroup, array<S2, 4u>>) {
|
||||
let x = *(tint_symbol);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(ModuleScopeVarToEntryPointParamTest, EmtpyModule) {
|
||||
auto* src = "";
|
||||
|
||||
|
||||
Reference in New Issue
Block a user