transform: Add ModuleScopeVarToEntryPointParam

This is the HandleModuleScopeVars() part of the MSL sanitizer moved
verbatim to a standalone transform. The transform code is unchanged,
but some expected test outputs are different as this is now tested in
isolation instead of along with the rest of the sanitizer transforms.

This is step towards removing the sanitizers completely.

Change-Id: I7be826e2119451fc2ce2891740cc94f978e7d5a1
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/63583
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
James Price 2021-09-07 18:59:21 +00:00
parent b584b374a1
commit 3646400342
9 changed files with 597 additions and 501 deletions

View File

@ -446,6 +446,8 @@ libtint_source_set("libtint_core_all_src") {
"transform/loop_to_for_loop.h", "transform/loop_to_for_loop.h",
"transform/manager.cc", "transform/manager.cc",
"transform/manager.h", "transform/manager.h",
"transform/module_scope_var_to_entry_point_param.cc",
"transform/module_scope_var_to_entry_point_param.h",
"transform/pad_array_elements.cc", "transform/pad_array_elements.cc",
"transform/pad_array_elements.h", "transform/pad_array_elements.h",
"transform/promote_initializers_to_const_var.cc", "transform/promote_initializers_to_const_var.cc",

View File

@ -316,6 +316,8 @@ set(TINT_LIB_SRCS
transform/loop_to_for_loop.h transform/loop_to_for_loop.h
transform/manager.cc transform/manager.cc
transform/manager.h transform/manager.h
transform/module_scope_var_to_entry_point_param.cc
transform/module_scope_var_to_entry_point_param.h
transform/pad_array_elements.cc transform/pad_array_elements.cc
transform/pad_array_elements.h transform/pad_array_elements.h
transform/promote_initializers_to_const_var.cc transform/promote_initializers_to_const_var.cc
@ -937,6 +939,7 @@ if(${TINT_BUILD_TESTS})
transform/for_loop_to_loop_test.cc transform/for_loop_to_loop_test.cc
transform/inline_pointer_lets_test.cc transform/inline_pointer_lets_test.cc
transform/loop_to_for_loop_test.cc transform/loop_to_for_loop_test.cc
transform/module_scope_var_to_entry_point_param_test.cc
transform/pad_array_elements_test.cc transform/pad_array_elements_test.cc
transform/promote_initializers_to_const_var_test.cc transform/promote_initializers_to_const_var_test.cc
transform/renamer_test.cc transform/renamer_test.cc

View File

@ -0,0 +1,202 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/transform/module_scope_var_to_entry_point_param.h"
#include <unordered_map>
#include <utility>
#include <vector>
#include "src/ast/disable_validation_decoration.h"
#include "src/program_builder.h"
#include "src/sem/call.h"
#include "src/sem/function.h"
#include "src/sem/statement.h"
#include "src/sem/variable.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::ModuleScopeVarToEntryPointParam);
namespace tint {
namespace transform {
ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default;
void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
const DataMap&,
DataMap&) {
// Predetermine the list of function calls that need to be replaced.
using CallList = std::vector<const ast::CallExpression*>;
std::unordered_map<const ast::Function*, CallList> calls_to_replace;
std::vector<ast::Function*> functions_to_process;
// Build a list of functions that transitively reference any private or
// workgroup variables, or texture/sampler variables.
for (auto* func_ast : ctx.src->AST().Functions()) {
auto* func_sem = ctx.src->Sem().Get(func_ast);
bool needs_processing = false;
for (auto* var : func_sem->ReferencedModuleVariables()) {
if (var->StorageClass() == ast::StorageClass::kPrivate ||
var->StorageClass() == ast::StorageClass::kWorkgroup ||
var->StorageClass() == ast::StorageClass::kUniformConstant) {
needs_processing = true;
break;
}
}
if (needs_processing) {
functions_to_process.push_back(func_ast);
// Find all of the calls to this function that will need to be replaced.
for (auto* call : func_sem->CallSites()) {
auto* call_sem = ctx.src->Sem().Get(call);
calls_to_replace[call_sem->Stmt()->Function()].push_back(call);
}
}
}
// Build a list of `&ident` expressions. We'll use this later to avoid
// generating expressions of the form `&*ident`, which break WGSL validation
// rules when this expression is passed to a function.
// TODO(jrprice): We should add support for bidirectional SEM tree traversal
// so that we can do this on the fly instead.
std::unordered_map<ast::IdentifierExpression*, ast::UnaryOpExpression*>
ident_to_address_of;
for (auto* node : ctx.src->ASTNodes().Objects()) {
auto* address_of = node->As<ast::UnaryOpExpression>();
if (!address_of || address_of->op() != ast::UnaryOp::kAddressOf) {
continue;
}
if (auto* ident = address_of->expr()->As<ast::IdentifierExpression>()) {
ident_to_address_of[ident] = address_of;
}
}
for (auto* func_ast : functions_to_process) {
auto* func_sem = ctx.src->Sem().Get(func_ast);
bool is_entry_point = func_ast->IsEntryPoint();
// Map module-scope variables onto their function-scope replacement.
std::unordered_map<const sem::Variable*, Symbol> var_to_symbol;
for (auto* var : func_sem->ReferencedModuleVariables()) {
if (var->StorageClass() != ast::StorageClass::kPrivate &&
var->StorageClass() != ast::StorageClass::kWorkgroup &&
var->StorageClass() != ast::StorageClass::kUniformConstant) {
continue;
}
// This is the symbol for the variable that replaces the module-scope var.
auto new_var_symbol = ctx.dst->Sym();
auto* store_type = CreateASTTypeFor(ctx, var->Type()->UnwrapRef());
if (is_entry_point) {
if (store_type->is_handle()) {
// For a texture or sampler variable, redeclare it as an entry point
// parameter. Disable entry point parameter validation.
auto* disable_validation =
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(), ast::DisabledValidation::kEntryPointParameter);
auto decos = ctx.Clone(var->Declaration()->decorations());
decos.push_back(disable_validation);
auto* param = ctx.dst->Param(new_var_symbol, store_type, decos);
ctx.InsertFront(func_ast->params(), param);
} else {
// For a private or workgroup variable, redeclare it at function
// scope. Disable storage class validation on this variable.
auto* disable_validation =
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(), ast::DisabledValidation::kIgnoreStorageClass);
auto* constructor = ctx.Clone(var->Declaration()->constructor());
auto* local_var = ctx.dst->Var(
new_var_symbol, store_type, var->StorageClass(), constructor,
ast::DecorationList{disable_validation});
ctx.InsertFront(func_ast->body()->statements(),
ctx.dst->Decl(local_var));
}
} else {
// For a regular function, redeclare the variable as a parameter.
// Use a pointer for non-handle types.
auto* param_type = store_type;
if (!store_type->is_handle()) {
param_type = ctx.dst->ty.pointer(param_type, var->StorageClass());
}
ctx.InsertBack(func_ast->params(),
ctx.dst->Param(new_var_symbol, param_type));
}
// Replace all uses of the module-scope variable.
// For non-entry points, dereference non-handle pointer parameters.
for (auto* user : var->Users()) {
if (user->Stmt()->Function() == func_ast) {
ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
if (!is_entry_point && !store_type->is_handle()) {
// If this identifier is used by an address-of operator, just remove
// the address-of instead of adding a deref, since we already have a
// pointer.
auto* ident = user->Declaration()->As<ast::IdentifierExpression>();
if (ident_to_address_of.count(ident)) {
ctx.Replace(ident_to_address_of[ident], expr);
continue;
}
expr = ctx.dst->Deref(expr);
}
ctx.Replace(user->Declaration(), expr);
}
}
var_to_symbol[var] = new_var_symbol;
}
// Pass the variables as pointers to any functions that need them.
for (auto* call : calls_to_replace[func_ast]) {
auto* target = ctx.src->AST().Functions().Find(call->func()->symbol());
auto* target_sem = ctx.src->Sem().Get(target);
// Add new arguments for any variables that are needed by the callee.
// For entry points, pass non-handle types as pointers.
for (auto* target_var : target_sem->ReferencedModuleVariables()) {
if (target_var->StorageClass() == ast::StorageClass::kPrivate ||
target_var->StorageClass() == ast::StorageClass::kWorkgroup ||
target_var->StorageClass() == ast::StorageClass::kUniformConstant) {
ast::Expression* arg = ctx.dst->Expr(var_to_symbol[target_var]);
if (is_entry_point && !target_var->Type()->UnwrapRef()->is_handle()) {
arg = ctx.dst->AddressOf(arg);
}
ctx.InsertBack(call->params(), arg);
}
}
}
}
// Now remove all module-scope variables with these storage classes.
for (auto* var_ast : ctx.src->AST().GlobalVariables()) {
auto* var_sem = ctx.src->Sem().Get(var_ast);
if (var_sem->StorageClass() == ast::StorageClass::kPrivate ||
var_sem->StorageClass() == ast::StorageClass::kWorkgroup ||
var_sem->StorageClass() == ast::StorageClass::kUniformConstant) {
ctx.Remove(ctx.src->AST().GlobalDeclarations(), var_ast);
}
}
ctx.Clone();
}
} // namespace transform
} // namespace tint

View File

@ -0,0 +1,82 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TRANSFORM_MODULE_SCOPE_VAR_TO_ENTRY_POINT_PARAM_H_
#define SRC_TRANSFORM_MODULE_SCOPE_VAR_TO_ENTRY_POINT_PARAM_H_
#include "src/transform/transform.h"
namespace tint {
namespace transform {
/// Move module-scope variables into the entry point as parameters.
///
/// MSL does not allow private and workgroup variables at module-scope, so we
/// push these declarations into the entry point function and then pass them as
/// pointer parameters to any function that references them.
/// Similarly, texture and sampler types are converted to entry point
/// parameters and passed by value to functions that need them.
///
/// Since WGSL does not allow function-scope variables to have these storage
/// classes, we annotate the new variable declarations with an attribute that
/// bypasses that validation rule.
///
/// Before:
/// ```
/// var<private> v : f32 = 2.0;
///
/// fn foo() {
/// v = v + 1.0;
/// }
///
/// [[stage(compute), workgroup_size(1)]]
/// fn main() {
/// foo();
/// }
/// ```
///
/// After:
/// ```
/// fn foo(v : ptr<private, f32>) {
/// *v = *v + 1.0;
/// }
///
/// [[stage(compute), workgroup_size(1)]]
/// fn main() {
/// var<private> v : f32 = 2.0;
/// foo(&v);
/// }
/// ```
class ModuleScopeVarToEntryPointParam
: public Castable<ModuleScopeVarToEntryPointParam, Transform> {
public:
/// Constructor
ModuleScopeVarToEntryPointParam();
/// Destructor
~ModuleScopeVarToEntryPointParam() override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TRANSFORM_MODULE_SCOPE_VAR_TO_ENTRY_POINT_PARAM_H_

View File

@ -0,0 +1,303 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/transform/module_scope_var_to_entry_point_param.h"
#include <utility>
#include "src/transform/test_helper.h"
namespace tint {
namespace transform {
namespace {
using ModuleScopeVarToEntryPointParamTest = TransformTest;
TEST_F(ModuleScopeVarToEntryPointParamTest, Basic) {
auto* src = R"(
var<private> p : f32;
var<workgroup> w : f32;
[[stage(compute), workgroup_size(1)]]
fn main() {
w = p;
}
)";
auto* expect = R"(
[[stage(compute), workgroup_size(1)]]
fn main() {
[[internal(disable_validation__ignore_storage_class)]] var<workgroup> tint_symbol : f32;
[[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol_1 : f32;
tint_symbol = tint_symbol_1;
}
)";
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, FunctionCalls) {
auto* src = R"(
var<private> p : f32;
var<workgroup> w : f32;
fn no_uses() {
}
fn bar(a : f32, b : f32) {
p = a;
w = b;
}
fn foo(a : f32) {
let b : f32 = 2.0;
bar(a, b);
no_uses();
}
[[stage(compute), workgroup_size(1)]]
fn main() {
foo(1.0);
}
)";
auto* expect = R"(
fn no_uses() {
}
fn bar(a : f32, b : f32, tint_symbol : ptr<private, f32>, tint_symbol_1 : ptr<workgroup, f32>) {
*(tint_symbol) = a;
*(tint_symbol_1) = b;
}
fn foo(a : f32, tint_symbol_2 : ptr<private, f32>, tint_symbol_3 : ptr<workgroup, f32>) {
let b : f32 = 2.0;
bar(a, b, tint_symbol_2, tint_symbol_3);
no_uses();
}
[[stage(compute), workgroup_size(1)]]
fn main() {
[[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol_4 : f32;
[[internal(disable_validation__ignore_storage_class)]] var<workgroup> tint_symbol_5 : f32;
foo(1.0, &(tint_symbol_4), &(tint_symbol_5));
}
)";
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Constructors) {
auto* src = R"(
var<private> a : f32 = 1.0;
var<private> b : f32 = f32();
[[stage(compute), workgroup_size(1)]]
fn main() {
let x : f32 = a + b;
}
)";
auto* expect = R"(
[[stage(compute), workgroup_size(1)]]
fn main() {
[[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol : f32 = 1.0;
[[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol_1 : f32 = f32();
let x : f32 = (tint_symbol + tint_symbol_1);
}
)";
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Pointers) {
auto* src = R"(
var<private> p : f32;
var<workgroup> w : f32;
[[stage(compute), workgroup_size(1)]]
fn main() {
let p_ptr : ptr<private, f32> = &p;
let w_ptr : ptr<workgroup, f32> = &w;
let x : f32 = *p_ptr + *w_ptr;
*p_ptr = x;
}
)";
auto* expect = R"(
[[stage(compute), workgroup_size(1)]]
fn main() {
[[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol : f32;
[[internal(disable_validation__ignore_storage_class)]] var<workgroup> tint_symbol_1 : f32;
let p_ptr : ptr<private, f32> = &(tint_symbol);
let w_ptr : ptr<workgroup, f32> = &(tint_symbol_1);
let x : f32 = (*(p_ptr) + *(w_ptr));
*(p_ptr) = x;
}
)";
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, UnusedVariables) {
auto* src = R"(
var<private> p : f32;
var<workgroup> w : f32;
[[stage(compute), workgroup_size(1)]]
fn main() {
}
)";
auto* expect = R"(
[[stage(compute), workgroup_size(1)]]
fn main() {
}
)";
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, OtherVariables) {
auto* src = R"(
[[block]]
struct S {
a : f32;
};
[[group(0), binding(0)]]
var<uniform> u : S;
[[stage(compute), workgroup_size(1)]]
fn main() {
}
)";
auto* expect = R"(
[[block]]
struct S {
a : f32;
};
[[group(0), binding(0)]] var<uniform> u : S;
[[stage(compute), workgroup_size(1)]]
fn main() {
}
)";
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, HandleTypes_Basic) {
auto* src = R"(
[[group(0), binding(0)]] var t : texture_2d<f32>;
[[group(0), binding(1)]] var s : sampler;
[[stage(compute), workgroup_size(1)]]
fn main() {
ignore(t);
ignore(s);
}
)";
auto* expect = R"(
[[stage(compute), workgroup_size(1)]]
fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter)]] tint_symbol : texture_2d<f32>, [[group(0), binding(1), internal(disable_validation__entry_point_parameter)]] tint_symbol_1 : sampler) {
ignore(tint_symbol);
ignore(tint_symbol_1);
}
)";
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, HandleTypes_FunctionCalls) {
auto* src = R"(
[[group(0), binding(0)]] var t : texture_2d<f32>;
[[group(0), binding(1)]] var s : sampler;
fn no_uses() {
}
fn bar(a : f32, b : f32) {
ignore(t);
ignore(s);
}
fn foo(a : f32) {
let b : f32 = 2.0;
ignore(t);
bar(a, b);
no_uses();
}
[[stage(compute), workgroup_size(1)]]
fn main() {
foo(1.0);
}
)";
auto* expect = R"(
fn no_uses() {
}
fn bar(a : f32, b : f32, tint_symbol : texture_2d<f32>, tint_symbol_1 : sampler) {
ignore(tint_symbol);
ignore(tint_symbol_1);
}
fn foo(a : f32, tint_symbol_2 : texture_2d<f32>, tint_symbol_3 : sampler) {
let b : f32 = 2.0;
ignore(tint_symbol_2);
bar(a, b, tint_symbol_2, tint_symbol_3);
no_uses();
}
[[stage(compute), workgroup_size(1)]]
fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter)]] tint_symbol_4 : texture_2d<f32>, [[group(0), binding(1), internal(disable_validation__entry_point_parameter)]] tint_symbol_5 : sampler) {
foo(1.0, tint_symbol_4, tint_symbol_5);
}
)";
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, EmtpyModule) {
auto* src = "";
auto got = Run<ModuleScopeVarToEntryPointParam>(src);
EXPECT_EQ(src, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@ -30,6 +30,7 @@
#include "src/transform/external_texture_transform.h" #include "src/transform/external_texture_transform.h"
#include "src/transform/inline_pointer_lets.h" #include "src/transform/inline_pointer_lets.h"
#include "src/transform/manager.h" #include "src/transform/manager.h"
#include "src/transform/module_scope_var_to_entry_point_param.h"
#include "src/transform/pad_array_elements.h" #include "src/transform/pad_array_elements.h"
#include "src/transform/promote_initializers_to_const_var.h" #include "src/transform/promote_initializers_to_const_var.h"
#include "src/transform/simplify.h" #include "src/transform/simplify.h"
@ -86,6 +87,7 @@ Output Msl::Run(const Program* in, const DataMap& inputs) {
manager.Add<PromoteInitializersToConstVar>(); manager.Add<PromoteInitializersToConstVar>();
manager.Add<WrapArraysInStructs>(); manager.Add<WrapArraysInStructs>();
manager.Add<PadArrayElements>(); manager.Add<PadArrayElements>();
manager.Add<ModuleScopeVarToEntryPointParam>();
manager.Add<InlinePointerLets>(); manager.Add<InlinePointerLets>();
manager.Add<Simplify>(); manager.Add<Simplify>();
// ArrayLengthFromUniform must come after InlinePointerLets and Simplify, as // ArrayLengthFromUniform must come after InlinePointerLets and Simplify, as
@ -102,9 +104,7 @@ Output Msl::Run(const Program* in, const DataMap& inputs) {
ProgramBuilder builder; ProgramBuilder builder;
CloneContext ctx(&builder, &out.program); CloneContext ctx(&builder, &out.program);
// TODO(jrprice): Consider making this a standalone transform, with target // TODO(jrprice): Move the sanitizer into the backend.
// storage class(es) as transform options.
HandleModuleScopeVariables(ctx);
ctx.Clone(); ctx.Clone();
auto result = std::make_unique<Result>( auto result = std::make_unique<Result>(
@ -114,203 +114,6 @@ Output Msl::Run(const Program* in, const DataMap& inputs) {
return Output{Program(std::move(builder)), std::move(result)}; return Output{Program(std::move(builder)), std::move(result)};
} }
void Msl::HandleModuleScopeVariables(CloneContext& ctx) const {
// MSL does not allow private and workgroup variables at module-scope, so we
// push these declarations into the entry point function and then pass them as
// pointer parameters to any function that references them.
// Similarly, texture and sampler types are converted to entry point
// parameters and passed by value to functions that need them.
//
// Since WGSL does not allow function-scope variables to have these storage
// classes, we annotate the new variable declarations with an attribute that
// bypasses that validation rule.
//
// Before:
// ```
// var<private> v : f32 = 2.0;
//
// fn foo() {
// v = v + 1.0;
// }
//
// [[stage(compute), workgroup_size(1)]]
// fn main() {
// foo();
// }
// ```
//
// After:
// ```
// fn foo(v : ptr<private, f32>) {
// *v = *v + 1.0;
// }
//
// [[stage(compute), workgroup_size(1)]]
// fn main() {
// var<private> v : f32 = 2.0;
// foo(&v);
// }
// ```
// Predetermine the list of function calls that need to be replaced.
using CallList = std::vector<const ast::CallExpression*>;
std::unordered_map<const ast::Function*, CallList> calls_to_replace;
std::vector<ast::Function*> functions_to_process;
// Build a list of functions that transitively reference any private or
// workgroup variables, or texture/sampler variables.
for (auto* func_ast : ctx.src->AST().Functions()) {
auto* func_sem = ctx.src->Sem().Get(func_ast);
bool needs_processing = false;
for (auto* var : func_sem->ReferencedModuleVariables()) {
if (var->StorageClass() == ast::StorageClass::kPrivate ||
var->StorageClass() == ast::StorageClass::kWorkgroup ||
var->StorageClass() == ast::StorageClass::kUniformConstant) {
needs_processing = true;
break;
}
}
if (needs_processing) {
functions_to_process.push_back(func_ast);
// Find all of the calls to this function that will need to be replaced.
for (auto* call : func_sem->CallSites()) {
auto* call_sem = ctx.src->Sem().Get(call);
calls_to_replace[call_sem->Stmt()->Function()].push_back(call);
}
}
}
// Build a list of `&ident` expressions. We'll use this later to avoid
// generating expressions of the form `&*ident`, which break WGSL validation
// rules when this expression is passed to a function.
// TODO(jrprice): We should add support for bidirectional SEM tree traversal
// so that we can do this on the fly instead.
std::unordered_map<ast::IdentifierExpression*, ast::UnaryOpExpression*>
ident_to_address_of;
for (auto* node : ctx.src->ASTNodes().Objects()) {
auto* address_of = node->As<ast::UnaryOpExpression>();
if (!address_of || address_of->op() != ast::UnaryOp::kAddressOf) {
continue;
}
if (auto* ident = address_of->expr()->As<ast::IdentifierExpression>()) {
ident_to_address_of[ident] = address_of;
}
}
for (auto* func_ast : functions_to_process) {
auto* func_sem = ctx.src->Sem().Get(func_ast);
bool is_entry_point = func_ast->IsEntryPoint();
// Map module-scope variables onto their function-scope replacement.
std::unordered_map<const sem::Variable*, Symbol> var_to_symbol;
for (auto* var : func_sem->ReferencedModuleVariables()) {
if (var->StorageClass() != ast::StorageClass::kPrivate &&
var->StorageClass() != ast::StorageClass::kWorkgroup &&
var->StorageClass() != ast::StorageClass::kUniformConstant) {
continue;
}
// This is the symbol for the variable that replaces the module-scope var.
auto new_var_symbol = ctx.dst->Sym();
auto* store_type = CreateASTTypeFor(ctx, var->Type()->UnwrapRef());
if (is_entry_point) {
if (store_type->is_handle()) {
// For a texture or sampler variable, redeclare it as an entry point
// parameter. Disable entry point parameter validation.
auto* disable_validation =
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(), ast::DisabledValidation::kEntryPointParameter);
auto decos = ctx.Clone(var->Declaration()->decorations());
decos.push_back(disable_validation);
auto* param = ctx.dst->Param(new_var_symbol, store_type, decos);
ctx.InsertFront(func_ast->params(), param);
} else {
// For a private or workgroup variable, redeclare it at function
// scope. Disable storage class validation on this variable.
auto* disable_validation =
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(), ast::DisabledValidation::kIgnoreStorageClass);
auto* constructor = ctx.Clone(var->Declaration()->constructor());
auto* local_var = ctx.dst->Var(
new_var_symbol, store_type, var->StorageClass(), constructor,
ast::DecorationList{disable_validation});
ctx.InsertFront(func_ast->body()->statements(),
ctx.dst->Decl(local_var));
}
} else {
// For a regular function, redeclare the variable as a parameter.
// Use a pointer for non-handle types.
auto* param_type = store_type;
if (!store_type->is_handle()) {
param_type = ctx.dst->ty.pointer(param_type, var->StorageClass());
}
ctx.InsertBack(func_ast->params(),
ctx.dst->Param(new_var_symbol, param_type));
}
// Replace all uses of the module-scope variable.
// For non-entry points, dereference non-handle pointer parameters.
for (auto* user : var->Users()) {
if (user->Stmt()->Function() == func_ast) {
ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
if (!is_entry_point && !store_type->is_handle()) {
// If this identifier is used by an address-of operator, just remove
// the address-of instead of adding a deref, since we already have a
// pointer.
auto* ident = user->Declaration()->As<ast::IdentifierExpression>();
if (ident_to_address_of.count(ident)) {
ctx.Replace(ident_to_address_of[ident], expr);
continue;
}
expr = ctx.dst->Deref(expr);
}
ctx.Replace(user->Declaration(), expr);
}
}
var_to_symbol[var] = new_var_symbol;
}
// Pass the variables as pointers to any functions that need them.
for (auto* call : calls_to_replace[func_ast]) {
auto* target = ctx.src->AST().Functions().Find(call->func()->symbol());
auto* target_sem = ctx.src->Sem().Get(target);
// Add new arguments for any variables that are needed by the callee.
// For entry points, pass non-handle types as pointers.
for (auto* target_var : target_sem->ReferencedModuleVariables()) {
if (target_var->StorageClass() == ast::StorageClass::kPrivate ||
target_var->StorageClass() == ast::StorageClass::kWorkgroup ||
target_var->StorageClass() == ast::StorageClass::kUniformConstant) {
ast::Expression* arg = ctx.dst->Expr(var_to_symbol[target_var]);
if (is_entry_point && !target_var->Type()->UnwrapRef()->is_handle()) {
arg = ctx.dst->AddressOf(arg);
}
ctx.InsertBack(call->params(), arg);
}
}
}
}
// Now remove all module-scope variables with these storage classes.
for (auto* var_ast : ctx.src->AST().GlobalVariables()) {
auto* var_sem = ctx.src->Sem().Get(var_ast);
if (var_sem->StorageClass() == ast::StorageClass::kPrivate ||
var_sem->StorageClass() == ast::StorageClass::kWorkgroup ||
var_sem->StorageClass() == ast::StorageClass::kUniformConstant) {
ctx.Remove(ctx.src->AST().GlobalDeclarations(), var_ast);
}
}
}
Msl::Config::Config(uint32_t buffer_size_ubo_idx, Msl::Config::Config(uint32_t buffer_size_ubo_idx,
uint32_t sample_mask, uint32_t sample_mask,
bool emit_point_size, bool emit_point_size,

View File

@ -86,12 +86,6 @@ class Msl : public Castable<Msl, Transform> {
/// @param data optional extra transform-specific input data /// @param data optional extra transform-specific input data
/// @returns the transformation result /// @returns the transformation result
Output Run(const Program* program, const DataMap& data = {}) override; Output Run(const Program* program, const DataMap& data = {}) override;
private:
/// Pushes module-scope variables with certain storage classes into the entry
/// point function, and passes them as function parameters to any functions
/// that need them.
void HandleModuleScopeVariables(CloneContext& ctx) const;
}; };
} // namespace transform } // namespace transform

View File

@ -22,301 +22,7 @@ namespace {
using MslTest = TransformTest; using MslTest = TransformTest;
TEST_F(MslTest, HandleModuleScopeVariables_Basic) { // TODO(jrprice): Remove this file when we remove the sanitizers.
auto* src = R"(
var<private> p : f32;
var<workgroup> w : f32;
[[stage(compute), workgroup_size(1)]]
fn main() {
w = p;
}
)";
auto* expect = R"(
fn main_inner(local_invocation_index : u32, tint_symbol : ptr<workgroup, f32>, tint_symbol_1 : ptr<private, f32>) {
{
*(tint_symbol) = f32();
}
workgroupBarrier();
*(tint_symbol) = *(tint_symbol_1);
}
[[stage(compute), workgroup_size(1)]]
fn main([[builtin(local_invocation_index)]] local_invocation_index : u32) {
[[internal(disable_validation__ignore_storage_class)]] var<workgroup> tint_symbol_2 : f32;
[[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol_3 : f32;
main_inner(local_invocation_index, &(tint_symbol_2), &(tint_symbol_3));
}
)";
auto got = Run<Msl>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(MslTest, HandleModuleScopeVariables_FunctionCalls) {
auto* src = R"(
var<private> p : f32;
var<workgroup> w : f32;
fn no_uses() {
}
fn bar(a : f32, b : f32) {
p = a;
w = b;
}
fn foo(a : f32) {
let b : f32 = 2.0;
bar(a, b);
no_uses();
}
[[stage(compute), workgroup_size(1)]]
fn main() {
foo(1.0);
}
)";
auto* expect = R"(
fn no_uses() {
}
fn bar(a : f32, b : f32, tint_symbol : ptr<private, f32>, tint_symbol_1 : ptr<workgroup, f32>) {
*(tint_symbol) = a;
*(tint_symbol_1) = b;
}
fn foo(a : f32, tint_symbol_2 : ptr<private, f32>, tint_symbol_3 : ptr<workgroup, f32>) {
let b : f32 = 2.0;
bar(a, b, tint_symbol_2, tint_symbol_3);
no_uses();
}
fn main_inner(local_invocation_index : u32, tint_symbol_4 : ptr<workgroup, f32>, tint_symbol_5 : ptr<private, f32>) {
{
*(tint_symbol_4) = f32();
}
workgroupBarrier();
foo(1.0, tint_symbol_5, tint_symbol_4);
}
[[stage(compute), workgroup_size(1)]]
fn main([[builtin(local_invocation_index)]] local_invocation_index : u32) {
[[internal(disable_validation__ignore_storage_class)]] var<workgroup> tint_symbol_6 : f32;
[[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol_7 : f32;
main_inner(local_invocation_index, &(tint_symbol_6), &(tint_symbol_7));
}
)";
auto got = Run<Msl>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(MslTest, HandleModuleScopeVariables_Constructors) {
auto* src = R"(
var<private> a : f32 = 1.0;
var<private> b : f32 = f32();
[[stage(compute), workgroup_size(1)]]
fn main() {
let x : f32 = a + b;
}
)";
auto* expect = R"(
[[stage(compute), workgroup_size(1)]]
fn main() {
[[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol : f32 = 1.0;
[[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol_1 : f32 = f32();
let x : f32 = (tint_symbol + tint_symbol_1);
}
)";
auto got = Run<Msl>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(MslTest, HandleModuleScopeVariables_Pointers) {
auto* src = R"(
var<private> p : f32;
var<workgroup> w : f32;
[[stage(compute), workgroup_size(1)]]
fn main() {
let p_ptr : ptr<private, f32> = &p;
let w_ptr : ptr<workgroup, f32> = &w;
let x : f32 = *p_ptr + *w_ptr;
*p_ptr = x;
}
)";
auto* expect = R"(
fn main_inner(local_invocation_index : u32, tint_symbol : ptr<workgroup, f32>, tint_symbol_1 : ptr<private, f32>) {
{
*(tint_symbol) = f32();
}
workgroupBarrier();
let x : f32 = (*(tint_symbol_1) + *(tint_symbol));
*(tint_symbol_1) = x;
}
[[stage(compute), workgroup_size(1)]]
fn main([[builtin(local_invocation_index)]] local_invocation_index : u32) {
[[internal(disable_validation__ignore_storage_class)]] var<workgroup> tint_symbol_2 : f32;
[[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol_3 : f32;
main_inner(local_invocation_index, &(tint_symbol_2), &(tint_symbol_3));
}
)";
auto got = Run<Msl>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(MslTest, HandleModuleScopeVariables_UnusedVariables) {
auto* src = R"(
var<private> p : f32;
var<workgroup> w : f32;
[[stage(compute), workgroup_size(1)]]
fn main() {
}
)";
auto* expect = R"(
[[stage(compute), workgroup_size(1)]]
fn main() {
}
)";
auto got = Run<Msl>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(MslTest, HandleModuleScopeVariables_OtherVariables) {
auto* src = R"(
[[block]]
struct S {
a : f32;
};
[[group(0), binding(0)]]
var<uniform> u : S;
[[stage(compute), workgroup_size(1)]]
fn main() {
}
)";
auto* expect = R"(
[[block]]
struct S {
a : f32;
};
[[group(0), binding(0)]] var<uniform> u : S;
[[stage(compute), workgroup_size(1)]]
fn main() {
}
)";
auto got = Run<Msl>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(MslTest, HandleModuleScopeVariables_HandleTypes_Basic) {
auto* src = R"(
[[group(0), binding(0)]] var t : texture_2d<f32>;
[[group(0), binding(1)]] var s : sampler;
[[stage(compute), workgroup_size(1)]]
fn main() {
ignore(t);
ignore(s);
}
)";
auto* expect = R"(
[[stage(compute), workgroup_size(1)]]
fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter)]] tint_symbol : texture_2d<f32>, [[group(0), binding(1), internal(disable_validation__entry_point_parameter)]] tint_symbol_1 : sampler) {
ignore(tint_symbol);
ignore(tint_symbol_1);
}
)";
auto got = Run<Msl>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(MslTest, HandleModuleScopeVariables_HandleTypes_FunctionCalls) {
auto* src = R"(
[[group(0), binding(0)]] var t : texture_2d<f32>;
[[group(0), binding(1)]] var s : sampler;
fn no_uses() {
}
fn bar(a : f32, b : f32) {
ignore(t);
ignore(s);
}
fn foo(a : f32) {
let b : f32 = 2.0;
ignore(t);
bar(a, b);
no_uses();
}
[[stage(compute), workgroup_size(1)]]
fn main() {
foo(1.0);
}
)";
auto* expect = R"(
fn no_uses() {
}
fn bar(a : f32, b : f32, tint_symbol : texture_2d<f32>, tint_symbol_1 : sampler) {
ignore(tint_symbol);
ignore(tint_symbol_1);
}
fn foo(a : f32, tint_symbol_2 : texture_2d<f32>, tint_symbol_3 : sampler) {
let b : f32 = 2.0;
ignore(tint_symbol_2);
bar(a, b, tint_symbol_2, tint_symbol_3);
no_uses();
}
[[stage(compute), workgroup_size(1)]]
fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter)]] tint_symbol_4 : texture_2d<f32>, [[group(0), binding(1), internal(disable_validation__entry_point_parameter)]] tint_symbol_5 : sampler) {
foo(1.0, tint_symbol_4, tint_symbol_5);
}
)";
auto got = Run<Msl>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(MslTest, HandleModuleScopeVariables_EmtpyModule) {
auto* src = "";
auto got = Run<Msl>(src);
EXPECT_EQ(src, str(got));
}
} // namespace } // namespace
} // namespace transform } // namespace transform

View File

@ -301,6 +301,7 @@ tint_unittests_source_set("tint_unittests_core_src") {
"../src/transform/for_loop_to_loop_test.cc", "../src/transform/for_loop_to_loop_test.cc",
"../src/transform/inline_pointer_lets_test.cc", "../src/transform/inline_pointer_lets_test.cc",
"../src/transform/loop_to_for_loop_test.cc", "../src/transform/loop_to_for_loop_test.cc",
"../src/transform/module_scope_var_to_entry_point_param_test.cc",
"../src/transform/pad_array_elements_test.cc", "../src/transform/pad_array_elements_test.cc",
"../src/transform/promote_initializers_to_const_var_test.cc", "../src/transform/promote_initializers_to_const_var_test.cc",
"../src/transform/renamer_test.cc", "../src/transform/renamer_test.cc",