diff --git a/src/BUILD.gn b/src/BUILD.gn index 7b0aa75a77..8a7dc89a6a 100644 --- a/src/BUILD.gn +++ b/src/BUILD.gn @@ -403,6 +403,8 @@ source_set("libtint_core_src") { "transform/binding_remapper.h", "transform/bound_array_accessors.cc", "transform/bound_array_accessors.h", + "transform/calculate_array_length.cc", + "transform/calculate_array_length.h", "transform/canonicalize_entry_point_io.cc", "transform/canonicalize_entry_point_io.h", "transform/decompose_storage_access.cc", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 15426878b3..b920e8c56c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -218,6 +218,8 @@ set(TINT_LIB_SRCS transform/binding_remapper.h transform/bound_array_accessors.cc transform/bound_array_accessors.h + transform/calculate_array_length.cc + transform/calculate_array_length.h transform/canonicalize_entry_point_io.cc transform/canonicalize_entry_point_io.h transform/decompose_storage_access.cc @@ -731,6 +733,7 @@ if(${TINT_BUILD_TESTS}) list(APPEND TINT_TEST_SRCS transform/binding_remapper_test.cc transform/bound_array_accessors_test.cc + transform/calculate_array_length_test.cc transform/canonicalize_entry_point_io_test.cc transform/decompose_storage_access_test.cc transform/emit_vertex_point_size_test.cc diff --git a/src/transform/calculate_array_length.cc b/src/transform/calculate_array_length.cc new file mode 100644 index 0000000000..3d57d9cfbe --- /dev/null +++ b/src/transform/calculate_array_length.cc @@ -0,0 +1,236 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/transform/calculate_array_length.h" + +#include +#include + +#include "src/ast/call_statement.h" +#include "src/program_builder.h" +#include "src/semantic/call.h" +#include "src/semantic/statement.h" +#include "src/semantic/struct.h" +#include "src/semantic/variable.h" +#include "src/utils/get_or_create.h" +#include "src/utils/hash.h" + +TINT_INSTANTIATE_TYPEINFO( + tint::transform::CalculateArrayLength::BufferSizeIntrinsic); + +namespace tint { +namespace transform { + +namespace { + +/// ArrayUsage describes a runtime array usage. +/// It is used as a key by the array_length_by_usage map. +struct ArrayUsage { + ast::BlockStatement const* const block; + semantic::Node const* const buffer; + bool operator==(const ArrayUsage& rhs) const { + return block == rhs.block && buffer == rhs.buffer; + } + struct Hasher { + inline std::size_t operator()(const ArrayUsage& u) const { + return utils::Hash(u.block, u.buffer); + } + }; +}; + +} // namespace + +CalculateArrayLength::BufferSizeIntrinsic::BufferSizeIntrinsic() = default; +CalculateArrayLength::BufferSizeIntrinsic::~BufferSizeIntrinsic() = default; +std::string CalculateArrayLength::BufferSizeIntrinsic::Name() const { + return "intrinsic_buffer_size"; +} + +CalculateArrayLength::BufferSizeIntrinsic* +CalculateArrayLength::BufferSizeIntrinsic::Clone(CloneContext* ctx) const { + return ctx->dst->ASTNodes() + .Create(); +} + +CalculateArrayLength::CalculateArrayLength() = default; +CalculateArrayLength::~CalculateArrayLength() = default; + +Transform::Output CalculateArrayLength::Run(const Program* in, const DataMap&) { + ProgramBuilder out; + CloneContext ctx(&out, in); + + auto& sem = ctx.src->Sem(); + + // get_buffer_size_intrinsic() emits the function decorated with + // BufferSizeIntrinsic that is transformed by the HLSL writer into a call to + // [RW]ByteAddressBuffer.GetDimensions(). + std::unordered_map buffer_size_intrinsics; + auto get_buffer_size_intrinsic = [&](type::Struct* buffer_type) { + return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] { + auto name = ctx.dst->Symbols().New(); + auto* func = ctx.dst->create( + name, + ast::VariableList{ + // Note: The buffer parameter requires the kStorage StorageClass + // in order for HLSL to emit this as a ByteAddressBuffer. + ctx.dst->create( + ctx.dst->Sym("buffer"), ast::StorageClass::kStorage, + ctx.Clone(buffer_type), true, nullptr, ast::DecorationList{}), + ctx.dst->Param("result", + ctx.dst->ty.pointer(ctx.dst->ty.u32(), + ast::StorageClass::kFunction)), + }, + ctx.dst->ty.void_(), nullptr, + ast::DecorationList{ + ctx.dst->ASTNodes().Create(), + }, + ast::DecorationList{}); + ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), buffer_type, func); + return name; + }); + }; + + std::unordered_map + array_length_by_usage; + + // Find all the arrayLength() calls... + for (auto* node : ctx.src->ASTNodes().Objects()) { + if (auto* call_expr = node->As()) { + auto* call = sem.Get(call_expr); + if (auto* intrinsic = call->Target()->As()) { + if (intrinsic->Type() == semantic::IntrinsicType::kArrayLength) { + // We're dealing with an arrayLength() call + + // https://gpuweb.github.io/gpuweb/wgsl.html#array-types states: + // + // * The last member of the structure type defining the store type for + // a variable in the storage storage class may be a runtime-sized + // array. + // * A runtime-sized array must not be used as the store type or + // contained within a store type in any other cases. + // * The type of an expression must not be a runtime-sized array type. + // arrayLength() + // + // We can assume that the arrayLength() call has a single argument of + // the form: arrayLength(X.Y) where X is an expression that resolves + // to the storage buffer structure, and Y is the runtime sized array. + auto* array_expr = call_expr->params()[0]; + auto* accessor = array_expr->As(); + if (!accessor) { + TINT_ICE(ctx.dst->Diagnostics()) + << "arrayLength() expected ast::MemberAccessorExpression, got " + << array_expr->TypeInfo().name; + break; + } + auto* storage_buffer_expr = accessor->structure(); + auto* storage_buffer_sem = sem.Get(storage_buffer_expr); + auto* storage_buffer_type = + storage_buffer_sem->Type()->UnwrapAll()->As(); + + // Generate BufferSizeIntrinsic for this storage type if we haven't + // already + auto buffer_size = get_buffer_size_intrinsic(storage_buffer_type); + + if (!storage_buffer_type) { + TINT_ICE(ctx.dst->Diagnostics()) + << "arrayLength(X.Y) expected X to be type::Struct, got " + << storage_buffer_type->FriendlyName(ctx.src->Symbols()); + break; + } + + // Find the current statement block + auto* block = call->Stmt()->Block(); + if (!block) { + TINT_ICE(ctx.dst->Diagnostics()) + << "arrayLength() statement is outside a BlockStatement"; + break; + } + + // If the storage_buffer_expr is resolves to a variable (typically + // true) then key the array_length from the variable. If not, key off + // the expression semantic node, which will be unique per call to + // arrayLength(). + const semantic::Node* storage_buffer_usage = storage_buffer_sem; + if (auto* user = storage_buffer_sem->As()) { + storage_buffer_usage = user->Variable(); + } + + auto array_length = utils::GetOrCreate( + array_length_by_usage, {block, storage_buffer_usage}, [&] { + // First time this array length is used for this block. + // Let's calculate it. + + // Semantic info for the storage buffer structure + auto* storage_buffer_type_sem = + ctx.src->Sem().Get(storage_buffer_type); + // Semantic info for the runtime array structure member + auto* array_member_sem = + storage_buffer_type_sem->Members().back(); + + // Construct the variable that'll hold the result of + // RWByteAddressBuffer.GetDimensions() + auto* buffer_size_result = + ctx.dst->create(ctx.dst->Var( + ctx.dst->Symbols().New(), ctx.dst->ty.u32(), + ast::StorageClass::kFunction, ctx.dst->Expr(0u))); + + // Call storage_buffer.GetDimensions(buffer_size_result) + auto* call_get_dims = + ctx.dst->create(ctx.dst->Call( + // BufferSizeIntrinsic(X, ARGS...) is + // translated to: + // X.GetDimensions(ARGS..) by the writer + buffer_size, ctx.Clone(storage_buffer_expr), + buffer_size_result->variable()->symbol())); + + // Calculate actual array length + // total_storage_buffer_size - array_offset + // array_length = ---------------------------------------- + // array_stride + auto name = ctx.dst->Symbols().New(); + uint32_t array_offset = array_member_sem->Offset(); + uint32_t array_stride = array_member_sem->Size(); + auto* array_length_var = + ctx.dst->create(ctx.dst->Const( + name, ctx.dst->ty.u32(), + ctx.dst->Div( + ctx.dst->Sub( + buffer_size_result->variable()->symbol(), + array_offset), + array_stride))); + + // Insert the array length calculations at the top of the block + ctx.InsertBefore(block->statements(), *block->begin(), + buffer_size_result); + ctx.InsertBefore(block->statements(), *block->begin(), + call_get_dims); + ctx.InsertBefore(block->statements(), *block->begin(), + array_length_var); + return name; + }); + + // Replace the call to arrayLength() with the array length variable + ctx.Replace(call_expr, ctx.dst->Expr(array_length)); + } + } + } + } + + ctx.Clone(); + + return Output{Program(std::move(out))}; +} + +} // namespace transform +} // namespace tint diff --git a/src/transform/calculate_array_length.h b/src/transform/calculate_array_length.h new file mode 100644 index 0000000000..fa96d81b6b --- /dev/null +++ b/src/transform/calculate_array_length.h @@ -0,0 +1,68 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TRANSFORM_CALCULATE_ARRAY_LENGTH_H_ +#define SRC_TRANSFORM_CALCULATE_ARRAY_LENGTH_H_ + +#include + +#include "src/ast/internal_decoration.h" +#include "src/transform/transform.h" + +namespace tint { + +// Forward declarations +class CloneContext; + +namespace transform { + +/// CalculateArrayLength is a transform used to replace calls to arrayLength() +/// with a value calculated from the size of the storage buffer. +class CalculateArrayLength : public Transform { + public: + /// BufferSizeIntrinsic is an InternalDecoration that's applied to intrinsic + /// functions used to obtain the runtime size of a storage buffer. + class BufferSizeIntrinsic + : public Castable { + public: + /// Constructor + BufferSizeIntrinsic(); + /// Destructor + ~BufferSizeIntrinsic() override; + + /// @return "buffer_size" + std::string Name() const override; + + /// Performs a deep clone of this object using the CloneContext `ctx`. + /// @param ctx the clone context + /// @return the newly cloned object + BufferSizeIntrinsic* Clone(CloneContext* ctx) const override; + }; + + /// Constructor + CalculateArrayLength(); + /// Destructor + ~CalculateArrayLength() override; + + /// Runs the transform on `program`, returning the transformation result. + /// @param program the source program to transform + /// @param data optional extra transform-specific data + /// @returns the transformation result + Output Run(const Program* program, const DataMap& data = {}) override; +}; + +} // namespace transform +} // namespace tint + +#endif // SRC_TRANSFORM_CALCULATE_ARRAY_LENGTH_H_ diff --git a/src/transform/calculate_array_length_test.cc b/src/transform/calculate_array_length_test.cc new file mode 100644 index 0000000000..42b2f101fa --- /dev/null +++ b/src/transform/calculate_array_length_test.cc @@ -0,0 +1,284 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/transform/calculate_array_length.h" + +#include "src/transform/test_helper.h" + +namespace tint { +namespace transform { +namespace { + +using CalculateArrayLengthTest = TransformTest; + +TEST_F(CalculateArrayLengthTest, Basic) { + auto* src = R"( +[[block]] +struct SB { + x : i32; + arr : array; +}; + +var sb : SB; + +[[stage(vertex)]] +fn main() { + var len : u32 = arrayLength(sb.arr); +} +)"; + + auto* expect = R"( +[[block]] +struct SB { + x : i32; + arr : array; +}; + +[[internal(intrinsic_buffer_size)]] +fn tint_symbol_1(buffer : SB, result : ptr) + +var sb : SB; + +[[stage(vertex)]] +fn main() { + var tint_symbol_7 : u32 = 0u; + tint_symbol_1(sb, tint_symbol_7); + let tint_symbol_9 : u32 = ((tint_symbol_7 - 4u) / 4u); + var len : u32 = tint_symbol_9; +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(CalculateArrayLengthTest, InSameBlock) { + auto* src = R"( +[[block]] +struct SB { + x : i32; + arr : array; +}; + +var sb : SB; + +[[stage(vertex)]] +fn main() { + var a : u32 = arrayLength(sb.arr); + var b : u32 = arrayLength(sb.arr); + var c : u32 = arrayLength(sb.arr); +} +)"; + + auto* expect = R"( +[[block]] +struct SB { + x : i32; + arr : array; +}; + +[[internal(intrinsic_buffer_size)]] +fn tint_symbol_1(buffer : SB, result : ptr) + +var sb : SB; + +[[stage(vertex)]] +fn main() { + var tint_symbol_7 : u32 = 0u; + tint_symbol_1(sb, tint_symbol_7); + let tint_symbol_9 : u32 = ((tint_symbol_7 - 4u) / 4u); + var a : u32 = tint_symbol_9; + var b : u32 = tint_symbol_9; + var c : u32 = tint_symbol_9; +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(CalculateArrayLengthTest, WithStride) { + auto* src = R"( +[[block]] +struct SB { + x : i32; + y : f32; + arr : [[stride(64)]] array; +}; + +var sb : SB; + +[[stage(vertex)]] +fn main() { + var len : u32 = arrayLength(sb.arr); +} +)"; + + auto* expect = R"( +[[block]] +struct SB { + x : i32; + y : f32; + arr : [[stride(64)]] array; +}; + +[[internal(intrinsic_buffer_size)]] +fn tint_symbol_1(buffer : SB, result : ptr) + +var sb : SB; + +[[stage(vertex)]] +fn main() { + var tint_symbol_8 : u32 = 0u; + tint_symbol_1(sb, tint_symbol_8); + let tint_symbol_10 : u32 = ((tint_symbol_8 - 8u) / 64u); + var len : u32 = tint_symbol_10; +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(CalculateArrayLengthTest, Nested) { + auto* src = R"( +[[block]] +struct SB { + x : i32; + arr : array; +}; + +var sb : SB; + +[[stage(vertex)]] +fn main() { + if (true) { + var len : u32 = arrayLength(sb.arr); + } else { + if (true) { + var len : u32 = arrayLength(sb.arr); + } + } +} +)"; + + auto* expect = R"( +[[block]] +struct SB { + x : i32; + arr : array; +}; + +[[internal(intrinsic_buffer_size)]] +fn tint_symbol_1(buffer : SB, result : ptr) + +var sb : SB; + +[[stage(vertex)]] +fn main() { + if (true) { + var tint_symbol_7 : u32 = 0u; + tint_symbol_1(sb, tint_symbol_7); + let tint_symbol_9 : u32 = ((tint_symbol_7 - 4u) / 4u); + var len : u32 = tint_symbol_9; + } else { + if (true) { + var tint_symbol_10 : u32 = 0u; + tint_symbol_1(sb, tint_symbol_10); + let tint_symbol_11 : u32 = ((tint_symbol_10 - 4u) / 4u); + var len : u32 = tint_symbol_11; + } + } +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(CalculateArrayLengthTest, MultipleStorageBuffers) { + auto* src = R"( +[[block]] +struct SB1 { + x : i32; + arr1 : array; +}; + +[[block]] +struct SB2 { + x : i32; + arr2 : array>; +}; + +var sb1 : SB1; + +var sb2 : SB2; + +[[stage(vertex)]] +fn main() { + var len1 : u32 = arrayLength(sb1.arr1); + var len2 : u32 = arrayLength(sb2.arr2); + var x : u32 = (len1 + len2); +} +)"; + + auto* expect = R"( +[[block]] +struct SB1 { + x : i32; + arr1 : array; +}; + +[[internal(intrinsic_buffer_size)]] +fn tint_symbol_1(buffer : SB1, result : ptr) + +[[block]] +struct SB2 { + x : i32; + arr2 : array>; +}; + +[[internal(intrinsic_buffer_size)]] +fn tint_symbol_10(buffer : SB2, result : ptr) + +var sb1 : SB1; + +var sb2 : SB2; + +[[stage(vertex)]] +fn main() { + var tint_symbol_7 : u32 = 0u; + tint_symbol_1(sb1, tint_symbol_7); + let tint_symbol_9 : u32 = ((tint_symbol_7 - 4u) / 4u); + var tint_symbol_13 : u32 = 0u; + tint_symbol_10(sb2, tint_symbol_13); + let tint_symbol_15 : u32 = ((tint_symbol_13 - 16u) / 16u); + var len1 : u32 = tint_symbol_9; + var len2 : u32 = tint_symbol_15; + var x : u32 = (len1 + len2); +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +} // namespace +} // namespace transform +} // namespace tint diff --git a/test/BUILD.gn b/test/BUILD.gn index ad3fe6e989..6827f4b54b 100644 --- a/test/BUILD.gn +++ b/test/BUILD.gn @@ -194,6 +194,7 @@ source_set("tint_unittests_core_src") { "../src/traits_test.cc", "../src/transform/binding_remapper_test.cc", "../src/transform/bound_array_accessors_test.cc", + "../src/transform/calculate_array_length_test.cc", "../src/transform/canonicalize_entry_point_io_test.cc", "../src/transform/decompose_storage_access_test.cc", "../src/transform/emit_vertex_point_size_test.cc",