mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-13 15:16:16 +00:00
writer/msl: Handle private and workgroup variables
Add a transform that pushes these into the entry point and then passes them by pointer to any functions that need them. Since WGSL does not allow non-function storage class at function-scope, add a DisableValidation attribute to bypass this check. Fixed: tint/726 Change-Id: Ic1f4cd691a54c19e77a60e8ba178508e4249bfd9 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/51962 Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: James Price <jrprice@google.com> Auto-Submit: James Price <jrprice@google.com> Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
committed by
Tint LUCI CQ
parent
61e573663d
commit
7a47fa8495
@@ -14,8 +14,16 @@
|
||||
|
||||
#include "src/transform/msl.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/ast/disable_validation_decoration.h"
|
||||
#include "src/program_builder.h"
|
||||
#include "src/sem/call.h"
|
||||
#include "src/sem/function.h"
|
||||
#include "src/sem/statement.h"
|
||||
#include "src/sem/variable.h"
|
||||
#include "src/transform/canonicalize_entry_point_io.h"
|
||||
#include "src/transform/external_texture_transform.h"
|
||||
#include "src/transform/manager.h"
|
||||
@@ -34,7 +42,162 @@ Output Msl::Run(const Program* in, const DataMap& data) {
|
||||
if (!out.program.IsValid()) {
|
||||
return out;
|
||||
}
|
||||
return Output{Program(std::move(out.program))};
|
||||
|
||||
ProgramBuilder builder;
|
||||
CloneContext ctx(&builder, &out.program);
|
||||
// TODO(jrprice): Consider making this a standalone transform, with target
|
||||
// storage class(es) as transform options.
|
||||
HandlePrivateAndWorkgroupVariables(ctx);
|
||||
ctx.Clone();
|
||||
return Output{Program(std::move(builder))};
|
||||
}
|
||||
|
||||
void Msl::HandlePrivateAndWorkgroupVariables(CloneContext& ctx) const {
|
||||
// MSL does not allow private and workgroup variables at module-scope, so we
|
||||
// push these declarations into the entry point function and then pass them as
|
||||
// pointer parameters to any function that references them.
|
||||
//
|
||||
// Since WGSL does not allow function-scope variables to have these storage
|
||||
// classes, we annotate the new variable declarations with an attribute that
|
||||
// bypasses that validation rule.
|
||||
//
|
||||
// Before:
|
||||
// ```
|
||||
// var<private> v : f32 = 2.0;
|
||||
//
|
||||
// fn foo() {
|
||||
// v = v + 1.0;
|
||||
// }
|
||||
//
|
||||
// [[stage(compute)]]
|
||||
// fn main() {
|
||||
// foo();
|
||||
// }
|
||||
// ```
|
||||
//
|
||||
// After:
|
||||
// ```
|
||||
// fn foo(v : ptr<private, f32>) {
|
||||
// *v = *v + 1.0;
|
||||
// }
|
||||
//
|
||||
// [[stage(compute)]]
|
||||
// fn main() {
|
||||
// var<private> v : f32 = 2.0;
|
||||
// let v_ptr : ptr<private, f32> = &f32;
|
||||
// foo(v_ptr);
|
||||
// }
|
||||
// ```
|
||||
|
||||
// Predetermine the list of function calls that need to be replaced.
|
||||
using CallList = std::vector<const ast::CallExpression*>;
|
||||
std::unordered_map<const ast::Function*, CallList> calls_to_replace;
|
||||
|
||||
std::vector<ast::Function*> functions_to_process;
|
||||
|
||||
// Build a list of functions that transitively reference any private or
|
||||
// workgroup variables.
|
||||
for (auto* func_ast : ctx.src->AST().Functions()) {
|
||||
auto* func_sem = ctx.src->Sem().Get(func_ast);
|
||||
|
||||
bool needs_processing = false;
|
||||
for (auto* var : func_sem->ReferencedModuleVariables()) {
|
||||
if (var->StorageClass() == ast::StorageClass::kPrivate ||
|
||||
var->StorageClass() == ast::StorageClass::kWorkgroup) {
|
||||
needs_processing = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (needs_processing) {
|
||||
functions_to_process.push_back(func_ast);
|
||||
|
||||
// Find all of the calls to this function that will need to be replaced.
|
||||
for (auto* call : func_sem->CallSites()) {
|
||||
auto* call_sem = ctx.src->Sem().Get(call);
|
||||
calls_to_replace[call_sem->Stmt()->Function()].push_back(call);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* func_ast : functions_to_process) {
|
||||
auto* func_sem = ctx.src->Sem().Get(func_ast);
|
||||
|
||||
// Map module-scope variables onto their function-scope replacement.
|
||||
std::unordered_map<const sem::Variable*, Symbol> var_to_symbol;
|
||||
|
||||
for (auto* var : func_sem->ReferencedModuleVariables()) {
|
||||
if (var->StorageClass() != ast::StorageClass::kPrivate &&
|
||||
var->StorageClass() != ast::StorageClass::kWorkgroup) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// This is the symbol for the pointer that replaces the module-scope var.
|
||||
auto new_var_symbol = ctx.dst->Sym();
|
||||
|
||||
auto* store_type = CreateASTTypeFor(&ctx, var->Type()->UnwrapRef());
|
||||
|
||||
if (func_ast->IsEntryPoint()) {
|
||||
// For an entry point, redeclare the variable at function-scope.
|
||||
// Disable storage class validation on this variable.
|
||||
auto* disable_validation =
|
||||
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
|
||||
ctx.dst->ID(),
|
||||
ast::DisabledValidation::kFunctionVarStorageClass);
|
||||
auto* constructor = ctx.Clone(var->Declaration()->constructor());
|
||||
auto* local_var =
|
||||
ctx.dst->Var(ctx.dst->Sym(), store_type, var->StorageClass(),
|
||||
constructor, {disable_validation});
|
||||
ctx.InsertBefore(func_ast->body()->statements(),
|
||||
*func_ast->body()->begin(), ctx.dst->Decl(local_var));
|
||||
|
||||
// Now take the address of the variable.
|
||||
auto* ptr = ctx.dst->Const(new_var_symbol, nullptr,
|
||||
ctx.dst->AddressOf(local_var));
|
||||
ctx.InsertBefore(func_ast->body()->statements(),
|
||||
*func_ast->body()->begin(), ctx.dst->Decl(ptr));
|
||||
} else {
|
||||
// For a regular function, redeclare the variable as a pointer function
|
||||
// parameter.
|
||||
auto* ptr_type = ctx.dst->ty.pointer(store_type, var->StorageClass());
|
||||
ctx.InsertBack(func_ast->params(),
|
||||
ctx.dst->Param(new_var_symbol, ptr_type));
|
||||
}
|
||||
|
||||
// Replace all uses of the module-scope variable with the pointer
|
||||
// replacement (dereferenced).
|
||||
for (auto* user : var->Users()) {
|
||||
if (user->Stmt()->Function() == func_ast) {
|
||||
ctx.Replace(user->Declaration(), ctx.dst->Deref(new_var_symbol));
|
||||
}
|
||||
}
|
||||
|
||||
var_to_symbol[var] = new_var_symbol;
|
||||
}
|
||||
|
||||
// Pass the pointers through to any functions that need them.
|
||||
for (auto* call : calls_to_replace[func_ast]) {
|
||||
auto* target = ctx.src->AST().Functions().Find(call->func()->symbol());
|
||||
auto* target_sem = ctx.src->Sem().Get(target);
|
||||
|
||||
// Add new arguments for any referenced private and workgroup variables.
|
||||
for (auto* target_var : target_sem->ReferencedModuleVariables()) {
|
||||
if (target_var->StorageClass() == ast::StorageClass::kPrivate ||
|
||||
target_var->StorageClass() == ast::StorageClass::kWorkgroup) {
|
||||
ctx.InsertBack(call->params(),
|
||||
ctx.dst->Expr(var_to_symbol[target_var]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now remove all module-scope private and workgroup variables.
|
||||
for (auto* var : ctx.src->AST().GlobalVariables()) {
|
||||
if (var->declared_storage_class() == ast::StorageClass::kPrivate ||
|
||||
var->declared_storage_class() == ast::StorageClass::kWorkgroup) {
|
||||
ctx.Remove(ctx.src->AST().GlobalDeclarations(), var);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
|
||||
@@ -34,6 +34,12 @@ class Msl : public Transform {
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns the transformation result
|
||||
Output Run(const Program* program, const DataMap& data = {}) override;
|
||||
|
||||
private:
|
||||
/// Pushes module-scope variables with private or workgroup storage classes
|
||||
/// into the entry point function, and passes them as function parameters to
|
||||
/// any functions that need them.
|
||||
void HandlePrivateAndWorkgroupVariables(CloneContext& ctx) const;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
|
||||
@@ -20,6 +20,210 @@ namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using MslTest = TransformTest;
|
||||
|
||||
TEST_F(MslTest, HandlePrivateAndWorkgroupVariables_Basic) {
|
||||
auto* src = R"(
|
||||
var<private> p : f32;
|
||||
var<workgroup> w : f32;
|
||||
|
||||
[[stage(compute)]]
|
||||
fn main() {
|
||||
w = p;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
[[stage(compute)]]
|
||||
fn main() {
|
||||
[[internal(disable_validation__function_var_storage_class)]] var<workgroup> tint_symbol_1 : f32;
|
||||
let tint_symbol = &(tint_symbol_1);
|
||||
[[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol_3 : f32;
|
||||
let tint_symbol_2 = &(tint_symbol_3);
|
||||
*(tint_symbol) = *(tint_symbol_2);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Msl>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(MslTest, HandlePrivateAndWorkgroupVariables_FunctionCalls) {
|
||||
auto* src = R"(
|
||||
var<private> p : f32;
|
||||
var<workgroup> w : f32;
|
||||
|
||||
fn no_uses() {
|
||||
}
|
||||
|
||||
fn bar(a : f32, b : f32) {
|
||||
p = a;
|
||||
w = b;
|
||||
}
|
||||
|
||||
fn foo(a : f32) {
|
||||
let b : f32 = 2.0;
|
||||
bar(a, b);
|
||||
no_uses();
|
||||
}
|
||||
|
||||
[[stage(compute)]]
|
||||
fn main() {
|
||||
foo(1.0);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn no_uses() {
|
||||
}
|
||||
|
||||
fn bar(a : f32, b : f32, tint_symbol : ptr<private, f32>, tint_symbol_1 : ptr<workgroup, f32>) {
|
||||
*(tint_symbol) = a;
|
||||
*(tint_symbol_1) = b;
|
||||
}
|
||||
|
||||
fn foo(a : f32, tint_symbol_2 : ptr<private, f32>, tint_symbol_3 : ptr<workgroup, f32>) {
|
||||
let b : f32 = 2.0;
|
||||
bar(a, b, tint_symbol_2, tint_symbol_3);
|
||||
no_uses();
|
||||
}
|
||||
|
||||
[[stage(compute)]]
|
||||
fn main() {
|
||||
[[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol_5 : f32;
|
||||
let tint_symbol_4 = &(tint_symbol_5);
|
||||
[[internal(disable_validation__function_var_storage_class)]] var<workgroup> tint_symbol_7 : f32;
|
||||
let tint_symbol_6 = &(tint_symbol_7);
|
||||
foo(1.0, tint_symbol_4, tint_symbol_6);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Msl>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(MslTest, HandlePrivateAndWorkgroupVariables_Constructors) {
|
||||
auto* src = R"(
|
||||
var<private> a : f32 = 1.0;
|
||||
var<private> b : f32 = f32();
|
||||
|
||||
[[stage(compute)]]
|
||||
fn main() {
|
||||
let x : f32 = a + b;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
[[stage(compute)]]
|
||||
fn main() {
|
||||
[[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol_1 : f32 = 1.0;
|
||||
let tint_symbol = &(tint_symbol_1);
|
||||
[[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol_3 : f32 = f32();
|
||||
let tint_symbol_2 = &(tint_symbol_3);
|
||||
let x : f32 = (*(tint_symbol) + *(tint_symbol_2));
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Msl>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(MslTest, HandlePrivateAndWorkgroupVariables_Pointers) {
|
||||
auto* src = R"(
|
||||
var<private> p : f32;
|
||||
var<workgroup> w : f32;
|
||||
|
||||
[[stage(compute)]]
|
||||
fn main() {
|
||||
let p_ptr : ptr<private, f32> = &p;
|
||||
let w_ptr : ptr<workgroup, f32> = &w;
|
||||
let x : f32 = *p_ptr + *w_ptr;
|
||||
*p_ptr = x;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
[[stage(compute)]]
|
||||
fn main() {
|
||||
[[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol_1 : f32;
|
||||
let tint_symbol = &(tint_symbol_1);
|
||||
[[internal(disable_validation__function_var_storage_class)]] var<workgroup> tint_symbol_3 : f32;
|
||||
let tint_symbol_2 = &(tint_symbol_3);
|
||||
let p_ptr : ptr<private, f32> = &(*(tint_symbol));
|
||||
let w_ptr : ptr<workgroup, f32> = &(*(tint_symbol_2));
|
||||
let x : f32 = (*(p_ptr) + *(w_ptr));
|
||||
*(p_ptr) = x;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Msl>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(MslTest, HandlePrivateAndWorkgroupVariables_UnusedVariables) {
|
||||
auto* src = R"(
|
||||
var<private> p : f32;
|
||||
var<workgroup> w : f32;
|
||||
|
||||
[[stage(compute)]]
|
||||
fn main() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
[[stage(compute)]]
|
||||
fn main() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Msl>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(MslTest, HandlePrivateAndWorkgroupVariables_OtherVariables) {
|
||||
auto* src = R"(
|
||||
[[block]]
|
||||
struct S {
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]]
|
||||
var<uniform> u : S;
|
||||
|
||||
[[stage(compute)]]
|
||||
fn main() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
[[block]]
|
||||
struct S {
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<uniform> u : S;
|
||||
|
||||
[[stage(compute)]]
|
||||
fn main() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Msl>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(MslTest, HandlePrivateAndWorkgroupVariables_EmtpyModule) {
|
||||
auto* src = "";
|
||||
|
||||
auto got = Run<Msl>(src);
|
||||
|
||||
EXPECT_EQ(src, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
Reference in New Issue
Block a user