From 38fa643702d04f096fbccdaba3d19f24c1b29ffa Mon Sep 17 00:00:00 2001 From: Austin Eng Date: Fri, 19 Nov 2021 04:11:33 +0000 Subject: [PATCH] Add HLSL/MSL generator options for ArrayLengthFromUniform ArrayLengthFromUniform is needed for correct bounds checks on dynamic storage buffers on D3D12. The intrinsic GetDimensions does not return the actual size of the buffer binding. ArrayLengthFromUniform is updated to output the indices of the uniform buffer that are statically used. This allows Dawn to minimize the amount of data needed to upload into the uniform buffer. These output indices are returned on the HLSL/MSL generator result. ArrayLengthFromUniform is also updated to allow only some of the arrayLength calls to be replaced with uniform buffer loads. For HLSL output, the remaining arrayLength computations will continue to use GetDimensions(). For MSL, it is invalid to not specify an index into the uniform buffer for all storage buffers. After Dawn is updated to use the array_length_from_uniform option in the Metal backend, the buffer_size_ubo_index member for MSL output may be removed. Bug: dawn:429 Change-Id: I9da4ec4a20882e9f1bfa5bb026725d72529eff26 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/69301 Kokoro: Kokoro Reviewed-by: James Price Commit-Queue: Austin Eng --- src/BUILD.gn | 2 + src/CMakeLists.txt | 3 + src/transform/array_length_from_uniform.cc | 157 ++++++----- src/transform/array_length_from_uniform.h | 13 +- .../array_length_from_uniform_test.cc | 166 ++++++++++- .../array_length_from_uniform_options.cc | 30 ++ .../array_length_from_uniform_options.h | 52 ++++ src/writer/hlsl/generator.cc | 11 +- src/writer/hlsl/generator.h | 19 ++ src/writer/hlsl/generator_impl.cc | 31 +- src/writer/hlsl/generator_impl.h | 19 +- .../hlsl/generator_impl_sanitizer_test.cc | 52 ++++ src/writer/hlsl/test_helper.h | 8 +- src/writer/msl/generator.cc | 12 +- src/writer/msl/generator.h | 20 ++ src/writer/msl/generator_impl.cc | 76 +++-- src/writer/msl/generator_impl.h | 24 +- .../msl/generator_impl_sanitizer_test.cc | 264 ++++++++++++++++++ src/writer/msl/test_helper.h | 9 +- test/BUILD.gn | 1 + 20 files changed, 850 insertions(+), 119 deletions(-) create mode 100644 src/writer/array_length_from_uniform_options.cc create mode 100644 src/writer/array_length_from_uniform_options.h create mode 100644 src/writer/msl/generator_impl_sanitizer_test.cc diff --git a/src/BUILD.gn b/src/BUILD.gn index d6c95c8274..fb1bcfe667 100644 --- a/src/BUILD.gn +++ b/src/BUILD.gn @@ -491,6 +491,8 @@ libtint_source_set("libtint_core_all_src") { "utils/unique_vector.h", "writer/append_vector.cc", "writer/append_vector.h", + "writer/array_length_from_uniform_options.cc", + "writer/array_length_from_uniform_options.h", "writer/float_to_string.cc", "writer/float_to_string.h", "writer/text.cc", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b86cb98da6..9b746da875 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -403,6 +403,8 @@ set(TINT_LIB_SRCS utils/unique_vector.h writer/append_vector.cc writer/append_vector.h + writer/array_length_from_uniform_options.cc + writer/array_length_from_uniform_options.h writer/float_to_string.cc writer/float_to_string.h writer/text_generator.cc @@ -998,6 +1000,7 @@ if(${TINT_BUILD_TESTS}) writer/msl/generator_impl_member_accessor_test.cc writer/msl/generator_impl_module_constant_test.cc writer/msl/generator_impl_return_test.cc + writer/msl/generator_impl_sanitizer_test.cc writer/msl/generator_impl_switch_test.cc writer/msl/generator_impl_test.cc writer/msl/generator_impl_type_test.cc diff --git a/src/transform/array_length_from_uniform.cc b/src/transform/array_length_from_uniform.cc index 65f8f0ad3e..a30e550ca0 100644 --- a/src/transform/array_length_from_uniform.cc +++ b/src/transform/array_length_from_uniform.cc @@ -35,60 +35,18 @@ namespace transform { ArrayLengthFromUniform::ArrayLengthFromUniform() = default; ArrayLengthFromUniform::~ArrayLengthFromUniform() = default; -void ArrayLengthFromUniform::Run(CloneContext& ctx, - const DataMap& inputs, - DataMap& outputs) { - if (!Requires(ctx)) { - return; - } - - auto* cfg = inputs.Get(); - if (cfg == nullptr) { - ctx.dst->Diagnostics().add_error( - diag::System::Transform, - "missing transform data for " + std::string(TypeInfo().name)); - return; - } - +/// Iterate over all arrayLength() intrinsics that operate on +/// storage buffer variables. +/// @param ctx the CloneContext. +/// @param functor of type void(const ast::CallExpression*, const +/// sem::VariableUser, const sem::GlobalVariable*). It takes in an +/// ast::CallExpression of the arrayLength call expression node, a +/// sem::VariableUser of the used storage buffer variable, and the +/// sem::GlobalVariable for the storage buffer. +template +static void IterateArrayLengthOnStorageVar(CloneContext& ctx, F&& functor) { auto& sem = ctx.src->Sem(); - const char* kBufferSizeMemberName = "buffer_size"; - - // Determine the size of the buffer size array. - uint32_t max_buffer_size_index = 0; - for (auto& idx : cfg->bindpoint_to_size_index) { - if (idx.second > max_buffer_size_index) { - max_buffer_size_index = idx.second; - } - } - - // Get (or create, on first call) the uniform buffer that will receive the - // size of each storage buffer in the module. - const ast::Variable* buffer_size_ubo = nullptr; - auto get_ubo = [&]() { - if (!buffer_size_ubo) { - // Emit an array, N>, where N is 1/4 number of elements. - // We do this because UBOs require an element stride that is 16-byte - // aligned. - auto* buffer_size_struct = ctx.dst->Structure( - ctx.dst->Sym(), - {ctx.dst->Member( - kBufferSizeMemberName, - ctx.dst->ty.array(ctx.dst->ty.vec4(ctx.dst->ty.u32()), - (max_buffer_size_index / 4) + 1))}, - - ast::DecorationList{ctx.dst->create()}); - buffer_size_ubo = ctx.dst->Global( - ctx.dst->Sym(), ctx.dst->ty.Of(buffer_size_struct), - ast::StorageClass::kUniform, - ast::DecorationList{ - ctx.dst->create(cfg->ubo_binding.group), - ctx.dst->create( - cfg->ubo_binding.binding)}); - } - return buffer_size_ubo; - }; - // Find all calls to the arrayLength() intrinsic. for (auto* node : ctx.src->ASTNodes().Objects()) { auto* call_expr = node->As(); @@ -137,23 +95,91 @@ void ArrayLengthFromUniform::Run(CloneContext& ctx, << "storage buffer is not a global variable"; break; } + functor(call_expr, storage_buffer_sem, var); + } +} + +void ArrayLengthFromUniform::Run(CloneContext& ctx, + const DataMap& inputs, + DataMap& outputs) { + if (!Requires(ctx)) { + return; + } + + auto* cfg = inputs.Get(); + if (cfg == nullptr) { + ctx.dst->Diagnostics().add_error( + diag::System::Transform, + "missing transform data for " + std::string(TypeInfo().name)); + return; + } + + const char* kBufferSizeMemberName = "buffer_size"; + + // Determine the size of the buffer size array. + uint32_t max_buffer_size_index = 0; + + IterateArrayLengthOnStorageVar( + ctx, [&](const ast::CallExpression*, const sem::VariableUser*, + const sem::GlobalVariable* var) { + auto binding = var->BindingPoint(); + auto idx_itr = cfg->bindpoint_to_size_index.find(binding); + if (idx_itr == cfg->bindpoint_to_size_index.end()) { + return; + } + if (idx_itr->second > max_buffer_size_index) { + max_buffer_size_index = idx_itr->second; + } + }); + + // Get (or create, on first call) the uniform buffer that will receive the + // size of each storage buffer in the module. + const ast::Variable* buffer_size_ubo = nullptr; + auto get_ubo = [&]() { + if (!buffer_size_ubo) { + // Emit an array, N>, where N is 1/4 number of elements. + // We do this because UBOs require an element stride that is 16-byte + // aligned. + auto* buffer_size_struct = ctx.dst->Structure( + ctx.dst->Sym(), + {ctx.dst->Member( + kBufferSizeMemberName, + ctx.dst->ty.array(ctx.dst->ty.vec4(ctx.dst->ty.u32()), + (max_buffer_size_index / 4) + 1))}, + + ast::DecorationList{ctx.dst->create()}); + buffer_size_ubo = ctx.dst->Global( + ctx.dst->Sym(), ctx.dst->ty.Of(buffer_size_struct), + ast::StorageClass::kUniform, + ast::DecorationList{ + ctx.dst->create(cfg->ubo_binding.group), + ctx.dst->create( + cfg->ubo_binding.binding)}); + } + return buffer_size_ubo; + }; + + std::unordered_set used_size_indices; + + IterateArrayLengthOnStorageVar(ctx, [&](const ast::CallExpression* call_expr, + const sem::VariableUser* + storage_buffer_sem, + const sem::GlobalVariable* var) { auto binding = var->BindingPoint(); auto idx_itr = cfg->bindpoint_to_size_index.find(binding); if (idx_itr == cfg->bindpoint_to_size_index.end()) { - ctx.dst->Diagnostics().add_error( - diag::System::Transform, - "missing size index mapping for binding point (" + - std::to_string(binding.group) + "," + - std::to_string(binding.binding) + ")"); - continue; + return; } + uint32_t size_index = idx_itr->second; + used_size_indices.insert(size_index); + // Load the total storage buffer size from the UBO. - uint32_t array_index = idx_itr->second / 4; + uint32_t array_index = size_index / 4; auto* vec_expr = ctx.dst->IndexAccessor( ctx.dst->MemberAccessor(get_ubo()->symbol, kBufferSizeMemberName), array_index); - uint32_t vec_index = idx_itr->second % 4; + uint32_t vec_index = size_index % 4; auto* total_storage_buffer_size = ctx.dst->IndexAccessor(vec_expr, vec_index); @@ -170,20 +196,23 @@ void ArrayLengthFromUniform::Run(CloneContext& ctx, ctx.dst->Sub(total_storage_buffer_size, array_offset), array_stride); ctx.Replace(call_expr, array_length); - } + }); ctx.Clone(); - outputs.Add(buffer_size_ubo ? true : false); + outputs.Add(used_size_indices); } ArrayLengthFromUniform::Config::Config(sem::BindingPoint ubo_bp) : ubo_binding(ubo_bp) {} ArrayLengthFromUniform::Config::Config(const Config&) = default; +ArrayLengthFromUniform::Config& ArrayLengthFromUniform::Config::operator=( + const Config&) = default; ArrayLengthFromUniform::Config::~Config() = default; -ArrayLengthFromUniform::Result::Result(bool needs_sizes) - : needs_buffer_sizes(needs_sizes) {} +ArrayLengthFromUniform::Result::Result( + std::unordered_set used_size_indices_in) + : used_size_indices(std::move(used_size_indices_in)) {} ArrayLengthFromUniform::Result::Result(const Result&) = default; ArrayLengthFromUniform::Result::~Result() = default; diff --git a/src/transform/array_length_from_uniform.h b/src/transform/array_length_from_uniform.h index 306ed8a231..7063a27f5d 100644 --- a/src/transform/array_length_from_uniform.h +++ b/src/transform/array_length_from_uniform.h @@ -16,6 +16,7 @@ #define SRC_TRANSFORM_ARRAY_LENGTH_FROM_UNIFORM_H_ #include +#include #include "src/sem/binding_point.h" #include "src/transform/transform.h" @@ -66,6 +67,10 @@ class ArrayLengthFromUniform /// Copy constructor Config(const Config&); + /// Copy assignment + /// @return this Config + Config& operator=(const Config&); + /// Destructor ~Config() override; @@ -79,8 +84,8 @@ class ArrayLengthFromUniform /// Information produced about what the transform did. struct Result : public Castable { /// Constructor - /// @param needs_sizes True if the transform generated the buffer sizes UBO. - explicit Result(bool needs_sizes); + /// @param used_size_indices Indices into the UBO that are statically used. + explicit Result(std::unordered_set used_size_indices); /// Copy constructor Result(const Result&); @@ -88,8 +93,8 @@ class ArrayLengthFromUniform /// Destructor ~Result() override; - /// True if the transform generated the buffer sizes UBO. - const bool needs_buffer_sizes; + /// Indices into the UBO that are statically used. + const std::unordered_set used_size_indices; }; protected: diff --git a/src/transform/array_length_from_uniform_test.cc b/src/transform/array_length_from_uniform_test.cc index c41ba2e932..fd4eef922b 100644 --- a/src/transform/array_length_from_uniform_test.cc +++ b/src/transform/array_length_from_uniform_test.cc @@ -110,8 +110,8 @@ fn main() { Run(src, data); EXPECT_EQ(expect, str(got)); - EXPECT_TRUE( - got.data.Get()->needs_buffer_sizes); + EXPECT_EQ(std::unordered_set({0}), + got.data.Get()->used_size_indices); } TEST_F(ArrayLengthFromUniformTest, WithStride) { @@ -164,8 +164,8 @@ fn main() { Run(src, data); EXPECT_EQ(expect, str(got)); - EXPECT_TRUE( - got.data.Get()->needs_buffer_sizes); + EXPECT_EQ(std::unordered_set({0}), + got.data.Get()->used_size_indices); } TEST_F(ArrayLengthFromUniformTest, MultipleStorageBuffers) { @@ -286,8 +286,124 @@ fn main() { Run(src, data); EXPECT_EQ(expect, str(got)); - EXPECT_TRUE( - got.data.Get()->needs_buffer_sizes); + EXPECT_EQ(std::unordered_set({0, 1, 2, 3, 4}), + got.data.Get()->used_size_indices); +} + +TEST_F(ArrayLengthFromUniformTest, MultipleUnusedStorageBuffers) { + auto* src = R"( +[[block]] +struct SB1 { + x : i32; + arr1 : array; +}; +[[block]] +struct SB2 { + x : i32; + arr2 : array>; +}; +[[block]] +struct SB3 { + x : i32; + arr3 : array>; +}; +[[block]] +struct SB4 { + x : i32; + arr4 : array>; +}; +[[block]] +struct SB5 { + x : i32; + arr5 : array>; +}; + +[[group(0), binding(2)]] var sb1 : SB1; +[[group(1), binding(2)]] var sb2 : SB2; +[[group(2), binding(2)]] var sb3 : SB3; +[[group(3), binding(2)]] var sb4 : SB4; +[[group(4), binding(2)]] var sb5 : SB5; + +[[stage(compute), workgroup_size(1)]] +fn main() { + var len1 : u32 = arrayLength(&(sb1.arr1)); + var len3 : u32 = arrayLength(&(sb3.arr3)); + var x : u32 = (len1 + len3); +} +)"; + + auto* expect = R"( +[[block]] +struct tint_symbol { + buffer_size : array, 1u>; +}; + +[[group(0), binding(30)]] var tint_symbol_1 : tint_symbol; + +[[block]] +struct SB1 { + x : i32; + arr1 : array; +}; + +[[block]] +struct SB2 { + x : i32; + arr2 : array>; +}; + +[[block]] +struct SB3 { + x : i32; + arr3 : array>; +}; + +[[block]] +struct SB4 { + x : i32; + arr4 : array>; +}; + +[[block]] +struct SB5 { + x : i32; + arr5 : array>; +}; + +[[group(0), binding(2)]] var sb1 : SB1; + +[[group(1), binding(2)]] var sb2 : SB2; + +[[group(2), binding(2)]] var sb3 : SB3; + +[[group(3), binding(2)]] var sb4 : SB4; + +[[group(4), binding(2)]] var sb5 : SB5; + +[[stage(compute), workgroup_size(1)]] +fn main() { + var len1 : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u); + var len3 : u32 = ((tint_symbol_1.buffer_size[0u][2u] - 16u) / 16u); + var x : u32 = (len1 + len3); +} +)"; + + ArrayLengthFromUniform::Config cfg({0, 30u}); + cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2u}, 0); + cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{1u, 2u}, 1); + cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{2u, 2u}, 2); + cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{3u, 2u}, 3); + cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{4u, 2u}, 4); + + DataMap data; + data.Add(std::move(cfg)); + + auto got = + Run(src, data); + + EXPECT_EQ(expect, str(got)); + EXPECT_EQ(std::unordered_set({0, 2}), + got.data.Get()->used_size_indices); } TEST_F(ArrayLengthFromUniformTest, NoArrayLengthCalls) { @@ -316,8 +432,8 @@ fn main() { Run(src, data); EXPECT_EQ(src, str(got)); - EXPECT_FALSE( - got.data.Get()->needs_buffer_sizes); + EXPECT_EQ(std::unordered_set(), + got.data.Get()->used_size_indices); } TEST_F(ArrayLengthFromUniformTest, MissingBindingPointToIndexMapping) { @@ -346,7 +462,37 @@ fn main() { } )"; - auto* expect = "error: missing size index mapping for binding point (1,2)"; + auto* expect = R"( +[[block]] +struct tint_symbol { + buffer_size : array, 1u>; +}; + +[[group(0), binding(30)]] var tint_symbol_1 : tint_symbol; + +[[block]] +struct SB1 { + x : i32; + arr1 : array; +}; + +[[block]] +struct SB2 { + x : i32; + arr2 : array>; +}; + +[[group(0), binding(2)]] var sb1 : SB1; + +[[group(1), binding(2)]] var sb2 : SB2; + +[[stage(compute), workgroup_size(1)]] +fn main() { + var len1 : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u); + var len2 : u32 = arrayLength(&(sb2.arr2)); + var x : u32 = (len1 + len2); +} +)"; ArrayLengthFromUniform::Config cfg({0, 30u}); cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2}, 0); @@ -358,6 +504,8 @@ fn main() { Run(src, data); EXPECT_EQ(expect, str(got)); + EXPECT_EQ(std::unordered_set({0}), + got.data.Get()->used_size_indices); } } // namespace diff --git a/src/writer/array_length_from_uniform_options.cc b/src/writer/array_length_from_uniform_options.cc new file mode 100644 index 0000000000..b848a484e3 --- /dev/null +++ b/src/writer/array_length_from_uniform_options.cc @@ -0,0 +1,30 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/writer/array_length_from_uniform_options.h" + +namespace tint { +namespace writer { + +ArrayLengthFromUniformOptions::ArrayLengthFromUniformOptions() = default; +ArrayLengthFromUniformOptions::~ArrayLengthFromUniformOptions() = default; +ArrayLengthFromUniformOptions::ArrayLengthFromUniformOptions( + const ArrayLengthFromUniformOptions&) = default; +ArrayLengthFromUniformOptions& ArrayLengthFromUniformOptions::operator=( + const ArrayLengthFromUniformOptions&) = default; +ArrayLengthFromUniformOptions::ArrayLengthFromUniformOptions( + ArrayLengthFromUniformOptions&&) = default; + +} // namespace writer +} // namespace tint diff --git a/src/writer/array_length_from_uniform_options.h b/src/writer/array_length_from_uniform_options.h new file mode 100644 index 0000000000..ae92d4f028 --- /dev/null +++ b/src/writer/array_length_from_uniform_options.h @@ -0,0 +1,52 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_WRITER_ARRAY_LENGTH_FROM_UNIFORM_OPTIONS_H_ +#define SRC_WRITER_ARRAY_LENGTH_FROM_UNIFORM_OPTIONS_H_ + +#include + +#include "src/sem/binding_point.h" + +namespace tint { +namespace writer { + +/// Options used to specify a mapping of binding points to indices into a UBO +/// from which to load buffer sizes. +struct ArrayLengthFromUniformOptions { + /// Constructor + ArrayLengthFromUniformOptions(); + /// Destructor + ~ArrayLengthFromUniformOptions(); + /// Copy constructor + ArrayLengthFromUniformOptions(const ArrayLengthFromUniformOptions&); + /// Copy assignment + /// @returns this ArrayLengthFromUniformOptions + ArrayLengthFromUniformOptions& operator=( + const ArrayLengthFromUniformOptions&); + /// Move constructor + ArrayLengthFromUniformOptions(ArrayLengthFromUniformOptions&&); + + /// The binding point to use to generate a uniform buffer from which to read + /// buffer sizes. + sem::BindingPoint ubo_binding; + /// The mapping from storage buffer binding points to the index into the + /// uniform buffer where the length of the buffer is stored. + std::unordered_map bindpoint_to_size_index; +}; + +} // namespace writer +} // namespace tint + +#endif // SRC_WRITER_ARRAY_LENGTH_FROM_UNIFORM_OPTIONS_H_ diff --git a/src/writer/hlsl/generator.cc b/src/writer/hlsl/generator.cc index ce172f9a99..727b5d5525 100644 --- a/src/writer/hlsl/generator.cc +++ b/src/writer/hlsl/generator.cc @@ -20,6 +20,11 @@ namespace tint { namespace writer { namespace hlsl { +Options::Options() = default; +Options::~Options() = default; +Options::Options(const Options&) = default; +Options& Options::operator=(const Options&) = default; + Result::Result() = default; Result::~Result() = default; Result::Result(const Result&) = default; @@ -29,7 +34,8 @@ Result Generate(const Program* program, const Options& options) { // Sanitize the program. auto sanitized_result = Sanitize(program, options.root_constant_binding_point, - options.disable_workgroup_init); + options.disable_workgroup_init, + options.array_length_from_uniform); if (!sanitized_result.program.IsValid()) { result.success = false; result.error = sanitized_result.program.Diagnostics().str(); @@ -50,6 +56,9 @@ Result Generate(const Program* program, const Options& options) { } } + result.used_array_length_from_uniform_indices = + std::move(sanitized_result.used_array_length_from_uniform_indices); + return result; } diff --git a/src/writer/hlsl/generator.h b/src/writer/hlsl/generator.h index 693002980c..4dacbd4cda 100644 --- a/src/writer/hlsl/generator.h +++ b/src/writer/hlsl/generator.h @@ -17,11 +17,13 @@ #include #include +#include #include #include #include "src/ast/pipeline_stage.h" #include "src/sem/binding_point.h" +#include "src/writer/array_length_from_uniform_options.h" #include "src/writer/text.h" namespace tint { @@ -37,10 +39,23 @@ class GeneratorImpl; /// Configuration options used for generating HLSL. struct Options { + /// Constructor + Options(); + /// Destructor + ~Options(); + /// Copy constructor + Options(const Options&); + /// Copy assignment + /// @returns this Options + Options& operator=(const Options&); + /// The binding point to use for information passed via root constants. sem::BindingPoint root_constant_binding_point; /// Set to `true` to disable workgroup memory zero initialization bool disable_workgroup_init = false; + /// Options used to specify a mapping of binding points to indices into a UBO + /// from which to load buffer sizes. + ArrayLengthFromUniformOptions array_length_from_uniform = {}; }; /// The result produced when generating HLSL. @@ -65,6 +80,10 @@ struct Result { /// The list of entry points in the generated HLSL. std::vector> entry_points; + + /// Indices into the array_length_from_uniform binding that are statically + /// used. + std::unordered_set used_array_length_from_uniform_indices; }; /// Generate HLSL for a program, according to a set of configuration options. diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index 3d3895a56e..c60be69618 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -45,6 +45,7 @@ #include "src/sem/type_conversion.h" #include "src/sem/variable.h" #include "src/transform/add_empty_entry_point.h" +#include "src/transform/array_length_from_uniform.h" #include "src/transform/calculate_array_length.h" #include "src/transform/canonicalize_entry_point_io.h" #include "src/transform/decompose_memory_access.h" @@ -124,12 +125,24 @@ const char* LoopAttribute() { } // namespace -SanitizedResult Sanitize(const Program* in, - sem::BindingPoint root_constant_binding_point, - bool disable_workgroup_init) { +SanitizedResult::SanitizedResult() = default; +SanitizedResult::~SanitizedResult() = default; +SanitizedResult::SanitizedResult(SanitizedResult&&) = default; + +SanitizedResult Sanitize( + const Program* in, + sem::BindingPoint root_constant_binding_point, + bool disable_workgroup_init, + const ArrayLengthFromUniformOptions& array_length_from_uniform) { transform::Manager manager; transform::DataMap data; + // Build the config for the internal ArrayLengthFromUniform transform. + transform::ArrayLengthFromUniform::Config array_length_from_uniform_cfg( + array_length_from_uniform.ubo_binding); + array_length_from_uniform_cfg.bindpoint_to_size_index = + array_length_from_uniform.bindpoint_to_size_index; + // Attempt to convert `loop`s into for-loops. This is to try and massage the // output into something that will not cause FXC to choke or misbehave. manager.Add(); @@ -149,6 +162,11 @@ SanitizedResult Sanitize(const Program* in, // Simplify cleans up messy `*(&(expr))` expressions from InlinePointerLets. manager.Add(); manager.Add(); + // ArrayLengthFromUniform must come after InlinePointerLets and Simplify, as + // it assumes that the form of the array length argument is &var.array. + manager.Add(); + data.Add( + std::move(array_length_from_uniform_cfg)); // DecomposeMemoryAccess must come after: // * InlinePointerLets, as we cannot take the address of calls to // DecomposeMemoryAccess::Intrinsic. @@ -171,8 +189,13 @@ SanitizedResult Sanitize(const Program* in, data.Add( root_constant_binding_point); + auto out = manager.Run(in, data); + SanitizedResult result; - result.program = std::move(manager.Run(in, data).program); + result.program = std::move(out.program); + result.used_array_length_from_uniform_indices = + std::move(out.data.Get() + ->used_size_indices); return result; } diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h index 7c56b10aa9..b23897d6ef 100644 --- a/src/writer/hlsl/generator_impl.h +++ b/src/writer/hlsl/generator_impl.h @@ -36,6 +36,7 @@ #include "src/sem/binding_point.h" #include "src/transform/decompose_memory_access.h" #include "src/utils/hash.h" +#include "src/writer/array_length_from_uniform_options.h" #include "src/writer/text_generator.h" namespace tint { @@ -53,8 +54,18 @@ namespace hlsl { /// The result of sanitizing a program for generation. struct SanitizedResult { + /// Constructor + SanitizedResult(); + /// Destructor + ~SanitizedResult(); + /// Move constructor + SanitizedResult(SanitizedResult&&); + /// The sanitized program. Program program; + /// Indices into the array_length_from_uniform binding that are statically + /// used. + std::unordered_set used_array_length_from_uniform_indices; }; /// Sanitize a program in preparation for generating HLSL. @@ -62,9 +73,11 @@ struct SanitizedResult { /// that will be passed via root constants /// @param disable_workgroup_init `true` to disable workgroup memory zero /// @returns the sanitized program and any supplementary information -SanitizedResult Sanitize(const Program* program, - sem::BindingPoint root_constant_binding_point = {}, - bool disable_workgroup_init = false); +SanitizedResult Sanitize( + const Program* program, + sem::BindingPoint root_constant_binding_point = {}, + bool disable_workgroup_init = false, + const ArrayLengthFromUniformOptions& array_length_from_uniform = {}); /// Implementation class for HLSL generator class GeneratorImpl : public TextGenerator { diff --git a/src/writer/hlsl/generator_impl_sanitizer_test.cc b/src/writer/hlsl/generator_impl_sanitizer_test.cc index c3692a7683..fa6d00bc31 100644 --- a/src/writer/hlsl/generator_impl_sanitizer_test.cc +++ b/src/writer/hlsl/generator_impl_sanitizer_test.cc @@ -144,6 +144,58 @@ void a_func() { EXPECT_EQ(expect, got); } +TEST_F(HlslSanitizerTest, Call_ArrayLength_ArrayLengthFromUniform) { + auto* s = Structure("my_struct", {Member(0, "a", ty.array(4))}, + {create()}); + Global("b", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead, + ast::DecorationList{ + create(1), + create(2), + }); + Global("c", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead, + ast::DecorationList{ + create(2), + create(2), + }); + + Func("a_func", ast::VariableList{}, ty.void_(), + ast::StatementList{ + Decl(Var( + "len", ty.u32(), ast::StorageClass::kNone, + Add(Call("arrayLength", AddressOf(MemberAccessor("b", "a"))), + Call("arrayLength", AddressOf(MemberAccessor("c", "a")))))), + }, + ast::DecorationList{ + Stage(ast::PipelineStage::kFragment), + }); + + Options options; + options.array_length_from_uniform.ubo_binding = {3, 4}; + options.array_length_from_uniform.bindpoint_to_size_index.emplace( + sem::BindingPoint{2, 2}, 7u); + GeneratorImpl& gen = SanitizeAndBuild(options); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + auto got = gen.result(); + auto* expect = R"(cbuffer cbuffer_tint_symbol_1 : register(b4, space3) { + uint4 tint_symbol_1[2]; +}; + +ByteAddressBuffer b : register(t1, space2); +ByteAddressBuffer c : register(t2, space2); + +void a_func() { + uint tint_symbol_4 = 0u; + b.GetDimensions(tint_symbol_4); + const uint tint_symbol_5 = ((tint_symbol_4 - 0u) / 4u); + uint len = (tint_symbol_5 + ((tint_symbol_1[1].w - 0u) / 4u)); + return; +} +)"; + EXPECT_EQ(expect, got); +} + TEST_F(HlslSanitizerTest, PromoteArrayInitializerToConstVar) { auto* array_init = array(1, 2, 3, 4); auto* array_index = IndexAccessor(array_init, 3); diff --git a/src/writer/hlsl/test_helper.h b/src/writer/hlsl/test_helper.h index 7523c82cbb..e55b9c1e95 100644 --- a/src/writer/hlsl/test_helper.h +++ b/src/writer/hlsl/test_helper.h @@ -22,6 +22,7 @@ #include "gtest/gtest.h" #include "src/transform/manager.h" #include "src/transform/renamer.h" +#include "src/writer/hlsl/generator.h" #include "src/writer/hlsl/generator_impl.h" namespace tint { @@ -58,10 +59,11 @@ class TestHelperBase : public BODY, public ProgramBuilder { /// Builds the program, runs the program through the HLSL sanitizer /// and returns a GeneratorImpl from the sanitized program. + /// @param options The HLSL generator options. /// @note The generator is only built once. Multiple calls to Build() will /// return the same GeneratorImpl without rebuilding. /// @return the built generator - GeneratorImpl& SanitizeAndBuild() { + GeneratorImpl& SanitizeAndBuild(const Options& options = {}) { if (gen_) { return *gen_; } @@ -76,7 +78,9 @@ class TestHelperBase : public BODY, public ProgramBuilder { << formatter.format(program->Diagnostics()); }(); - auto sanitized_result = Sanitize(program.get()); + auto sanitized_result = Sanitize( + program.get(), options.root_constant_binding_point, + options.disable_workgroup_init, options.array_length_from_uniform); [&]() { ASSERT_TRUE(sanitized_result.program.IsValid()) << formatter.format(sanitized_result.program.Diagnostics()); diff --git a/src/writer/msl/generator.cc b/src/writer/msl/generator.cc index cc1aa5c4a3..92741e8d75 100644 --- a/src/writer/msl/generator.cc +++ b/src/writer/msl/generator.cc @@ -14,12 +14,19 @@ #include "src/writer/msl/generator.h" +#include + #include "src/writer/msl/generator_impl.h" namespace tint { namespace writer { namespace msl { +Options::Options() = default; +Options::~Options() = default; +Options::Options(const Options&) = default; +Options& Options::operator=(const Options&) = default; + Result::Result() = default; Result::~Result() = default; Result::Result(const Result&) = default; @@ -30,7 +37,8 @@ Result Generate(const Program* program, const Options& options) { // Sanitize the program. auto sanitized_result = Sanitize( program, options.buffer_size_ubo_index, options.fixed_sample_mask, - options.emit_vertex_point_size, options.disable_workgroup_init); + options.emit_vertex_point_size, options.disable_workgroup_init, + options.array_length_from_uniform); if (!sanitized_result.program.IsValid()) { result.success = false; result.error = sanitized_result.program.Diagnostics().str(); @@ -38,6 +46,8 @@ Result Generate(const Program* program, const Options& options) { } result.needs_storage_buffer_sizes = sanitized_result.needs_storage_buffer_sizes; + result.used_array_length_from_uniform_indices = + std::move(sanitized_result.used_array_length_from_uniform_indices); // Generate the MSL code. auto impl = std::make_unique(&sanitized_result.program); diff --git a/src/writer/msl/generator.h b/src/writer/msl/generator.h index af812d7161..63d793957c 100644 --- a/src/writer/msl/generator.h +++ b/src/writer/msl/generator.h @@ -18,8 +18,10 @@ #include #include #include +#include #include +#include "src/writer/array_length_from_uniform_options.h" #include "src/writer/text.h" namespace tint { @@ -34,6 +36,16 @@ class GeneratorImpl; /// Configuration options used for generating MSL. struct Options { + /// Constructor + Options(); + /// Destructor + ~Options(); + /// Copy constructor + Options(const Options&); + /// Copy assignment + /// @returns this Options + Options& operator=(const Options&); + /// The index to use when generating a UBO to receive storage buffer sizes. /// Defaults to 30, which is the last valid buffer slot. uint32_t buffer_size_ubo_index = 30; @@ -48,6 +60,10 @@ struct Options { /// Set to `true` to disable workgroup memory zero initialization bool disable_workgroup_init = false; + + /// Options used to specify a mapping of binding points to indices into a UBO + /// from which to load buffer sizes. + ArrayLengthFromUniformOptions array_length_from_uniform = {}; }; /// The result produced when generating MSL. @@ -80,6 +96,10 @@ struct Result { /// Each entry in the vector is the size of the workgroup allocation that /// should be created for that index. std::unordered_map> workgroup_allocations; + + /// Indices into the array_length_from_uniform binding that are statically + /// used. + std::unordered_set used_array_length_from_uniform_indices; }; /// Generate MSL for a program, according to a set of configuration options. The diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc index f7ec04f414..4cfc22e9fa 100644 --- a/src/writer/msl/generator_impl.cc +++ b/src/writer/msl/generator_impl.cc @@ -112,31 +112,47 @@ class ScopedBitCast { }; } // namespace -SanitizedResult Sanitize(const Program* in, - uint32_t buffer_size_ubo_index, - uint32_t fixed_sample_mask, - bool emit_vertex_point_size, - bool disable_workgroup_init) { +SanitizedResult::SanitizedResult() = default; +SanitizedResult::~SanitizedResult() = default; +SanitizedResult::SanitizedResult(SanitizedResult&&) = default; + +SanitizedResult Sanitize( + const Program* in, + uint32_t buffer_size_ubo_index, + uint32_t fixed_sample_mask, + bool emit_vertex_point_size, + bool disable_workgroup_init, + const ArrayLengthFromUniformOptions& array_length_from_uniform) { transform::Manager manager; transform::DataMap internal_inputs; - // Build the configs for the internal transforms. - auto array_length_from_uniform_cfg = - transform::ArrayLengthFromUniform::Config( - sem::BindingPoint{0, buffer_size_ubo_index}); + // Build the config for the internal ArrayLengthFromUniform transform. + transform::ArrayLengthFromUniform::Config array_length_from_uniform_cfg( + array_length_from_uniform.ubo_binding); + if (!array_length_from_uniform.bindpoint_to_size_index.empty()) { + // If |array_length_from_uniform| bindings are provided, use that config. + array_length_from_uniform_cfg.bindpoint_to_size_index = + array_length_from_uniform.bindpoint_to_size_index; + } else { + // If the binding map is empty, use the deprecated |buffer_size_ubo_index| + // and automatically choose indices using the binding numbers. + array_length_from_uniform_cfg = transform::ArrayLengthFromUniform::Config( + sem::BindingPoint{0, buffer_size_ubo_index}); + // Use the SSBO binding numbers as the indices for the buffer size lookups. + for (auto* var : in->AST().GlobalVariables()) { + auto* global = in->Sem().Get(var); + if (global && global->StorageClass() == ast::StorageClass::kStorage) { + array_length_from_uniform_cfg.bindpoint_to_size_index.emplace( + global->BindingPoint(), global->BindingPoint().binding); + } + } + } + + // Build the configs for the internal CanonicalizeEntryPointIO transform. auto entry_point_io_cfg = transform::CanonicalizeEntryPointIO::Config( transform::CanonicalizeEntryPointIO::ShaderStyle::kMsl, fixed_sample_mask, emit_vertex_point_size); - // Use the SSBO binding numbers as the indices for the buffer size lookups. - for (auto* var : in->AST().GlobalVariables()) { - auto* global = in->Sem().Get(var); - if (global && global->StorageClass() == ast::StorageClass::kStorage) { - array_length_from_uniform_cfg.bindpoint_to_size_index.emplace( - global->BindingPoint(), global->BindingPoint().binding); - } - } - if (!disable_workgroup_init) { // ZeroInitWorkgroupMemory must come before CanonicalizeEntryPointIO as // ZeroInitWorkgroupMemory may inject new builtin parameters. @@ -160,13 +176,18 @@ SanitizedResult Sanitize(const Program* in, internal_inputs.Add( std::move(entry_point_io_cfg)); auto out = manager.Run(in, internal_inputs); - if (!out.program.IsValid()) { - return {std::move(out.program)}; - } - return {std::move(out.program), - out.data.Get() - ->needs_buffer_sizes}; + SanitizedResult result; + result.program = std::move(out.program); + if (!result.program.IsValid()) { + return result; + } + result.used_array_length_from_uniform_indices = + std::move(out.data.Get() + ->used_size_indices); + result.needs_storage_buffer_sizes = + !result.used_array_length_from_uniform_indices.empty(); + return result; } GeneratorImpl::GeneratorImpl(const Program* program) : TextGenerator(program) {} @@ -1314,6 +1335,13 @@ std::string GeneratorImpl::generate_builtin_name( case sem::IntrinsicType::kUnpack2x16unorm: out += "unpack_unorm2x16_to_float"; break; + case sem::IntrinsicType::kArrayLength: + diagnostics_.add_error( + diag::System::Writer, + "Unable to translate builtin: " + std::string(intrinsic->str()) + + "\nDid you forget to pass array_length_from_uniform generator " + "options?"); + return ""; default: diagnostics_.add_error( diag::System::Writer, diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h index 8b73a31c5c..710b8e2375 100644 --- a/src/writer/msl/generator_impl.h +++ b/src/writer/msl/generator_impl.h @@ -17,6 +17,7 @@ #include #include +#include #include #include "src/ast/assignment_statement.h" @@ -37,6 +38,7 @@ #include "src/program.h" #include "src/scope_stack.h" #include "src/sem/struct.h" +#include "src/writer/array_length_from_uniform_options.h" #include "src/writer/text_generator.h" namespace tint { @@ -54,10 +56,20 @@ namespace msl { /// The result of sanitizing a program for generation. struct SanitizedResult { + /// Constructor + SanitizedResult(); + /// Destructor + ~SanitizedResult(); + /// Move constructor + SanitizedResult(SanitizedResult&&); + /// The sanitized program. Program program; /// True if the shader needs a UBO of buffer sizes. bool needs_storage_buffer_sizes = false; + /// Indices into the array_length_from_uniform binding that are statically + /// used. + std::unordered_set used_array_length_from_uniform_indices; }; /// Sanitize a program in preparation for generating MSL. @@ -66,11 +78,13 @@ struct SanitizedResult { /// @param emit_vertex_point_size `true` to emit a vertex point size builtin /// @param disable_workgroup_init `true` to disable workgroup memory zero /// @returns the sanitized program and any supplementary information -SanitizedResult Sanitize(const Program* program, - uint32_t buffer_size_ubo_index, - uint32_t fixed_sample_mask = 0xFFFFFFFF, - bool emit_vertex_point_size = false, - bool disable_workgroup_init = false); +SanitizedResult Sanitize( + const Program* program, + uint32_t buffer_size_ubo_index, + uint32_t fixed_sample_mask = 0xFFFFFFFF, + bool emit_vertex_point_size = false, + bool disable_workgroup_init = false, + const ArrayLengthFromUniformOptions& array_length_from_uniform = {}); /// Implementation class for MSL generator class GeneratorImpl : public TextGenerator { diff --git a/src/writer/msl/generator_impl_sanitizer_test.cc b/src/writer/msl/generator_impl_sanitizer_test.cc new file mode 100644 index 0000000000..c8fb11653e --- /dev/null +++ b/src/writer/msl/generator_impl_sanitizer_test.cc @@ -0,0 +1,264 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gmock/gmock.h" +#include "src/ast/call_statement.h" +#include "src/ast/stage_decoration.h" +#include "src/ast/struct_block_decoration.h" +#include "src/ast/variable_decl_statement.h" +#include "src/writer/msl/test_helper.h" + +namespace tint { +namespace writer { +namespace msl { +namespace { + +using ::testing::HasSubstr; + +using MslSanitizerTest = TestHelper; + +TEST_F(MslSanitizerTest, Call_ArrayLength) { + auto* s = Structure("my_struct", {Member(0, "a", ty.array(4))}, + {create()}); + Global("b", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead, + ast::DecorationList{ + create(1), + create(2), + }); + + Func("a_func", ast::VariableList{}, ty.void_(), + ast::StatementList{ + Decl(Var("len", ty.u32(), ast::StorageClass::kNone, + Call("arrayLength", AddressOf(MemberAccessor("b", "a"))))), + }, + ast::DecorationList{ + Stage(ast::PipelineStage::kFragment), + }); + + GeneratorImpl& gen = SanitizeAndBuild(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + auto got = gen.result(); + auto* expect = R"(#include + +using namespace metal; +struct tint_symbol { + /* 0x0000 */ uint4 buffer_size[1]; +}; +struct my_struct { + float a[1]; +}; + +fragment void a_func(const constant tint_symbol* tint_symbol_2 [[buffer(30)]]) { + uint len = (((*(tint_symbol_2)).buffer_size[0u][1u] - 0u) / 4u); + return; +} + +)"; + EXPECT_EQ(expect, got); +} + +TEST_F(MslSanitizerTest, Call_ArrayLength_OtherMembersInStruct) { + auto* s = Structure("my_struct", + { + Member(0, "z", ty.f32()), + Member(4, "a", ty.array(4)), + }, + {create()}); + Global("b", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead, + ast::DecorationList{ + create(1), + create(2), + }); + + Func("a_func", ast::VariableList{}, ty.void_(), + ast::StatementList{ + Decl(Var("len", ty.u32(), ast::StorageClass::kNone, + Call("arrayLength", AddressOf(MemberAccessor("b", "a"))))), + }, + ast::DecorationList{ + Stage(ast::PipelineStage::kFragment), + }); + + GeneratorImpl& gen = SanitizeAndBuild(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + auto got = gen.result(); + auto* expect = R"(#include + +using namespace metal; +struct tint_symbol { + /* 0x0000 */ uint4 buffer_size[1]; +}; +struct my_struct { + float z; + float a[1]; +}; + +fragment void a_func(const constant tint_symbol* tint_symbol_2 [[buffer(30)]]) { + uint len = (((*(tint_symbol_2)).buffer_size[0u][1u] - 4u) / 4u); + return; +} + +)"; + + EXPECT_EQ(expect, got); +} + +TEST_F(MslSanitizerTest, Call_ArrayLength_ViaLets) { + auto* s = Structure("my_struct", {Member(0, "a", ty.array(4))}, + {create()}); + Global("b", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead, + ast::DecorationList{ + create(1), + create(2), + }); + + auto* p = Const("p", nullptr, AddressOf("b")); + auto* p2 = Const("p2", nullptr, AddressOf(MemberAccessor(Deref(p), "a"))); + + Func("a_func", ast::VariableList{}, ty.void_(), + ast::StatementList{ + Decl(p), + Decl(p2), + Decl(Var("len", ty.u32(), ast::StorageClass::kNone, + Call("arrayLength", p2))), + }, + ast::DecorationList{ + Stage(ast::PipelineStage::kFragment), + }); + + GeneratorImpl& gen = SanitizeAndBuild(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + auto got = gen.result(); + auto* expect = R"(#include + +using namespace metal; +struct tint_symbol { + /* 0x0000 */ uint4 buffer_size[1]; +}; +struct my_struct { + float a[1]; +}; + +fragment void a_func(const constant tint_symbol* tint_symbol_2 [[buffer(30)]]) { + uint len = (((*(tint_symbol_2)).buffer_size[0u][1u] - 0u) / 4u); + return; +} + +)"; + + EXPECT_EQ(expect, got); +} + +TEST_F(MslSanitizerTest, Call_ArrayLength_ArrayLengthFromUniform) { + auto* s = Structure("my_struct", {Member(0, "a", ty.array(4))}, + {create()}); + Global("b", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead, + ast::DecorationList{ + create(1), + create(0), + }); + Global("c", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead, + ast::DecorationList{ + create(2), + create(0), + }); + + Func("a_func", ast::VariableList{}, ty.void_(), + ast::StatementList{ + Decl(Var( + "len", ty.u32(), ast::StorageClass::kNone, + Add(Call("arrayLength", AddressOf(MemberAccessor("b", "a"))), + Call("arrayLength", AddressOf(MemberAccessor("c", "a")))))), + }, + ast::DecorationList{ + Stage(ast::PipelineStage::kFragment), + }); + + Options options; + options.array_length_from_uniform.ubo_binding = {0, 29}; + options.array_length_from_uniform.bindpoint_to_size_index.emplace( + sem::BindingPoint{0, 1}, 7u); + options.array_length_from_uniform.bindpoint_to_size_index.emplace( + sem::BindingPoint{0, 2}, 2u); + GeneratorImpl& gen = SanitizeAndBuild(options); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + auto got = gen.result(); + auto* expect = R"(#include + +using namespace metal; +struct tint_symbol { + /* 0x0000 */ uint4 buffer_size[2]; +}; +struct my_struct { + float a[1]; +}; + +fragment void a_func(const constant tint_symbol* tint_symbol_2 [[buffer(29)]]) { + uint len = ((((*(tint_symbol_2)).buffer_size[1u][3u] - 0u) / 4u) + (((*(tint_symbol_2)).buffer_size[0u][2u] - 0u) / 4u)); + return; +} + +)"; + EXPECT_EQ(expect, got); +} + +TEST_F(MslSanitizerTest, + Call_ArrayLength_ArrayLengthFromUniformMissingBinding) { + auto* s = Structure("my_struct", {Member(0, "a", ty.array(4))}, + {create()}); + Global("b", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead, + ast::DecorationList{ + create(1), + create(0), + }); + Global("c", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead, + ast::DecorationList{ + create(2), + create(0), + }); + + Func("a_func", ast::VariableList{}, ty.void_(), + ast::StatementList{ + Decl(Var( + "len", ty.u32(), ast::StorageClass::kNone, + Add(Call("arrayLength", AddressOf(MemberAccessor("b", "a"))), + Call("arrayLength", AddressOf(MemberAccessor("c", "a")))))), + }, + ast::DecorationList{ + Stage(ast::PipelineStage::kFragment), + }); + + Options options; + options.array_length_from_uniform.ubo_binding = {0, 29}; + options.array_length_from_uniform.bindpoint_to_size_index.emplace( + sem::BindingPoint{0, 2}, 2u); + GeneratorImpl& gen = SanitizeAndBuild(options); + + ASSERT_FALSE(gen.Generate()); + EXPECT_THAT(gen.error(), + HasSubstr("Unable to translate builtin: arrayLength")); +} + +} // namespace +} // namespace msl +} // namespace writer +} // namespace tint diff --git a/src/writer/msl/test_helper.h b/src/writer/msl/test_helper.h index 762b2b6528..8b4cf3b2f2 100644 --- a/src/writer/msl/test_helper.h +++ b/src/writer/msl/test_helper.h @@ -21,6 +21,7 @@ #include "gtest/gtest.h" #include "src/program_builder.h" +#include "src/writer/msl/generator.h" #include "src/writer/msl/generator_impl.h" namespace tint { @@ -57,10 +58,11 @@ class TestHelperBase : public BASE, public ProgramBuilder { /// Builds the program, runs the program through the transform::Msl sanitizer /// and returns a GeneratorImpl from the sanitized program. + /// @param options The MSL generator options. /// @note The generator is only built once. Multiple calls to Build() will /// return the same GeneratorImpl without rebuilding. /// @return the built generator - GeneratorImpl& SanitizeAndBuild() { + GeneratorImpl& SanitizeAndBuild(const Options& options = {}) { if (gen_) { return *gen_; } @@ -74,7 +76,10 @@ class TestHelperBase : public BASE, public ProgramBuilder { << diag::Formatter().format(program->Diagnostics()); }(); - auto result = Sanitize(program.get(), 30); + auto result = Sanitize( + program.get(), options.buffer_size_ubo_index, options.fixed_sample_mask, + options.emit_vertex_point_size, options.disable_workgroup_init, + options.array_length_from_uniform); [&]() { ASSERT_TRUE(result.program.IsValid()) << diag::Formatter().format(result.program.Diagnostics()); diff --git a/test/BUILD.gn b/test/BUILD.gn index 907018a0c1..099765a883 100644 --- a/test/BUILD.gn +++ b/test/BUILD.gn @@ -591,6 +591,7 @@ tint_unittests_source_set("tint_unittests_msl_writer_src") { "../src/writer/msl/generator_impl_member_accessor_test.cc", "../src/writer/msl/generator_impl_module_constant_test.cc", "../src/writer/msl/generator_impl_return_test.cc", + "../src/writer/msl/generator_impl_sanitizer_test.cc", "../src/writer/msl/generator_impl_switch_test.cc", "../src/writer/msl/generator_impl_test.cc", "../src/writer/msl/generator_impl_type_test.cc",