mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-12 14:46:08 +00:00
writer/msl: Handle texture and sampler variables
Move these module-scope variables to entry point parameters and pass them as arguments to functions that use them. Disable entry point IO validation for them. Emit [[texture()]] and [[sampler()]] attributes on these entry point parameters. Fixed: tint:145 Change-Id: I936a80801875a5d0b6cd98a2e8f3e297a2f53509 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/53961 Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: James Price <jrprice@google.com> Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
committed by
Tint LUCI CQ
parent
2940c7002c
commit
830b97ffa9
@@ -52,15 +52,17 @@ Output Msl::Run(const Program* in, const DataMap&) {
|
||||
CloneContext ctx(&builder, &out.program);
|
||||
// TODO(jrprice): Consider making this a standalone transform, with target
|
||||
// storage class(es) as transform options.
|
||||
HandlePrivateAndWorkgroupVariables(ctx);
|
||||
HandleModuleScopeVariables(ctx);
|
||||
ctx.Clone();
|
||||
return Output{Program(std::move(builder))};
|
||||
}
|
||||
|
||||
void Msl::HandlePrivateAndWorkgroupVariables(CloneContext& ctx) const {
|
||||
void Msl::HandleModuleScopeVariables(CloneContext& ctx) const {
|
||||
// MSL does not allow private and workgroup variables at module-scope, so we
|
||||
// push these declarations into the entry point function and then pass them as
|
||||
// pointer parameters to any function that references them.
|
||||
// Similarly, texture and sampler types are converted to entry point
|
||||
// parameters and passed by value to functions that need them.
|
||||
//
|
||||
// Since WGSL does not allow function-scope variables to have these storage
|
||||
// classes, we annotate the new variable declarations with an attribute that
|
||||
@@ -100,14 +102,15 @@ void Msl::HandlePrivateAndWorkgroupVariables(CloneContext& ctx) const {
|
||||
std::vector<ast::Function*> functions_to_process;
|
||||
|
||||
// Build a list of functions that transitively reference any private or
|
||||
// workgroup variables.
|
||||
// workgroup variables, or texture/sampler variables.
|
||||
for (auto* func_ast : ctx.src->AST().Functions()) {
|
||||
auto* func_sem = ctx.src->Sem().Get(func_ast);
|
||||
|
||||
bool needs_processing = false;
|
||||
for (auto* var : func_sem->ReferencedModuleVariables()) {
|
||||
if (var->StorageClass() == ast::StorageClass::kPrivate ||
|
||||
var->StorageClass() == ast::StorageClass::kWorkgroup) {
|
||||
var->StorageClass() == ast::StorageClass::kWorkgroup ||
|
||||
var->StorageClass() == ast::StorageClass::kUniformConstant) {
|
||||
needs_processing = true;
|
||||
break;
|
||||
}
|
||||
@@ -133,7 +136,8 @@ void Msl::HandlePrivateAndWorkgroupVariables(CloneContext& ctx) const {
|
||||
|
||||
for (auto* var : func_sem->ReferencedModuleVariables()) {
|
||||
if (var->StorageClass() != ast::StorageClass::kPrivate &&
|
||||
var->StorageClass() != ast::StorageClass::kWorkgroup) {
|
||||
var->StorageClass() != ast::StorageClass::kWorkgroup &&
|
||||
var->StorageClass() != ast::StorageClass::kUniformConstant) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -143,32 +147,47 @@ void Msl::HandlePrivateAndWorkgroupVariables(CloneContext& ctx) const {
|
||||
auto* store_type = CreateASTTypeFor(&ctx, var->Type()->UnwrapRef());
|
||||
|
||||
if (is_entry_point) {
|
||||
// For an entry point, redeclare the variable at function-scope.
|
||||
// Disable storage class validation on this variable.
|
||||
auto* disable_validation =
|
||||
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
|
||||
ctx.dst->ID(),
|
||||
ast::DisabledValidation::kFunctionVarStorageClass);
|
||||
auto* constructor = ctx.Clone(var->Declaration()->constructor());
|
||||
auto* local_var =
|
||||
ctx.dst->Var(new_var_symbol, store_type, var->StorageClass(),
|
||||
constructor, ast::DecorationList{disable_validation});
|
||||
ctx.InsertBefore(func_ast->body()->statements(),
|
||||
*func_ast->body()->begin(), ctx.dst->Decl(local_var));
|
||||
if (store_type->is_handle()) {
|
||||
// For a texture or sampler variable, redeclare it as an entry point
|
||||
// parameter. Disable entry point parameter validation.
|
||||
auto* disable_validation =
|
||||
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
|
||||
ctx.dst->ID(), ast::DisabledValidation::kEntryPointParameter);
|
||||
auto decos = ctx.Clone(var->Declaration()->decorations());
|
||||
decos.push_back(disable_validation);
|
||||
auto* param = ctx.dst->Param(new_var_symbol, store_type, decos);
|
||||
ctx.InsertFront(func_ast->params(), param);
|
||||
} else {
|
||||
// For a private or workgroup variable, redeclare it at function
|
||||
// scope. Disable storage class validation on this variable.
|
||||
auto* disable_validation =
|
||||
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
|
||||
ctx.dst->ID(),
|
||||
ast::DisabledValidation::kFunctionVarStorageClass);
|
||||
auto* constructor = ctx.Clone(var->Declaration()->constructor());
|
||||
auto* local_var = ctx.dst->Var(new_var_symbol, store_type,
|
||||
var->StorageClass(), constructor,
|
||||
ast::DecorationList{disable_validation});
|
||||
ctx.InsertFront(func_ast->body()->statements(),
|
||||
ctx.dst->Decl(local_var));
|
||||
}
|
||||
} else {
|
||||
// For a regular function, redeclare the variable as a pointer function
|
||||
// parameter.
|
||||
auto* ptr_type = ctx.dst->ty.pointer(store_type, var->StorageClass());
|
||||
// For a regular function, redeclare the variable as a parameter.
|
||||
// Use a pointer for non-handle types.
|
||||
auto* param_type = store_type;
|
||||
if (!store_type->is_handle()) {
|
||||
param_type = ctx.dst->ty.pointer(param_type, var->StorageClass());
|
||||
}
|
||||
ctx.InsertBack(func_ast->params(),
|
||||
ctx.dst->Param(new_var_symbol, ptr_type));
|
||||
ctx.dst->Param(new_var_symbol, param_type));
|
||||
}
|
||||
|
||||
// Replace all uses of the module-scope variable.
|
||||
// For non-entry points, dereference non-handle pointer parameters.
|
||||
for (auto* user : var->Users()) {
|
||||
if (user->Stmt()->Function() == func_ast) {
|
||||
ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
|
||||
if (!is_entry_point) {
|
||||
// For non-entry points, dereference the pointer argument.
|
||||
if (!is_entry_point && !store_type->is_handle()) {
|
||||
expr = ctx.dst->Deref(expr);
|
||||
}
|
||||
ctx.Replace(user->Declaration(), expr);
|
||||
@@ -183,13 +202,14 @@ void Msl::HandlePrivateAndWorkgroupVariables(CloneContext& ctx) const {
|
||||
auto* target = ctx.src->AST().Functions().Find(call->func()->symbol());
|
||||
auto* target_sem = ctx.src->Sem().Get(target);
|
||||
|
||||
// Add new arguments for any referenced private and workgroup variables.
|
||||
// Add new arguments for any variables that are needed by the callee.
|
||||
// For entry points, pass non-handle types as pointers.
|
||||
for (auto* target_var : target_sem->ReferencedModuleVariables()) {
|
||||
if (target_var->StorageClass() == ast::StorageClass::kPrivate ||
|
||||
target_var->StorageClass() == ast::StorageClass::kWorkgroup) {
|
||||
target_var->StorageClass() == ast::StorageClass::kWorkgroup ||
|
||||
target_var->StorageClass() == ast::StorageClass::kUniformConstant) {
|
||||
ast::Expression* arg = ctx.dst->Expr(var_to_symbol[target_var]);
|
||||
if (is_entry_point) {
|
||||
// For entry points, pass the address of the variable.
|
||||
if (is_entry_point && !target_var->Type()->UnwrapRef()->is_handle()) {
|
||||
arg = ctx.dst->AddressOf(arg);
|
||||
}
|
||||
ctx.InsertBack(call->params(), arg);
|
||||
@@ -198,11 +218,13 @@ void Msl::HandlePrivateAndWorkgroupVariables(CloneContext& ctx) const {
|
||||
}
|
||||
}
|
||||
|
||||
// Now remove all module-scope private and workgroup variables.
|
||||
for (auto* var : ctx.src->AST().GlobalVariables()) {
|
||||
if (var->declared_storage_class() == ast::StorageClass::kPrivate ||
|
||||
var->declared_storage_class() == ast::StorageClass::kWorkgroup) {
|
||||
ctx.Remove(ctx.src->AST().GlobalDeclarations(), var);
|
||||
// Now remove all module-scope variables with these storage classes.
|
||||
for (auto* var_ast : ctx.src->AST().GlobalVariables()) {
|
||||
auto* var_sem = ctx.src->Sem().Get(var_ast);
|
||||
if (var_sem->StorageClass() == ast::StorageClass::kPrivate ||
|
||||
var_sem->StorageClass() == ast::StorageClass::kWorkgroup ||
|
||||
var_sem->StorageClass() == ast::StorageClass::kUniformConstant) {
|
||||
ctx.Remove(ctx.src->AST().GlobalDeclarations(), var_ast);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,10 +36,10 @@ class Msl : public Transform {
|
||||
Output Run(const Program* program, const DataMap& data = {}) override;
|
||||
|
||||
private:
|
||||
/// Pushes module-scope variables with private or workgroup storage classes
|
||||
/// into the entry point function, and passes them as function parameters to
|
||||
/// any functions that need them.
|
||||
void HandlePrivateAndWorkgroupVariables(CloneContext& ctx) const;
|
||||
/// Pushes module-scope variables with certain storage classes into the entry
|
||||
/// point function, and passes them as function parameters to any functions
|
||||
/// that need them.
|
||||
void HandleModuleScopeVariables(CloneContext& ctx) const;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
|
||||
@@ -22,7 +22,7 @@ namespace {
|
||||
|
||||
using MslTest = TransformTest;
|
||||
|
||||
TEST_F(MslTest, HandlePrivateAndWorkgroupVariables_Basic) {
|
||||
TEST_F(MslTest, HandleModuleScopeVariables_Basic) {
|
||||
auto* src = R"(
|
||||
var<private> p : f32;
|
||||
var<workgroup> w : f32;
|
||||
@@ -47,7 +47,7 @@ fn main() {
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(MslTest, HandlePrivateAndWorkgroupVariables_FunctionCalls) {
|
||||
TEST_F(MslTest, HandleModuleScopeVariables_FunctionCalls) {
|
||||
auto* src = R"(
|
||||
var<private> p : f32;
|
||||
var<workgroup> w : f32;
|
||||
@@ -100,7 +100,7 @@ fn main() {
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(MslTest, HandlePrivateAndWorkgroupVariables_Constructors) {
|
||||
TEST_F(MslTest, HandleModuleScopeVariables_Constructors) {
|
||||
auto* src = R"(
|
||||
var<private> a : f32 = 1.0;
|
||||
var<private> b : f32 = f32();
|
||||
@@ -125,7 +125,7 @@ fn main() {
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(MslTest, HandlePrivateAndWorkgroupVariables_Pointers) {
|
||||
TEST_F(MslTest, HandleModuleScopeVariables_Pointers) {
|
||||
auto* src = R"(
|
||||
var<private> p : f32;
|
||||
var<workgroup> w : f32;
|
||||
@@ -156,7 +156,7 @@ fn main() {
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(MslTest, HandlePrivateAndWorkgroupVariables_UnusedVariables) {
|
||||
TEST_F(MslTest, HandleModuleScopeVariables_UnusedVariables) {
|
||||
auto* src = R"(
|
||||
var<private> p : f32;
|
||||
var<workgroup> w : f32;
|
||||
@@ -177,7 +177,7 @@ fn main() {
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(MslTest, HandlePrivateAndWorkgroupVariables_OtherVariables) {
|
||||
TEST_F(MslTest, HandleModuleScopeVariables_OtherVariables) {
|
||||
auto* src = R"(
|
||||
[[block]]
|
||||
struct S {
|
||||
@@ -208,7 +208,85 @@ fn main() {
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(MslTest, HandlePrivateAndWorkgroupVariables_EmtpyModule) {
|
||||
TEST_F(MslTest, HandleModuleScopeVariables_HandleTypes_Basic) {
|
||||
auto* src = R"(
|
||||
[[group(0), binding(0)]] var t : texture_2d<f32>;
|
||||
[[group(0), binding(1)]] var s : sampler;
|
||||
|
||||
[[stage(compute)]]
|
||||
fn main() {
|
||||
ignore(t);
|
||||
ignore(s);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
[[stage(compute)]]
|
||||
fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter)]] tint_symbol : texture_2d<f32>, [[group(0), binding(1), internal(disable_validation__entry_point_parameter)]] tint_symbol_1 : sampler) {
|
||||
ignore(tint_symbol);
|
||||
ignore(tint_symbol_1);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Msl>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(MslTest, HandleModuleScopeVariables_HandleTypes_FunctionCalls) {
|
||||
auto* src = R"(
|
||||
[[group(0), binding(0)]] var t : texture_2d<f32>;
|
||||
[[group(0), binding(1)]] var s : sampler;
|
||||
|
||||
fn no_uses() {
|
||||
}
|
||||
|
||||
fn bar(a : f32, b : f32) {
|
||||
ignore(t);
|
||||
ignore(s);
|
||||
}
|
||||
|
||||
fn foo(a : f32) {
|
||||
let b : f32 = 2.0;
|
||||
ignore(t);
|
||||
bar(a, b);
|
||||
no_uses();
|
||||
}
|
||||
|
||||
[[stage(compute)]]
|
||||
fn main() {
|
||||
foo(1.0);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn no_uses() {
|
||||
}
|
||||
|
||||
fn bar(a : f32, b : f32, tint_symbol : texture_2d<f32>, tint_symbol_1 : sampler) {
|
||||
ignore(tint_symbol);
|
||||
ignore(tint_symbol_1);
|
||||
}
|
||||
|
||||
fn foo(a : f32, tint_symbol_2 : texture_2d<f32>, tint_symbol_3 : sampler) {
|
||||
let b : f32 = 2.0;
|
||||
ignore(tint_symbol_2);
|
||||
bar(a, b, tint_symbol_2, tint_symbol_3);
|
||||
no_uses();
|
||||
}
|
||||
|
||||
[[stage(compute)]]
|
||||
fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter)]] tint_symbol_4 : texture_2d<f32>, [[group(0), binding(1), internal(disable_validation__entry_point_parameter)]] tint_symbol_5 : sampler) {
|
||||
foo(1.0, tint_symbol_4, tint_symbol_5);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Msl>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(MslTest, HandleModuleScopeVariables_EmtpyModule) {
|
||||
auto* src = "";
|
||||
|
||||
auto got = Run<Msl>(src);
|
||||
|
||||
@@ -111,6 +111,25 @@ ast::Type* Transform::CreateASTTypeFor(CloneContext* ctx, const sem::Type* ty) {
|
||||
if (auto* s = ty->As<sem::Reference>()) {
|
||||
return CreateASTTypeFor(ctx, s->StoreType());
|
||||
}
|
||||
if (auto* t = ty->As<sem::DepthTexture>()) {
|
||||
return ctx->dst->create<ast::DepthTexture>(t->dim());
|
||||
}
|
||||
if (auto* t = ty->As<sem::MultisampledTexture>()) {
|
||||
return ctx->dst->create<ast::MultisampledTexture>(
|
||||
t->dim(), CreateASTTypeFor(ctx, t->type()));
|
||||
}
|
||||
if (auto* t = ty->As<sem::SampledTexture>()) {
|
||||
return ctx->dst->create<ast::SampledTexture>(
|
||||
t->dim(), CreateASTTypeFor(ctx, t->type()));
|
||||
}
|
||||
if (auto* t = ty->As<sem::StorageTexture>()) {
|
||||
return ctx->dst->create<ast::StorageTexture>(
|
||||
t->dim(), t->image_format(), CreateASTTypeFor(ctx, t->type()),
|
||||
t->access());
|
||||
}
|
||||
if (auto* s = ty->As<sem::Sampler>()) {
|
||||
return ctx->dst->create<ast::Sampler>(s->kind());
|
||||
}
|
||||
TINT_UNREACHABLE(ctx->dst->Diagnostics())
|
||||
<< "Unhandled type: " << ty->TypeInfo().name;
|
||||
return nullptr;
|
||||
|
||||
Reference in New Issue
Block a user