writer/msl: Handle texture and sampler variables

Move these module-scope variables to entry point parameters and pass
them as arguments to functions that use them. Disable entry point IO
validation for them.

Emit [[texture()]] and [[sampler()]] attributes on these entry point
parameters.

Fixed: tint:145
Change-Id: I936a80801875a5d0b6cd98a2e8f3e297a2f53509
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/53961
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
James Price
2021-06-11 12:34:26 +00:00
committed by Tint LUCI CQ
parent 2940c7002c
commit 830b97ffa9
90 changed files with 3212 additions and 47 deletions

View File

@@ -52,15 +52,17 @@ Output Msl::Run(const Program* in, const DataMap&) {
CloneContext ctx(&builder, &out.program);
// TODO(jrprice): Consider making this a standalone transform, with target
// storage class(es) as transform options.
HandlePrivateAndWorkgroupVariables(ctx);
HandleModuleScopeVariables(ctx);
ctx.Clone();
return Output{Program(std::move(builder))};
}
void Msl::HandlePrivateAndWorkgroupVariables(CloneContext& ctx) const {
void Msl::HandleModuleScopeVariables(CloneContext& ctx) const {
// MSL does not allow private and workgroup variables at module-scope, so we
// push these declarations into the entry point function and then pass them as
// pointer parameters to any function that references them.
// Similarly, texture and sampler types are converted to entry point
// parameters and passed by value to functions that need them.
//
// Since WGSL does not allow function-scope variables to have these storage
// classes, we annotate the new variable declarations with an attribute that
@@ -100,14 +102,15 @@ void Msl::HandlePrivateAndWorkgroupVariables(CloneContext& ctx) const {
std::vector<ast::Function*> functions_to_process;
// Build a list of functions that transitively reference any private or
// workgroup variables.
// workgroup variables, or texture/sampler variables.
for (auto* func_ast : ctx.src->AST().Functions()) {
auto* func_sem = ctx.src->Sem().Get(func_ast);
bool needs_processing = false;
for (auto* var : func_sem->ReferencedModuleVariables()) {
if (var->StorageClass() == ast::StorageClass::kPrivate ||
var->StorageClass() == ast::StorageClass::kWorkgroup) {
var->StorageClass() == ast::StorageClass::kWorkgroup ||
var->StorageClass() == ast::StorageClass::kUniformConstant) {
needs_processing = true;
break;
}
@@ -133,7 +136,8 @@ void Msl::HandlePrivateAndWorkgroupVariables(CloneContext& ctx) const {
for (auto* var : func_sem->ReferencedModuleVariables()) {
if (var->StorageClass() != ast::StorageClass::kPrivate &&
var->StorageClass() != ast::StorageClass::kWorkgroup) {
var->StorageClass() != ast::StorageClass::kWorkgroup &&
var->StorageClass() != ast::StorageClass::kUniformConstant) {
continue;
}
@@ -143,32 +147,47 @@ void Msl::HandlePrivateAndWorkgroupVariables(CloneContext& ctx) const {
auto* store_type = CreateASTTypeFor(&ctx, var->Type()->UnwrapRef());
if (is_entry_point) {
// For an entry point, redeclare the variable at function-scope.
// Disable storage class validation on this variable.
auto* disable_validation =
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(),
ast::DisabledValidation::kFunctionVarStorageClass);
auto* constructor = ctx.Clone(var->Declaration()->constructor());
auto* local_var =
ctx.dst->Var(new_var_symbol, store_type, var->StorageClass(),
constructor, ast::DecorationList{disable_validation});
ctx.InsertBefore(func_ast->body()->statements(),
*func_ast->body()->begin(), ctx.dst->Decl(local_var));
if (store_type->is_handle()) {
// For a texture or sampler variable, redeclare it as an entry point
// parameter. Disable entry point parameter validation.
auto* disable_validation =
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(), ast::DisabledValidation::kEntryPointParameter);
auto decos = ctx.Clone(var->Declaration()->decorations());
decos.push_back(disable_validation);
auto* param = ctx.dst->Param(new_var_symbol, store_type, decos);
ctx.InsertFront(func_ast->params(), param);
} else {
// For a private or workgroup variable, redeclare it at function
// scope. Disable storage class validation on this variable.
auto* disable_validation =
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(),
ast::DisabledValidation::kFunctionVarStorageClass);
auto* constructor = ctx.Clone(var->Declaration()->constructor());
auto* local_var = ctx.dst->Var(new_var_symbol, store_type,
var->StorageClass(), constructor,
ast::DecorationList{disable_validation});
ctx.InsertFront(func_ast->body()->statements(),
ctx.dst->Decl(local_var));
}
} else {
// For a regular function, redeclare the variable as a pointer function
// parameter.
auto* ptr_type = ctx.dst->ty.pointer(store_type, var->StorageClass());
// For a regular function, redeclare the variable as a parameter.
// Use a pointer for non-handle types.
auto* param_type = store_type;
if (!store_type->is_handle()) {
param_type = ctx.dst->ty.pointer(param_type, var->StorageClass());
}
ctx.InsertBack(func_ast->params(),
ctx.dst->Param(new_var_symbol, ptr_type));
ctx.dst->Param(new_var_symbol, param_type));
}
// Replace all uses of the module-scope variable.
// For non-entry points, dereference non-handle pointer parameters.
for (auto* user : var->Users()) {
if (user->Stmt()->Function() == func_ast) {
ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
if (!is_entry_point) {
// For non-entry points, dereference the pointer argument.
if (!is_entry_point && !store_type->is_handle()) {
expr = ctx.dst->Deref(expr);
}
ctx.Replace(user->Declaration(), expr);
@@ -183,13 +202,14 @@ void Msl::HandlePrivateAndWorkgroupVariables(CloneContext& ctx) const {
auto* target = ctx.src->AST().Functions().Find(call->func()->symbol());
auto* target_sem = ctx.src->Sem().Get(target);
// Add new arguments for any referenced private and workgroup variables.
// Add new arguments for any variables that are needed by the callee.
// For entry points, pass non-handle types as pointers.
for (auto* target_var : target_sem->ReferencedModuleVariables()) {
if (target_var->StorageClass() == ast::StorageClass::kPrivate ||
target_var->StorageClass() == ast::StorageClass::kWorkgroup) {
target_var->StorageClass() == ast::StorageClass::kWorkgroup ||
target_var->StorageClass() == ast::StorageClass::kUniformConstant) {
ast::Expression* arg = ctx.dst->Expr(var_to_symbol[target_var]);
if (is_entry_point) {
// For entry points, pass the address of the variable.
if (is_entry_point && !target_var->Type()->UnwrapRef()->is_handle()) {
arg = ctx.dst->AddressOf(arg);
}
ctx.InsertBack(call->params(), arg);
@@ -198,11 +218,13 @@ void Msl::HandlePrivateAndWorkgroupVariables(CloneContext& ctx) const {
}
}
// Now remove all module-scope private and workgroup variables.
for (auto* var : ctx.src->AST().GlobalVariables()) {
if (var->declared_storage_class() == ast::StorageClass::kPrivate ||
var->declared_storage_class() == ast::StorageClass::kWorkgroup) {
ctx.Remove(ctx.src->AST().GlobalDeclarations(), var);
// Now remove all module-scope variables with these storage classes.
for (auto* var_ast : ctx.src->AST().GlobalVariables()) {
auto* var_sem = ctx.src->Sem().Get(var_ast);
if (var_sem->StorageClass() == ast::StorageClass::kPrivate ||
var_sem->StorageClass() == ast::StorageClass::kWorkgroup ||
var_sem->StorageClass() == ast::StorageClass::kUniformConstant) {
ctx.Remove(ctx.src->AST().GlobalDeclarations(), var_ast);
}
}
}

View File

@@ -36,10 +36,10 @@ class Msl : public Transform {
Output Run(const Program* program, const DataMap& data = {}) override;
private:
/// Pushes module-scope variables with private or workgroup storage classes
/// into the entry point function, and passes them as function parameters to
/// any functions that need them.
void HandlePrivateAndWorkgroupVariables(CloneContext& ctx) const;
/// Pushes module-scope variables with certain storage classes into the entry
/// point function, and passes them as function parameters to any functions
/// that need them.
void HandleModuleScopeVariables(CloneContext& ctx) const;
};
} // namespace transform

View File

@@ -22,7 +22,7 @@ namespace {
using MslTest = TransformTest;
TEST_F(MslTest, HandlePrivateAndWorkgroupVariables_Basic) {
TEST_F(MslTest, HandleModuleScopeVariables_Basic) {
auto* src = R"(
var<private> p : f32;
var<workgroup> w : f32;
@@ -47,7 +47,7 @@ fn main() {
EXPECT_EQ(expect, str(got));
}
TEST_F(MslTest, HandlePrivateAndWorkgroupVariables_FunctionCalls) {
TEST_F(MslTest, HandleModuleScopeVariables_FunctionCalls) {
auto* src = R"(
var<private> p : f32;
var<workgroup> w : f32;
@@ -100,7 +100,7 @@ fn main() {
EXPECT_EQ(expect, str(got));
}
TEST_F(MslTest, HandlePrivateAndWorkgroupVariables_Constructors) {
TEST_F(MslTest, HandleModuleScopeVariables_Constructors) {
auto* src = R"(
var<private> a : f32 = 1.0;
var<private> b : f32 = f32();
@@ -125,7 +125,7 @@ fn main() {
EXPECT_EQ(expect, str(got));
}
TEST_F(MslTest, HandlePrivateAndWorkgroupVariables_Pointers) {
TEST_F(MslTest, HandleModuleScopeVariables_Pointers) {
auto* src = R"(
var<private> p : f32;
var<workgroup> w : f32;
@@ -156,7 +156,7 @@ fn main() {
EXPECT_EQ(expect, str(got));
}
TEST_F(MslTest, HandlePrivateAndWorkgroupVariables_UnusedVariables) {
TEST_F(MslTest, HandleModuleScopeVariables_UnusedVariables) {
auto* src = R"(
var<private> p : f32;
var<workgroup> w : f32;
@@ -177,7 +177,7 @@ fn main() {
EXPECT_EQ(expect, str(got));
}
TEST_F(MslTest, HandlePrivateAndWorkgroupVariables_OtherVariables) {
TEST_F(MslTest, HandleModuleScopeVariables_OtherVariables) {
auto* src = R"(
[[block]]
struct S {
@@ -208,7 +208,85 @@ fn main() {
EXPECT_EQ(expect, str(got));
}
TEST_F(MslTest, HandlePrivateAndWorkgroupVariables_EmtpyModule) {
TEST_F(MslTest, HandleModuleScopeVariables_HandleTypes_Basic) {
auto* src = R"(
[[group(0), binding(0)]] var t : texture_2d<f32>;
[[group(0), binding(1)]] var s : sampler;
[[stage(compute)]]
fn main() {
ignore(t);
ignore(s);
}
)";
auto* expect = R"(
[[stage(compute)]]
fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter)]] tint_symbol : texture_2d<f32>, [[group(0), binding(1), internal(disable_validation__entry_point_parameter)]] tint_symbol_1 : sampler) {
ignore(tint_symbol);
ignore(tint_symbol_1);
}
)";
auto got = Run<Msl>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(MslTest, HandleModuleScopeVariables_HandleTypes_FunctionCalls) {
auto* src = R"(
[[group(0), binding(0)]] var t : texture_2d<f32>;
[[group(0), binding(1)]] var s : sampler;
fn no_uses() {
}
fn bar(a : f32, b : f32) {
ignore(t);
ignore(s);
}
fn foo(a : f32) {
let b : f32 = 2.0;
ignore(t);
bar(a, b);
no_uses();
}
[[stage(compute)]]
fn main() {
foo(1.0);
}
)";
auto* expect = R"(
fn no_uses() {
}
fn bar(a : f32, b : f32, tint_symbol : texture_2d<f32>, tint_symbol_1 : sampler) {
ignore(tint_symbol);
ignore(tint_symbol_1);
}
fn foo(a : f32, tint_symbol_2 : texture_2d<f32>, tint_symbol_3 : sampler) {
let b : f32 = 2.0;
ignore(tint_symbol_2);
bar(a, b, tint_symbol_2, tint_symbol_3);
no_uses();
}
[[stage(compute)]]
fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter)]] tint_symbol_4 : texture_2d<f32>, [[group(0), binding(1), internal(disable_validation__entry_point_parameter)]] tint_symbol_5 : sampler) {
foo(1.0, tint_symbol_4, tint_symbol_5);
}
)";
auto got = Run<Msl>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(MslTest, HandleModuleScopeVariables_EmtpyModule) {
auto* src = "";
auto got = Run<Msl>(src);

View File

@@ -111,6 +111,25 @@ ast::Type* Transform::CreateASTTypeFor(CloneContext* ctx, const sem::Type* ty) {
if (auto* s = ty->As<sem::Reference>()) {
return CreateASTTypeFor(ctx, s->StoreType());
}
if (auto* t = ty->As<sem::DepthTexture>()) {
return ctx->dst->create<ast::DepthTexture>(t->dim());
}
if (auto* t = ty->As<sem::MultisampledTexture>()) {
return ctx->dst->create<ast::MultisampledTexture>(
t->dim(), CreateASTTypeFor(ctx, t->type()));
}
if (auto* t = ty->As<sem::SampledTexture>()) {
return ctx->dst->create<ast::SampledTexture>(
t->dim(), CreateASTTypeFor(ctx, t->type()));
}
if (auto* t = ty->As<sem::StorageTexture>()) {
return ctx->dst->create<ast::StorageTexture>(
t->dim(), t->image_format(), CreateASTTypeFor(ctx, t->type()),
t->access());
}
if (auto* s = ty->As<sem::Sampler>()) {
return ctx->dst->create<ast::Sampler>(s->kind());
}
TINT_UNREACHABLE(ctx->dst->Diagnostics())
<< "Unhandled type: " << ty->TypeInfo().name;
return nullptr;