diff --git a/src/BUILD.gn b/src/BUILD.gn index b6e4b20301..0ff76ce00d 100644 --- a/src/BUILD.gn +++ b/src/BUILD.gn @@ -584,6 +584,8 @@ libtint_source_set("libtint_core_all_src") { "transform/vertex_pulling.h", "transform/wrap_arrays_in_structs.cc", "transform/wrap_arrays_in_structs.h", + "transform/zero_init_workgroup_memory.cc", + "transform/zero_init_workgroup_memory.h", "utils/enum_set.h", "utils/get_or_create.h", "utils/hash.h", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 541735223f..fc5d550a70 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -309,6 +309,8 @@ set(TINT_LIB_SRCS transform/vertex_pulling.h transform/wrap_arrays_in_structs.cc transform/wrap_arrays_in_structs.h + transform/zero_init_workgroup_memory.cc + transform/zero_init_workgroup_memory.h sem/bool_type.cc sem/bool_type.h sem/depth_texture_type.cc @@ -873,6 +875,7 @@ if(${TINT_BUILD_TESTS}) transform/test_helper.h transform/vertex_pulling_test.cc transform/wrap_arrays_in_structs_test.cc + transform/zero_init_workgroup_memory_test.cc ) endif() diff --git a/src/transform/zero_init_workgroup_memory.cc b/src/transform/zero_init_workgroup_memory.cc new file mode 100644 index 0000000000..53f96182af --- /dev/null +++ b/src/transform/zero_init_workgroup_memory.cc @@ -0,0 +1,200 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/transform/zero_init_workgroup_memory.h" + +#include +#include + +#include "src/program_builder.h" +#include "src/sem/atomic_type.h" +#include "src/sem/function.h" +#include "src/sem/variable.h" +#include "src/utils/get_or_create.h" + +namespace tint { +namespace transform { + +// PIMPL state for the ZeroInitWorkgroupMemory transform +struct ZeroInitWorkgroupMemory::State { + /// The clone context + CloneContext& ctx; + /// The built statements + ast::StatementList& stmts; + + /// Zero() generates the statements required to zero initialize the workgroup + /// storage expression of type `ty`. + /// @param ty the expression type + /// @param get_expr a function that builds the AST nodes for the expression + void Zero(const sem::Type* ty, + const std::function& get_expr) { + if (CanZero(ty)) { + auto* var = get_expr(); + auto* zero_init = ctx.dst->Construct(CreateASTTypeFor(&ctx, ty)); + stmts.emplace_back( + ctx.dst->create(var, zero_init)); + return; + } + + if (auto* atomic = ty->As()) { + auto* zero_init = + ctx.dst->Construct(CreateASTTypeFor(&ctx, atomic->Type())); + auto* store = ctx.dst->Call("atomicStore", ctx.dst->AddressOf(get_expr()), + zero_init); + stmts.emplace_back(ctx.dst->create(store)); + return; + } + + if (auto* str = ty->As()) { + for (auto* member : str->Members()) { + auto name = ctx.Clone(member->Declaration()->symbol()); + Zero(member->Type(), + [&] { return ctx.dst->MemberAccessor(get_expr(), name); }); + } + return; + } + + if (auto* arr = ty->As()) { + // TODO(bclayton): If array sizes become pipeline-overridable then this + // will need to emit code for a loop. + // See https://github.com/gpuweb/gpuweb/pull/1792 + for (size_t i = 0; i < arr->Count(); i++) { + Zero(arr->ElemType(), [&] { + return ctx.dst->IndexAccessor(get_expr(), + static_cast(i)); + }); + } + return; + } + + TINT_UNREACHABLE(ctx.dst->Diagnostics()) + << "could not zero workgroup type: " << ty->type_name(); + } + + /// @returns true if the type `ty` can be zeroed with a simple zero-value + /// expression in the form of a type constructor without operands. If + /// CanZero() returns false, then the type needs to be initialized by + /// decomposing the initialization into multiple sub-initializations. + /// @param ty the type to inspect + static bool CanZero(const sem::Type* ty) { + if (ty->Is()) { + return false; + } + if (auto* str = ty->As()) { + for (auto* member : str->Members()) { + if (!CanZero(member->Type())) { + return false; + } + } + } + if (auto* arr = ty->As()) { + if (!CanZero(arr->ElemType())) { + return false; + } + } + return true; + } +}; + +ZeroInitWorkgroupMemory::ZeroInitWorkgroupMemory() = default; + +ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default; + +Output ZeroInitWorkgroupMemory::Run(const Program* in, const DataMap&) { + ProgramBuilder out; + CloneContext ctx(&out, in); + + auto& sem = ctx.src->Sem(); + + for (auto* ast_func : in->AST().Functions()) { + if (!ast_func->IsEntryPoint()) { + continue; + } + + // Generate a list of statements to zero initialize each of the workgroup + // storage variables. + ast::StatementList stmts; + auto* func = sem.Get(ast_func); + for (auto* var : func->ReferencedModuleVariables()) { + if (var->StorageClass() != ast::StorageClass::kWorkgroup) { + continue; + } + State{ctx, stmts}.Zero(var->Type()->UnwrapRef(), [&] { + auto var_name = ctx.Clone(var->Declaration()->symbol()); + return ctx.dst->Expr(var_name); + }); + } + + if (stmts.empty()) { + continue; // No workgroup variables to initialize. + } + + // Scan the entry point for an existing local_invocation_index builtin + // parameter + ast::Expression* local_index = nullptr; + for (auto* param : ast_func->params()) { + if (auto* builtin = ast::GetDecoration( + param->decorations())) { + if (builtin->value() == ast::Builtin::kLocalInvocationIndex) { + local_index = ctx.dst->Expr(ctx.Clone(param->symbol())); + break; + } + } + + if (auto* str = sem.Get(param)->Type()->As()) { + for (auto* member : str->Members()) { + if (auto* builtin = ast::GetDecoration( + member->Declaration()->decorations())) { + if (builtin->value() == ast::Builtin::kLocalInvocationIndex) { + auto* param_expr = ctx.dst->Expr(ctx.Clone(param->symbol())); + auto member_name = ctx.Clone(member->Declaration()->symbol()); + local_index = ctx.dst->MemberAccessor(param_expr, member_name); + break; + } + } + } + } + } + if (!local_index) { + // No existing local index parameter. Append one to the entry point. + auto* param = ctx.dst->Param( + ctx.dst->Symbols().New("local_invocation_index"), ctx.dst->ty.u32(), + {ctx.dst->Builtin(ast::Builtin::kLocalInvocationIndex)}); + ctx.InsertBack(ast_func->params(), param); + local_index = ctx.dst->Expr(param->symbol()); + } + + // We only want to zero-initialize the workgroup memory with the first + // shader invocation. Construct an if statement that holds stmts. + // TODO(crbug.com/tint/910): We should attempt to optimize this for arrays. + auto* if_zero_local_index = ctx.dst->create( + ast::BinaryOp::kEqual, local_index, ctx.dst->Expr(0u)); + auto* if_stmt = ctx.dst->If(if_zero_local_index, ctx.dst->Block(stmts)); + + // Insert this if-statement at the top of the entry point. + ctx.InsertFront(ast_func->body()->statements(), if_stmt); + + // Append a single workgroup barrier after the if statement. + ctx.InsertFront( + ast_func->body()->statements(), + ctx.dst->create(ctx.dst->Call("workgroupBarrier"))); + } + + ctx.Clone(); + + return Output(Program(std::move(out))); +} + +} // namespace transform +} // namespace tint diff --git a/src/transform/zero_init_workgroup_memory.h b/src/transform/zero_init_workgroup_memory.h new file mode 100644 index 0000000000..bf846c746e --- /dev/null +++ b/src/transform/zero_init_workgroup_memory.h @@ -0,0 +1,47 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TRANSFORM_ZERO_INIT_WORKGROUP_MEMORY_H_ +#define SRC_TRANSFORM_ZERO_INIT_WORKGROUP_MEMORY_H_ + +#include "src/transform/transform.h" + +namespace tint { +namespace transform { + +/// ZeroInitWorkgroupMemory is a transform that injects code at the top of entry +/// points to zero-initialize workgroup memory used by that entry point (and all +/// transitive functions called by that entry point) +class ZeroInitWorkgroupMemory : public Transform { + public: + /// Constructor + ZeroInitWorkgroupMemory(); + + /// Destructor + ~ZeroInitWorkgroupMemory() override; + + /// Runs the transform on `program`, returning the transformation result. + /// @param program the source program to transform + /// @param data optional extra transform-specific input data + /// @returns the transformation result + Output Run(const Program* program, const DataMap& data = {}) override; + + private: + struct State; +}; + +} // namespace transform +} // namespace tint + +#endif // SRC_TRANSFORM_ZERO_INIT_WORKGROUP_MEMORY_H_ diff --git a/src/transform/zero_init_workgroup_memory_test.cc b/src/transform/zero_init_workgroup_memory_test.cc new file mode 100644 index 0000000000..ab1b305f18 --- /dev/null +++ b/src/transform/zero_init_workgroup_memory_test.cc @@ -0,0 +1,563 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/transform/zero_init_workgroup_memory.h" + +#include + +#include "src/transform/test_helper.h" + +namespace tint { +namespace transform { +namespace { + +using ZeroInitWorkgroupMemoryTest = TransformTest; + +TEST_F(ZeroInitWorkgroupMemoryTest, EmptyModule) { + auto* src = ""; + auto* expect = src; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ZeroInitWorkgroupMemoryTest, NoWorkgroupVars) { + auto* src = R"( +var v : i32; + +fn f() { + v = 1; +} +)"; + auto* expect = src; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ZeroInitWorkgroupMemoryTest, UnreferencedWorkgroupVars) { + auto* src = R"( +var a : i32; + +var b : i32; + +var c : i32; + +fn unreferenced() { + b = c; +} + +[[stage(compute), workgroup_size(1)]] +fn f() { +} +)"; + auto* expect = src; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ZeroInitWorkgroupMemoryTest, SingleWorkgroupVar_ExistingLocalIndex) { + auto* src = R"( +var v : i32; + +[[stage(compute), workgroup_size(1)]] +fn f([[builtin(local_invocation_index)]] idx : u32) { + ignore(v); // Initialization should be inserted above this statement +} +)"; + auto* expect = R"( +var v : i32; + +[[stage(compute), workgroup_size(1)]] +fn f([[builtin(local_invocation_index)]] idx : u32) { + if ((idx == 0u)) { + v = i32(); + } + workgroupBarrier(); + ignore(v); +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ZeroInitWorkgroupMemoryTest, + SingleWorkgroupVar_ExistingLocalIndexInStruct) { + auto* src = R"( +var v : i32; + +struct Params { + [[builtin(local_invocation_index)]] idx : u32; +}; + +[[stage(compute), workgroup_size(1)]] +fn f(params : Params) { + ignore(v); // Initialization should be inserted above this statement +} +)"; + auto* expect = R"( +var v : i32; + +struct Params { + [[builtin(local_invocation_index)]] + idx : u32; +}; + +[[stage(compute), workgroup_size(1)]] +fn f(params : Params) { + if ((params.idx == 0u)) { + v = i32(); + } + workgroupBarrier(); + ignore(v); +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ZeroInitWorkgroupMemoryTest, SingleWorkgroupVar_InjectedLocalIndex) { + auto* src = R"( +var v : i32; + +[[stage(compute), workgroup_size(1)]] +fn f() { + ignore(v); // Initialization should be inserted above this statement +} +)"; + auto* expect = R"( +var v : i32; + +[[stage(compute), workgroup_size(1)]] +fn f([[builtin(local_invocation_index)]] local_invocation_index : u32) { + if ((local_invocation_index == 0u)) { + v = i32(); + } + workgroupBarrier(); + ignore(v); +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_ExistingLocalIndex) { + auto* src = R"( +struct S { + x : i32; + y : array; +}; + +var a : i32; + +var b : S; + +var c : array; + +[[stage(compute), workgroup_size(1)]] +fn f([[builtin(local_invocation_index)]] idx : u32) { + ignore(a); // Initialization should be inserted above this statement + ignore(b); + ignore(c); +} +)"; + auto* expect = R"( +struct S { + x : i32; + y : array; +}; + +var a : i32; + +var b : S; + +var c : array; + +[[stage(compute), workgroup_size(1)]] +fn f([[builtin(local_invocation_index)]] idx : u32) { + if ((idx == 0u)) { + a = i32(); + b = S(); + c = array(); + } + workgroupBarrier(); + ignore(a); + ignore(b); + ignore(c); +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_InjectedLocalIndex) { + auto* src = R"( +struct S { + x : i32; + y : array; +}; + +var a : i32; + +var b : S; + +var c : array; + +[[stage(compute), workgroup_size(1)]] +fn f([[builtin(local_invocation_id)]] local_invocation_id : vec3) { + ignore(a); // Initialization should be inserted above this statement + ignore(b); + ignore(c); +} +)"; + auto* expect = R"( +struct S { + x : i32; + y : array; +}; + +var a : i32; + +var b : S; + +var c : array; + +[[stage(compute), workgroup_size(1)]] +fn f([[builtin(local_invocation_id)]] local_invocation_id : vec3, [[builtin(local_invocation_index)]] local_invocation_index : u32) { + if ((local_invocation_index == 0u)) { + a = i32(); + b = S(); + c = array(); + } + workgroupBarrier(); + ignore(a); + ignore(b); + ignore(c); +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_MultipleEntryPoints) { + auto* src = R"( +struct S { + x : i32; + y : array; +}; + +var a : i32; + +var b : S; + +var c : array; + +[[stage(compute), workgroup_size(1)]] +fn f1() { + ignore(a); // Initialization should be inserted above this statement + ignore(c); +} + +[[stage(compute), workgroup_size(1)]] +fn f2([[builtin(local_invocation_id)]] local_invocation_id : vec3) { + ignore(b); // Initialization should be inserted above this statement +} + +[[stage(compute), workgroup_size(1)]] +fn f3() { + ignore(c); // Initialization should be inserted above this statement + ignore(a); +} +)"; + auto* expect = R"( +struct S { + x : i32; + y : array; +}; + +var a : i32; + +var b : S; + +var c : array; + +[[stage(compute), workgroup_size(1)]] +fn f1([[builtin(local_invocation_index)]] local_invocation_index : u32) { + if ((local_invocation_index == 0u)) { + a = i32(); + c = array(); + } + workgroupBarrier(); + ignore(a); + ignore(c); +} + +[[stage(compute), workgroup_size(1)]] +fn f2([[builtin(local_invocation_id)]] local_invocation_id : vec3, [[builtin(local_invocation_index)]] local_invocation_index_1 : u32) { + if ((local_invocation_index_1 == 0u)) { + b = S(); + } + workgroupBarrier(); + ignore(b); +} + +[[stage(compute), workgroup_size(1)]] +fn f3([[builtin(local_invocation_index)]] local_invocation_index_2 : u32) { + if ((local_invocation_index_2 == 0u)) { + c = array(); + a = i32(); + } + workgroupBarrier(); + ignore(c); + ignore(a); +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ZeroInitWorkgroupMemoryTest, TransitiveUsage) { + auto* src = R"( +var v : i32; + +fn use_v() { + ignore(v); +} + +fn call_use_v() { + use_v(); +} + +[[stage(compute), workgroup_size(1)]] +fn f([[builtin(local_invocation_index)]] idx : u32) { + call_use_v(); // Initialization should be inserted above this statement +} +)"; + auto* expect = R"( +var v : i32; + +fn use_v() { + ignore(v); +} + +fn call_use_v() { + use_v(); +} + +[[stage(compute), workgroup_size(1)]] +fn f([[builtin(local_invocation_index)]] idx : u32) { + if ((idx == 0u)) { + v = i32(); + } + workgroupBarrier(); + call_use_v(); +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupAtomics) { + auto* src = R"( +var i : atomic; +var u : atomic; + +[[stage(compute), workgroup_size(1)]] +fn f() { + ignore(i); // Initialization should be inserted above this statement + ignore(u); +} +)"; + auto* expect = R"( +var i : atomic; + +var u : atomic; + +[[stage(compute), workgroup_size(1)]] +fn f([[builtin(local_invocation_index)]] local_invocation_index : u32) { + if ((local_invocation_index == 0u)) { + atomicStore(&(i), i32()); + atomicStore(&(u), u32()); + } + workgroupBarrier(); + ignore(i); + ignore(u); +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupStructOfAtomics) { + auto* src = R"( +struct S { + a : i32; + i : atomic; + b : f32; + u : atomic; + c : u32; +}; + +var w : S; + +[[stage(compute), workgroup_size(1)]] +fn f() { + ignore(w); // Initialization should be inserted above this statement +} +)"; + auto* expect = R"( +struct S { + a : i32; + i : atomic; + b : f32; + u : atomic; + c : u32; +}; + +var w : S; + +[[stage(compute), workgroup_size(1)]] +fn f([[builtin(local_invocation_index)]] local_invocation_index : u32) { + if ((local_invocation_index == 0u)) { + w.a = i32(); + atomicStore(&(w.i), i32()); + w.b = f32(); + atomicStore(&(w.u), u32()); + w.c = u32(); + } + workgroupBarrier(); + ignore(w); +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupArrayOfAtomics) { + auto* src = R"( +var w : array, 4>; + +[[stage(compute), workgroup_size(1)]] +fn f() { + ignore(w); // Initialization should be inserted above this statement +} +)"; + auto* expect = R"( +var w : array, 4>; + +[[stage(compute), workgroup_size(1)]] +fn f([[builtin(local_invocation_index)]] local_invocation_index : u32) { + if ((local_invocation_index == 0u)) { + atomicStore(&(w[0u]), u32()); + atomicStore(&(w[1u]), u32()); + atomicStore(&(w[2u]), u32()); + atomicStore(&(w[3u]), u32()); + } + workgroupBarrier(); + ignore(w); +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupArrayOfStructOfAtomics) { + auto* src = R"( +struct S { + a : i32; + i : atomic; + b : f32; + u : atomic; + c : u32; +}; + +var w : array; + +[[stage(compute), workgroup_size(1)]] +fn f() { + ignore(w); // Initialization should be inserted above this statement +} +)"; + auto* expect = R"( +struct S { + a : i32; + i : atomic; + b : f32; + u : atomic; + c : u32; +}; + +var w : array; + +[[stage(compute), workgroup_size(1)]] +fn f([[builtin(local_invocation_index)]] local_invocation_index : u32) { + if ((local_invocation_index == 0u)) { + w[0u].a = i32(); + atomicStore(&(w[0u].i), i32()); + w[0u].b = f32(); + atomicStore(&(w[0u].u), u32()); + w[0u].c = u32(); + w[1u].a = i32(); + atomicStore(&(w[1u].i), i32()); + w[1u].b = f32(); + atomicStore(&(w[1u].u), u32()); + w[1u].c = u32(); + w[2u].a = i32(); + atomicStore(&(w[2u].i), i32()); + w[2u].b = f32(); + atomicStore(&(w[2u].u), u32()); + w[2u].c = u32(); + w[3u].a = i32(); + atomicStore(&(w[3u].i), i32()); + w[3u].b = f32(); + atomicStore(&(w[3u].u), u32()); + w[3u].c = u32(); + } + workgroupBarrier(); + ignore(w); +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +} // namespace +} // namespace transform +} // namespace tint diff --git a/test/BUILD.gn b/test/BUILD.gn index d4df320b9d..04d781da93 100644 --- a/test/BUILD.gn +++ b/test/BUILD.gn @@ -292,6 +292,7 @@ tint_unittests_source_set("tint_unittests_core_src") { "../src/transform/transform_test.cc", "../src/transform/vertex_pulling_test.cc", "../src/transform/wrap_arrays_in_structs_test.cc", + "../src/transform/zero_init_workgroup_memory_test.cc", "../src/utils/enum_set_test.cc", "../src/utils/get_or_create_test.cc", "../src/utils/hash_test.cc",