msl: Handle buffer variables in transform

This removes a lot of awkward logic from the MSL writer, and means
that we now handle all module-scope variables with the same transform.

Change-Id: I782e36a4b88dafbc3f8364f7caa7f95c6ae3f5f1
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/67643
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
James Price
2021-10-28 15:00:39 +00:00
parent c6ce5785d0
commit e548db90f6
141 changed files with 1681 additions and 1546 deletions

View File

@@ -92,21 +92,18 @@ struct ModuleScopeVarToEntryPointParam::State {
std::vector<const ast::Function*> functions_to_process;
// Build a list of functions that transitively reference any private or
// workgroup variables, or texture/sampler variables.
// Build a list of functions that transitively reference any module-scope
// variables.
for (auto* func_ast : ctx.src->AST().Functions()) {
auto* func_sem = ctx.src->Sem().Get(func_ast);
bool needs_processing = false;
for (auto* var : func_sem->ReferencedModuleVariables()) {
if (var->StorageClass() == ast::StorageClass::kPrivate ||
var->StorageClass() == ast::StorageClass::kWorkgroup ||
var->StorageClass() == ast::StorageClass::kUniformConstant) {
if (var->StorageClass() != ast::StorageClass::kNone) {
needs_processing = true;
break;
}
}
if (needs_processing) {
functions_to_process.push_back(func_ast);
@@ -140,8 +137,12 @@ struct ModuleScopeVarToEntryPointParam::State {
auto* func_sem = ctx.src->Sem().Get(func_ast);
bool is_entry_point = func_ast->IsEntryPoint();
// Map module-scope variables onto their function-scope replacement.
std::unordered_map<const sem::Variable*, Symbol> var_to_symbol;
// Map module-scope variables onto their replacement.
struct NewVar {
Symbol symbol;
bool is_pointer;
};
std::unordered_map<const sem::Variable*, NewVar> var_to_newvar;
// We aggregate all workgroup variables into a struct to avoid hitting
// MSL's limit for threadgroup memory arguments.
@@ -155,11 +156,18 @@ struct ModuleScopeVarToEntryPointParam::State {
};
for (auto* var : func_sem->ReferencedModuleVariables()) {
if (var->StorageClass() != ast::StorageClass::kPrivate &&
var->StorageClass() != ast::StorageClass::kWorkgroup &&
var->StorageClass() != ast::StorageClass::kUniformConstant) {
auto sc = var->StorageClass();
if (sc == ast::StorageClass::kNone) {
continue;
}
if (sc != ast::StorageClass::kPrivate &&
sc != ast::StorageClass::kStorage &&
sc != ast::StorageClass::kUniform &&
sc != ast::StorageClass::kUniformConstant &&
sc != ast::StorageClass::kWorkgroup) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unhandled module-scope storage class (" << sc << ")";
}
// This is the symbol for the variable that replaces the module-scope
// var.
@@ -185,7 +193,26 @@ struct ModuleScopeVarToEntryPointParam::State {
decos.push_back(disable_validation);
auto* param = ctx.dst->Param(new_var_symbol, store_type(), decos);
ctx.InsertFront(func_ast->params, param);
} else if (var->StorageClass() == ast::StorageClass::kWorkgroup &&
} else if (sc == ast::StorageClass::kStorage ||
sc == ast::StorageClass::kUniform) {
// Variables into the Storage and Uniform storage classes are
// redeclared as entry point parameters with a pointer type.
auto attributes = ctx.Clone(var->Declaration()->decorations);
attributes.push_back(
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(),
ast::DisabledValidation::kEntryPointParameter));
attributes.push_back(
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(),
ast::DisabledValidation::kIgnoreStorageClass));
auto* param_type = ctx.dst->ty.pointer(
store_type(), sc, var->Declaration()->declared_access);
auto* param =
ctx.dst->Param(new_var_symbol, param_type, attributes);
ctx.InsertFront(func_ast->params, param);
is_pointer = true;
} else if (sc == ast::StorageClass::kWorkgroup &&
ContainsMatrix(var->Type())) {
// Due to a bug in the MSL compiler, we use a threadgroup memory
// argument for any workgroup allocation that contains a matrix.
@@ -210,17 +237,17 @@ struct ModuleScopeVarToEntryPointParam::State {
ctx.dst->Decl(local_var));
is_pointer = true;
} else {
// For any other private or workgroup variable, redeclare it at
// function scope. Disable storage class validation on this
// variable.
// Variables in the Private and Workgroup storage classes are
// redeclared at function scope. Disable storage class validation on
// this variable.
auto* disable_validation =
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(),
ast::DisabledValidation::kIgnoreStorageClass);
auto* constructor = ctx.Clone(var->Declaration()->constructor);
auto* local_var = ctx.dst->Var(
new_var_symbol, store_type(), var->StorageClass(), constructor,
ast::DecorationList{disable_validation});
auto* local_var =
ctx.dst->Var(new_var_symbol, store_type(), sc, constructor,
ast::DecorationList{disable_validation});
ctx.InsertFront(func_ast->body->statements,
ctx.dst->Decl(local_var));
}
@@ -230,11 +257,16 @@ struct ModuleScopeVarToEntryPointParam::State {
auto* param_type = store_type();
ast::DecorationList attributes;
if (!var->Type()->UnwrapRef()->is_handle()) {
param_type = ctx.dst->ty.pointer(param_type, var->StorageClass());
param_type = ctx.dst->ty.pointer(
param_type, sc, var->Declaration()->declared_access);
is_pointer = true;
// Disable validation of arguments passed to this pointer parameter,
// as we will sometimes pass pointers to struct members.
// Disable validation of the parameter's storage class and of
// arguments passed it.
attributes.push_back(
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(),
ast::DisabledValidation::kIgnoreStorageClass));
attributes.push_back(
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(),
@@ -267,7 +299,7 @@ struct ModuleScopeVarToEntryPointParam::State {
}
}
var_to_symbol[var] = new_var_symbol;
var_to_newvar[var] = {new_var_symbol, is_pointer};
}
if (!workgroup_parameter_members.empty()) {
@@ -294,21 +326,20 @@ struct ModuleScopeVarToEntryPointParam::State {
// Add new arguments for any variables that are needed by the callee.
// For entry points, pass non-handle types as pointers.
for (auto* target_var : target_sem->ReferencedModuleVariables()) {
bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
bool is_workgroup_matrix =
target_var->StorageClass() == ast::StorageClass::kWorkgroup &&
ContainsMatrix(target_var->Type());
if (target_var->StorageClass() == ast::StorageClass::kPrivate ||
target_var->StorageClass() == ast::StorageClass::kWorkgroup ||
target_var->StorageClass() ==
ast::StorageClass::kUniformConstant) {
const ast::Expression* arg =
ctx.dst->Expr(var_to_symbol[target_var]);
if (is_entry_point && !is_handle && !is_workgroup_matrix) {
arg = ctx.dst->AddressOf(arg);
}
ctx.InsertBack(call->args, arg);
auto sc = target_var->StorageClass();
if (sc == ast::StorageClass::kNone) {
continue;
}
auto new_var = var_to_newvar[target_var];
bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
const ast::Expression* arg = ctx.dst->Expr(new_var.symbol);
if (is_entry_point && !is_handle && !new_var.is_pointer) {
// We need to pass a pointer and we don't already have one, so take
// the address of the new variable.
arg = ctx.dst->AddressOf(arg);
}
ctx.InsertBack(call->args, arg);
}
}
}
@@ -316,9 +347,7 @@ struct ModuleScopeVarToEntryPointParam::State {
// Now remove all module-scope variables with these storage classes.
for (auto* var_ast : ctx.src->AST().GlobalVariables()) {
auto* var_sem = ctx.src->Sem().Get(var_ast);
if (var_sem->StorageClass() == ast::StorageClass::kPrivate ||
var_sem->StorageClass() == ast::StorageClass::kWorkgroup ||
var_sem->StorageClass() == ast::StorageClass::kUniformConstant) {
if (var_sem->StorageClass() != ast::StorageClass::kNone) {
ctx.Remove(ctx.src->AST().GlobalDeclarations(), var_ast);
}
}

View File

@@ -22,22 +22,27 @@ namespace transform {
/// Move module-scope variables into the entry point as parameters.
///
/// MSL does not allow private and workgroup variables at module-scope, so we
/// push these declarations into the entry point function and then pass them as
/// pointer parameters to any function that references them.
/// Similarly, texture and sampler types are converted to entry point
/// parameters and passed by value to functions that need them.
/// MSL does not allow module-scope variables to have any address space other
/// than `constant`. This transform moves all module-scope declarations into the
/// entry point function (either as parameters or function-scope variables) and
/// then passes them as pointer parameters to any function that references them.
///
/// Since WGSL does not allow function-scope variables to have these storage
/// classes, we annotate the new variable declarations with an attribute that
/// bypasses that validation rule.
/// Since WGSL does not allow entry point parameters or function-scope variables
/// to have these storage classes, we annotate the new variable declarations
/// with an attribute that bypasses that validation rule.
///
/// Before:
/// ```
/// var<private> v : f32 = 2.0;
/// [[block]]
/// struct S {
/// f : f32;
/// };
/// [[binding(0), group(0)]]
/// var<storage, read> s : S;
/// var<private> p : f32 = 2.0;
///
/// fn foo() {
/// v = v + 1.0;
/// p = p + f;
/// }
///
/// [[stage(compute), workgroup_size(1)]]
@@ -48,14 +53,14 @@ namespace transform {
///
/// After:
/// ```
/// fn foo(v : ptr<private, f32>) {
/// *v = *v + 1.0;
/// fn foo(p : ptr<private, f32>, sptr : ptr<storage, S, read>) {
/// *p = *p + (*sptr).f;
/// }
///
/// [[stage(compute), workgroup_size(1)]]
/// fn main() {
/// var<private> v : f32 = 2.0;
/// foo(&v);
/// fn main(sptr : ptr<storage, S, read>) {
/// var<private> p : f32 = 2.0;
/// foo(&p, sptr);
/// }
/// ```
class ModuleScopeVarToEntryPointParam

View File

@@ -78,12 +78,12 @@ fn main() {
fn no_uses() {
}
fn bar(a : f32, b : f32, [[internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol : ptr<private, f32>, [[internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_1 : ptr<workgroup, f32>) {
fn bar(a : f32, b : f32, [[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol : ptr<private, f32>, [[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_1 : ptr<workgroup, f32>) {
*(tint_symbol) = a;
*(tint_symbol_1) = b;
}
fn foo(a : f32, [[internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_2 : ptr<private, f32>, [[internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_3 : ptr<workgroup, f32>) {
fn foo(a : f32, [[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_2 : ptr<private, f32>, [[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_3 : ptr<workgroup, f32>) {
let b : f32 = 2.0;
bar(a, b, tint_symbol_2, tint_symbol_3);
no_uses();
@@ -181,7 +181,7 @@ fn bar(p : ptr<private, f32>) {
*(p) = 0.0;
}
fn foo([[internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol : ptr<private, f32>) {
fn foo([[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol : ptr<private, f32>) {
bar(tint_symbol);
}
@@ -197,28 +197,7 @@ fn main() {
EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, UnusedVariables) {
auto* src = R"(
var<private> p : f32;
var<workgroup> w : f32;
[[stage(compute), workgroup_size(1)]]
fn main() {
}
)";
auto* expect = R"(
[[stage(compute), workgroup_size(1)]]
fn main() {
}
)";
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, OtherVariables) {
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffers_Basic) {
auto* src = R"(
[[block]]
struct S {
@@ -227,9 +206,13 @@ struct S {
[[group(0), binding(0)]]
var<uniform> u : S;
[[group(0), binding(1)]]
var<storage> s : S;
[[stage(compute), workgroup_size(1)]]
fn main() {
_ = u;
_ = s;
}
)";
@@ -239,10 +222,75 @@ struct S {
a : f32;
};
[[group(0), binding(0)]] var<uniform> u : S;
[[stage(compute), workgroup_size(1)]]
fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter), internal(disable_validation__ignore_storage_class)]] tint_symbol : ptr<uniform, S>, [[group(0), binding(1), internal(disable_validation__entry_point_parameter), internal(disable_validation__ignore_storage_class)]] tint_symbol_1 : ptr<storage, S>) {
_ = *(tint_symbol);
_ = *(tint_symbol_1);
}
)";
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Buffers_FunctionCalls) {
auto* src = R"(
[[block]]
struct S {
a : f32;
};
[[group(0), binding(0)]]
var<uniform> u : S;
[[group(0), binding(1)]]
var<storage> s : S;
fn no_uses() {
}
fn bar(a : f32, b : f32) {
_ = u;
_ = s;
}
fn foo(a : f32) {
let b : f32 = 2.0;
_ = u;
bar(a, b);
no_uses();
}
[[stage(compute), workgroup_size(1)]]
fn main() {
foo(1.0);
}
)";
auto* expect = R"(
[[block]]
struct S {
a : f32;
};
fn no_uses() {
}
fn bar(a : f32, b : f32, [[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol : ptr<uniform, S>, [[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_1 : ptr<storage, S>) {
_ = *(tint_symbol);
_ = *(tint_symbol_1);
}
fn foo(a : f32, [[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_2 : ptr<uniform, S>, [[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_3 : ptr<storage, S>) {
let b : f32 = 2.0;
_ = *(tint_symbol_2);
bar(a, b, tint_symbol_2, tint_symbol_3);
no_uses();
}
[[stage(compute), workgroup_size(1)]]
fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter), internal(disable_validation__ignore_storage_class)]] tint_symbol_4 : ptr<uniform, S>, [[group(0), binding(1), internal(disable_validation__entry_point_parameter), internal(disable_validation__ignore_storage_class)]] tint_symbol_5 : ptr<storage, S>) {
foo(1.0, tint_symbol_4, tint_symbol_5);
}
)";
@@ -258,16 +306,16 @@ TEST_F(ModuleScopeVarToEntryPointParamTest, HandleTypes_Basic) {
[[stage(compute), workgroup_size(1)]]
fn main() {
ignore(t);
ignore(s);
_ = t;
_ = s;
}
)";
auto* expect = R"(
[[stage(compute), workgroup_size(1)]]
fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter)]] tint_symbol : texture_2d<f32>, [[group(0), binding(1), internal(disable_validation__entry_point_parameter)]] tint_symbol_1 : sampler) {
ignore(tint_symbol);
ignore(tint_symbol_1);
_ = tint_symbol;
_ = tint_symbol_1;
}
)";
@@ -285,13 +333,13 @@ fn no_uses() {
}
fn bar(a : f32, b : f32) {
ignore(t);
ignore(s);
_ = t;
_ = s;
}
fn foo(a : f32) {
let b : f32 = 2.0;
ignore(t);
_ = t;
bar(a, b);
no_uses();
}
@@ -307,13 +355,13 @@ fn no_uses() {
}
fn bar(a : f32, b : f32, tint_symbol : texture_2d<f32>, tint_symbol_1 : sampler) {
ignore(tint_symbol);
ignore(tint_symbol_1);
_ = tint_symbol;
_ = tint_symbol_1;
}
fn foo(a : f32, tint_symbol_2 : texture_2d<f32>, tint_symbol_3 : sampler) {
let b : f32 = 2.0;
ignore(tint_symbol_2);
_ = tint_symbol_2;
bar(a, b, tint_symbol_2, tint_symbol_3);
no_uses();
}
@@ -440,6 +488,45 @@ fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol_1 :
EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, UnusedVariables) {
auto* src = R"(
[[block]]
struct S {
a : f32;
};
var<private> p : f32;
var<workgroup> w : f32;
[[group(0), binding(0)]]
var<uniform> ub : S;
[[group(0), binding(1)]]
var<storage> sb : S;
[[group(0), binding(2)]] var t : texture_2d<f32>;
[[group(0), binding(3)]] var s : sampler;
[[stage(compute), workgroup_size(1)]]
fn main() {
}
)";
auto* expect = R"(
[[block]]
struct S {
a : f32;
};
[[stage(compute), workgroup_size(1)]]
fn main() {
}
)";
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, EmtpyModule) {
auto* src = "";