diff --git a/src/BUILD.gn b/src/BUILD.gn index d6c95c8274..fb1bcfe667 100644 --- a/src/BUILD.gn +++ b/src/BUILD.gn @@ -491,6 +491,8 @@ libtint_source_set("libtint_core_all_src") { "utils/unique_vector.h", "writer/append_vector.cc", "writer/append_vector.h", + "writer/array_length_from_uniform_options.cc", + "writer/array_length_from_uniform_options.h", "writer/float_to_string.cc", "writer/float_to_string.h", "writer/text.cc", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b86cb98da6..9b746da875 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -403,6 +403,8 @@ set(TINT_LIB_SRCS utils/unique_vector.h writer/append_vector.cc writer/append_vector.h + writer/array_length_from_uniform_options.cc + writer/array_length_from_uniform_options.h writer/float_to_string.cc writer/float_to_string.h writer/text_generator.cc @@ -998,6 +1000,7 @@ if(${TINT_BUILD_TESTS}) writer/msl/generator_impl_member_accessor_test.cc writer/msl/generator_impl_module_constant_test.cc writer/msl/generator_impl_return_test.cc + writer/msl/generator_impl_sanitizer_test.cc writer/msl/generator_impl_switch_test.cc writer/msl/generator_impl_test.cc writer/msl/generator_impl_type_test.cc diff --git a/src/transform/array_length_from_uniform.cc b/src/transform/array_length_from_uniform.cc index 65f8f0ad3e..a30e550ca0 100644 --- a/src/transform/array_length_from_uniform.cc +++ b/src/transform/array_length_from_uniform.cc @@ -35,60 +35,18 @@ namespace transform { ArrayLengthFromUniform::ArrayLengthFromUniform() = default; ArrayLengthFromUniform::~ArrayLengthFromUniform() = default; -void ArrayLengthFromUniform::Run(CloneContext& ctx, - const DataMap& inputs, - DataMap& outputs) { - if (!Requires(ctx)) { - return; - } - - auto* cfg = inputs.Get(); - if (cfg == nullptr) { - ctx.dst->Diagnostics().add_error( - diag::System::Transform, - "missing transform data for " + std::string(TypeInfo().name)); - return; - } - +/// Iterate over all arrayLength() intrinsics that operate on +/// storage buffer variables. +/// @param ctx the CloneContext. +/// @param functor of type void(const ast::CallExpression*, const +/// sem::VariableUser, const sem::GlobalVariable*). It takes in an +/// ast::CallExpression of the arrayLength call expression node, a +/// sem::VariableUser of the used storage buffer variable, and the +/// sem::GlobalVariable for the storage buffer. +template +static void IterateArrayLengthOnStorageVar(CloneContext& ctx, F&& functor) { auto& sem = ctx.src->Sem(); - const char* kBufferSizeMemberName = "buffer_size"; - - // Determine the size of the buffer size array. - uint32_t max_buffer_size_index = 0; - for (auto& idx : cfg->bindpoint_to_size_index) { - if (idx.second > max_buffer_size_index) { - max_buffer_size_index = idx.second; - } - } - - // Get (or create, on first call) the uniform buffer that will receive the - // size of each storage buffer in the module. - const ast::Variable* buffer_size_ubo = nullptr; - auto get_ubo = [&]() { - if (!buffer_size_ubo) { - // Emit an array, N>, where N is 1/4 number of elements. - // We do this because UBOs require an element stride that is 16-byte - // aligned. - auto* buffer_size_struct = ctx.dst->Structure( - ctx.dst->Sym(), - {ctx.dst->Member( - kBufferSizeMemberName, - ctx.dst->ty.array(ctx.dst->ty.vec4(ctx.dst->ty.u32()), - (max_buffer_size_index / 4) + 1))}, - - ast::DecorationList{ctx.dst->create()}); - buffer_size_ubo = ctx.dst->Global( - ctx.dst->Sym(), ctx.dst->ty.Of(buffer_size_struct), - ast::StorageClass::kUniform, - ast::DecorationList{ - ctx.dst->create(cfg->ubo_binding.group), - ctx.dst->create( - cfg->ubo_binding.binding)}); - } - return buffer_size_ubo; - }; - // Find all calls to the arrayLength() intrinsic. for (auto* node : ctx.src->ASTNodes().Objects()) { auto* call_expr = node->As(); @@ -137,23 +95,91 @@ void ArrayLengthFromUniform::Run(CloneContext& ctx, << "storage buffer is not a global variable"; break; } + functor(call_expr, storage_buffer_sem, var); + } +} + +void ArrayLengthFromUniform::Run(CloneContext& ctx, + const DataMap& inputs, + DataMap& outputs) { + if (!Requires(ctx)) { + return; + } + + auto* cfg = inputs.Get(); + if (cfg == nullptr) { + ctx.dst->Diagnostics().add_error( + diag::System::Transform, + "missing transform data for " + std::string(TypeInfo().name)); + return; + } + + const char* kBufferSizeMemberName = "buffer_size"; + + // Determine the size of the buffer size array. + uint32_t max_buffer_size_index = 0; + + IterateArrayLengthOnStorageVar( + ctx, [&](const ast::CallExpression*, const sem::VariableUser*, + const sem::GlobalVariable* var) { + auto binding = var->BindingPoint(); + auto idx_itr = cfg->bindpoint_to_size_index.find(binding); + if (idx_itr == cfg->bindpoint_to_size_index.end()) { + return; + } + if (idx_itr->second > max_buffer_size_index) { + max_buffer_size_index = idx_itr->second; + } + }); + + // Get (or create, on first call) the uniform buffer that will receive the + // size of each storage buffer in the module. + const ast::Variable* buffer_size_ubo = nullptr; + auto get_ubo = [&]() { + if (!buffer_size_ubo) { + // Emit an array, N>, where N is 1/4 number of elements. + // We do this because UBOs require an element stride that is 16-byte + // aligned. + auto* buffer_size_struct = ctx.dst->Structure( + ctx.dst->Sym(), + {ctx.dst->Member( + kBufferSizeMemberName, + ctx.dst->ty.array(ctx.dst->ty.vec4(ctx.dst->ty.u32()), + (max_buffer_size_index / 4) + 1))}, + + ast::DecorationList{ctx.dst->create()}); + buffer_size_ubo = ctx.dst->Global( + ctx.dst->Sym(), ctx.dst->ty.Of(buffer_size_struct), + ast::StorageClass::kUniform, + ast::DecorationList{ + ctx.dst->create(cfg->ubo_binding.group), + ctx.dst->create( + cfg->ubo_binding.binding)}); + } + return buffer_size_ubo; + }; + + std::unordered_set used_size_indices; + + IterateArrayLengthOnStorageVar(ctx, [&](const ast::CallExpression* call_expr, + const sem::VariableUser* + storage_buffer_sem, + const sem::GlobalVariable* var) { auto binding = var->BindingPoint(); auto idx_itr = cfg->bindpoint_to_size_index.find(binding); if (idx_itr == cfg->bindpoint_to_size_index.end()) { - ctx.dst->Diagnostics().add_error( - diag::System::Transform, - "missing size index mapping for binding point (" + - std::to_string(binding.group) + "," + - std::to_string(binding.binding) + ")"); - continue; + return; } + uint32_t size_index = idx_itr->second; + used_size_indices.insert(size_index); + // Load the total storage buffer size from the UBO. - uint32_t array_index = idx_itr->second / 4; + uint32_t array_index = size_index / 4; auto* vec_expr = ctx.dst->IndexAccessor( ctx.dst->MemberAccessor(get_ubo()->symbol, kBufferSizeMemberName), array_index); - uint32_t vec_index = idx_itr->second % 4; + uint32_t vec_index = size_index % 4; auto* total_storage_buffer_size = ctx.dst->IndexAccessor(vec_expr, vec_index); @@ -170,20 +196,23 @@ void ArrayLengthFromUniform::Run(CloneContext& ctx, ctx.dst->Sub(total_storage_buffer_size, array_offset), array_stride); ctx.Replace(call_expr, array_length); - } + }); ctx.Clone(); - outputs.Add(buffer_size_ubo ? true : false); + outputs.Add(used_size_indices); } ArrayLengthFromUniform::Config::Config(sem::BindingPoint ubo_bp) : ubo_binding(ubo_bp) {} ArrayLengthFromUniform::Config::Config(const Config&) = default; +ArrayLengthFromUniform::Config& ArrayLengthFromUniform::Config::operator=( + const Config&) = default; ArrayLengthFromUniform::Config::~Config() = default; -ArrayLengthFromUniform::Result::Result(bool needs_sizes) - : needs_buffer_sizes(needs_sizes) {} +ArrayLengthFromUniform::Result::Result( + std::unordered_set used_size_indices_in) + : used_size_indices(std::move(used_size_indices_in)) {} ArrayLengthFromUniform::Result::Result(const Result&) = default; ArrayLengthFromUniform::Result::~Result() = default; diff --git a/src/transform/array_length_from_uniform.h b/src/transform/array_length_from_uniform.h index 306ed8a231..7063a27f5d 100644 --- a/src/transform/array_length_from_uniform.h +++ b/src/transform/array_length_from_uniform.h @@ -16,6 +16,7 @@ #define SRC_TRANSFORM_ARRAY_LENGTH_FROM_UNIFORM_H_ #include +#include #include "src/sem/binding_point.h" #include "src/transform/transform.h" @@ -66,6 +67,10 @@ class ArrayLengthFromUniform /// Copy constructor Config(const Config&); + /// Copy assignment + /// @return this Config + Config& operator=(const Config&); + /// Destructor ~Config() override; @@ -79,8 +84,8 @@ class ArrayLengthFromUniform /// Information produced about what the transform did. struct Result : public Castable { /// Constructor - /// @param needs_sizes True if the transform generated the buffer sizes UBO. - explicit Result(bool needs_sizes); + /// @param used_size_indices Indices into the UBO that are statically used. + explicit Result(std::unordered_set used_size_indices); /// Copy constructor Result(const Result&); @@ -88,8 +93,8 @@ class ArrayLengthFromUniform /// Destructor ~Result() override; - /// True if the transform generated the buffer sizes UBO. - const bool needs_buffer_sizes; + /// Indices into the UBO that are statically used. + const std::unordered_set used_size_indices; }; protected: diff --git a/src/transform/array_length_from_uniform_test.cc b/src/transform/array_length_from_uniform_test.cc index c41ba2e932..fd4eef922b 100644 --- a/src/transform/array_length_from_uniform_test.cc +++ b/src/transform/array_length_from_uniform_test.cc @@ -110,8 +110,8 @@ fn main() { Run(src, data); EXPECT_EQ(expect, str(got)); - EXPECT_TRUE( - got.data.Get()->needs_buffer_sizes); + EXPECT_EQ(std::unordered_set({0}), + got.data.Get()->used_size_indices); } TEST_F(ArrayLengthFromUniformTest, WithStride) { @@ -164,8 +164,8 @@ fn main() { Run(src, data); EXPECT_EQ(expect, str(got)); - EXPECT_TRUE( - got.data.Get()->needs_buffer_sizes); + EXPECT_EQ(std::unordered_set({0}), + got.data.Get()->used_size_indices); } TEST_F(ArrayLengthFromUniformTest, MultipleStorageBuffers) { @@ -286,8 +286,124 @@ fn main() { Run(src, data); EXPECT_EQ(expect, str(got)); - EXPECT_TRUE( - got.data.Get()->needs_buffer_sizes); + EXPECT_EQ(std::unordered_set({0, 1, 2, 3, 4}), + got.data.Get()->used_size_indices); +} + +TEST_F(ArrayLengthFromUniformTest, MultipleUnusedStorageBuffers) { + auto* src = R"( +[[block]] +struct SB1 { + x : i32; + arr1 : array; +}; +[[block]] +struct SB2 { + x : i32; + arr2 : array>; +}; +[[block]] +struct SB3 { + x : i32; + arr3 : array>; +}; +[[block]] +struct SB4 { + x : i32; + arr4 : array>; +}; +[[block]] +struct SB5 { + x : i32; + arr5 : array>; +}; + +[[group(0), binding(2)]] var sb1 : SB1; +[[group(1), binding(2)]] var sb2 : SB2; +[[group(2), binding(2)]] var sb3 : SB3; +[[group(3), binding(2)]] var sb4 : SB4; +[[group(4), binding(2)]] var sb5 : SB5; + +[[stage(compute), workgroup_size(1)]] +fn main() { + var len1 : u32 = arrayLength(&(sb1.arr1)); + var len3 : u32 = arrayLength(&(sb3.arr3)); + var x : u32 = (len1 + len3); +} +)"; + + auto* expect = R"( +[[block]] +struct tint_symbol { + buffer_size : array, 1u>; +}; + +[[group(0), binding(30)]] var tint_symbol_1 : tint_symbol; + +[[block]] +struct SB1 { + x : i32; + arr1 : array; +}; + +[[block]] +struct SB2 { + x : i32; + arr2 : array>; +}; + +[[block]] +struct SB3 { + x : i32; + arr3 : array>; +}; + +[[block]] +struct SB4 { + x : i32; + arr4 : array>; +}; + +[[block]] +struct SB5 { + x : i32; + arr5 : array>; +}; + +[[group(0), binding(2)]] var sb1 : SB1; + +[[group(1), binding(2)]] var sb2 : SB2; + +[[group(2), binding(2)]] var sb3 : SB3; + +[[group(3), binding(2)]] var sb4 : SB4; + +[[group(4), binding(2)]] var sb5 : SB5; + +[[stage(compute), workgroup_size(1)]] +fn main() { + var len1 : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u); + var len3 : u32 = ((tint_symbol_1.buffer_size[0u][2u] - 16u) / 16u); + var x : u32 = (len1 + len3); +} +)"; + + ArrayLengthFromUniform::Config cfg({0, 30u}); + cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2u}, 0); + cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{1u, 2u}, 1); + cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{2u, 2u}, 2); + cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{3u, 2u}, 3); + cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{4u, 2u}, 4); + + DataMap data; + data.Add(std::move(cfg)); + + auto got = + Run(src, data); + + EXPECT_EQ(expect, str(got)); + EXPECT_EQ(std::unordered_set({0, 2}), + got.data.Get()->used_size_indices); } TEST_F(ArrayLengthFromUniformTest, NoArrayLengthCalls) { @@ -316,8 +432,8 @@ fn main() { Run(src, data); EXPECT_EQ(src, str(got)); - EXPECT_FALSE( - got.data.Get()->needs_buffer_sizes); + EXPECT_EQ(std::unordered_set(), + got.data.Get()->used_size_indices); } TEST_F(ArrayLengthFromUniformTest, MissingBindingPointToIndexMapping) { @@ -346,7 +462,37 @@ fn main() { } )"; - auto* expect = "error: missing size index mapping for binding point (1,2)"; + auto* expect = R"( +[[block]] +struct tint_symbol { + buffer_size : array, 1u>; +}; + +[[group(0), binding(30)]] var tint_symbol_1 : tint_symbol; + +[[block]] +struct SB1 { + x : i32; + arr1 : array; +}; + +[[block]] +struct SB2 { + x : i32; + arr2 : array>; +}; + +[[group(0), binding(2)]] var sb1 : SB1; + +[[group(1), binding(2)]] var sb2 : SB2; + +[[stage(compute), workgroup_size(1)]] +fn main() { + var len1 : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u); + var len2 : u32 = arrayLength(&(sb2.arr2)); + var x : u32 = (len1 + len2); +} +)"; ArrayLengthFromUniform::Config cfg({0, 30u}); cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2}, 0); @@ -358,6 +504,8 @@ fn main() { Run(src, data); EXPECT_EQ(expect, str(got)); + EXPECT_EQ(std::unordered_set({0}), + got.data.Get()->used_size_indices); } } // namespace diff --git a/src/writer/array_length_from_uniform_options.cc b/src/writer/array_length_from_uniform_options.cc new file mode 100644 index 0000000000..b848a484e3 --- /dev/null +++ b/src/writer/array_length_from_uniform_options.cc @@ -0,0 +1,30 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/writer/array_length_from_uniform_options.h" + +namespace tint { +namespace writer { + +ArrayLengthFromUniformOptions::ArrayLengthFromUniformOptions() = default; +ArrayLengthFromUniformOptions::~ArrayLengthFromUniformOptions() = default; +ArrayLengthFromUniformOptions::ArrayLengthFromUniformOptions( + const ArrayLengthFromUniformOptions&) = default; +ArrayLengthFromUniformOptions& ArrayLengthFromUniformOptions::operator=( + const ArrayLengthFromUniformOptions&) = default; +ArrayLengthFromUniformOptions::ArrayLengthFromUniformOptions( + ArrayLengthFromUniformOptions&&) = default; + +} // namespace writer +} // namespace tint diff --git a/src/writer/array_length_from_uniform_options.h b/src/writer/array_length_from_uniform_options.h new file mode 100644 index 0000000000..ae92d4f028 --- /dev/null +++ b/src/writer/array_length_from_uniform_options.h @@ -0,0 +1,52 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_WRITER_ARRAY_LENGTH_FROM_UNIFORM_OPTIONS_H_ +#define SRC_WRITER_ARRAY_LENGTH_FROM_UNIFORM_OPTIONS_H_ + +#include + +#include "src/sem/binding_point.h" + +namespace tint { +namespace writer { + +/// Options used to specify a mapping of binding points to indices into a UBO +/// from which to load buffer sizes. +struct ArrayLengthFromUniformOptions { + /// Constructor + ArrayLengthFromUniformOptions(); + /// Destructor + ~ArrayLengthFromUniformOptions(); + /// Copy constructor + ArrayLengthFromUniformOptions(const ArrayLengthFromUniformOptions&); + /// Copy assignment + /// @returns this ArrayLengthFromUniformOptions + ArrayLengthFromUniformOptions& operator=( + const ArrayLengthFromUniformOptions&); + /// Move constructor + ArrayLengthFromUniformOptions(ArrayLengthFromUniformOptions&&); + + /// The binding point to use to generate a uniform buffer from which to read + /// buffer sizes. + sem::BindingPoint ubo_binding; + /// The mapping from storage buffer binding points to the index into the + /// uniform buffer where the length of the buffer is stored. + std::unordered_map bindpoint_to_size_index; +}; + +} // namespace writer +} // namespace tint + +#endif // SRC_WRITER_ARRAY_LENGTH_FROM_UNIFORM_OPTIONS_H_ diff --git a/src/writer/hlsl/generator.cc b/src/writer/hlsl/generator.cc index ce172f9a99..727b5d5525 100644 --- a/src/writer/hlsl/generator.cc +++ b/src/writer/hlsl/generator.cc @@ -20,6 +20,11 @@ namespace tint { namespace writer { namespace hlsl { +Options::Options() = default; +Options::~Options() = default; +Options::Options(const Options&) = default; +Options& Options::operator=(const Options&) = default; + Result::Result() = default; Result::~Result() = default; Result::Result(const Result&) = default; @@ -29,7 +34,8 @@ Result Generate(const Program* program, const Options& options) { // Sanitize the program. auto sanitized_result = Sanitize(program, options.root_constant_binding_point, - options.disable_workgroup_init); + options.disable_workgroup_init, + options.array_length_from_uniform); if (!sanitized_result.program.IsValid()) { result.success = false; result.error = sanitized_result.program.Diagnostics().str(); @@ -50,6 +56,9 @@ Result Generate(const Program* program, const Options& options) { } } + result.used_array_length_from_uniform_indices = + std::move(sanitized_result.used_array_length_from_uniform_indices); + return result; } diff --git a/src/writer/hlsl/generator.h b/src/writer/hlsl/generator.h index 693002980c..4dacbd4cda 100644 --- a/src/writer/hlsl/generator.h +++ b/src/writer/hlsl/generator.h @@ -17,11 +17,13 @@ #include #include +#include #include #include #include "src/ast/pipeline_stage.h" #include "src/sem/binding_point.h" +#include "src/writer/array_length_from_uniform_options.h" #include "src/writer/text.h" namespace tint { @@ -37,10 +39,23 @@ class GeneratorImpl; /// Configuration options used for generating HLSL. struct Options { + /// Constructor + Options(); + /// Destructor + ~Options(); + /// Copy constructor + Options(const Options&); + /// Copy assignment + /// @returns this Options + Options& operator=(const Options&); + /// The binding point to use for information passed via root constants. sem::BindingPoint root_constant_binding_point; /// Set to `true` to disable workgroup memory zero initialization bool disable_workgroup_init = false; + /// Options used to specify a mapping of binding points to indices into a UBO + /// from which to load buffer sizes. + ArrayLengthFromUniformOptions array_length_from_uniform = {}; }; /// The result produced when generating HLSL. @@ -65,6 +80,10 @@ struct Result { /// The list of entry points in the generated HLSL. std::vector> entry_points; + + /// Indices into the array_length_from_uniform binding that are statically + /// used. + std::unordered_set used_array_length_from_uniform_indices; }; /// Generate HLSL for a program, according to a set of configuration options. diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index 3d3895a56e..c60be69618 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -45,6 +45,7 @@ #include "src/sem/type_conversion.h" #include "src/sem/variable.h" #include "src/transform/add_empty_entry_point.h" +#include "src/transform/array_length_from_uniform.h" #include "src/transform/calculate_array_length.h" #include "src/transform/canonicalize_entry_point_io.h" #include "src/transform/decompose_memory_access.h" @@ -124,12 +125,24 @@ const char* LoopAttribute() { } // namespace -SanitizedResult Sanitize(const Program* in, - sem::BindingPoint root_constant_binding_point, - bool disable_workgroup_init) { +SanitizedResult::SanitizedResult() = default; +SanitizedResult::~SanitizedResult() = default; +SanitizedResult::SanitizedResult(SanitizedResult&&) = default; + +SanitizedResult Sanitize( + const Program* in, + sem::BindingPoint root_constant_binding_point, + bool disable_workgroup_init, + const ArrayLengthFromUniformOptions& array_length_from_uniform) { transform::Manager manager; transform::DataMap data; + // Build the config for the internal ArrayLengthFromUniform transform. + transform::ArrayLengthFromUniform::Config array_length_from_uniform_cfg( + array_length_from_uniform.ubo_binding); + array_length_from_uniform_cfg.bindpoint_to_size_index = + array_length_from_uniform.bindpoint_to_size_index; + // Attempt to convert `loop`s into for-loops. This is to try and massage the // output into something that will not cause FXC to choke or misbehave. manager.Add(); @@ -149,6 +162,11 @@ SanitizedResult Sanitize(const Program* in, // Simplify cleans up messy `*(&(expr))` expressions from InlinePointerLets. manager.Add(); manager.Add(); + // ArrayLengthFromUniform must come after InlinePointerLets and Simplify, as + // it assumes that the form of the array length argument is &var.array. + manager.Add(); + data.Add( + std::move(array_length_from_uniform_cfg)); // DecomposeMemoryAccess must come after: // * InlinePointerLets, as we cannot take the address of calls to // DecomposeMemoryAccess::Intrinsic. @@ -171,8 +189,13 @@ SanitizedResult Sanitize(const Program* in, data.Add( root_constant_binding_point); + auto out = manager.Run(in, data); + SanitizedResult result; - result.program = std::move(manager.Run(in, data).program); + result.program = std::move(out.program); + result.used_array_length_from_uniform_indices = + std::move(out.data.Get() + ->used_size_indices); return result; } diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h index 7c56b10aa9..b23897d6ef 100644 --- a/src/writer/hlsl/generator_impl.h +++ b/src/writer/hlsl/generator_impl.h @@ -36,6 +36,7 @@ #include "src/sem/binding_point.h" #include "src/transform/decompose_memory_access.h" #include "src/utils/hash.h" +#include "src/writer/array_length_from_uniform_options.h" #include "src/writer/text_generator.h" namespace tint { @@ -53,8 +54,18 @@ namespace hlsl { /// The result of sanitizing a program for generation. struct SanitizedResult { + /// Constructor + SanitizedResult(); + /// Destructor + ~SanitizedResult(); + /// Move constructor + SanitizedResult(SanitizedResult&&); + /// The sanitized program. Program program; + /// Indices into the array_length_from_uniform binding that are statically + /// used. + std::unordered_set used_array_length_from_uniform_indices; }; /// Sanitize a program in preparation for generating HLSL. @@ -62,9 +73,11 @@ struct SanitizedResult { /// that will be passed via root constants /// @param disable_workgroup_init `true` to disable workgroup memory zero /// @returns the sanitized program and any supplementary information -SanitizedResult Sanitize(const Program* program, - sem::BindingPoint root_constant_binding_point = {}, - bool disable_workgroup_init = false); +SanitizedResult Sanitize( + const Program* program, + sem::BindingPoint root_constant_binding_point = {}, + bool disable_workgroup_init = false, + const ArrayLengthFromUniformOptions& array_length_from_uniform = {}); /// Implementation class for HLSL generator class GeneratorImpl : public TextGenerator { diff --git a/src/writer/hlsl/generator_impl_sanitizer_test.cc b/src/writer/hlsl/generator_impl_sanitizer_test.cc index c3692a7683..fa6d00bc31 100644 --- a/src/writer/hlsl/generator_impl_sanitizer_test.cc +++ b/src/writer/hlsl/generator_impl_sanitizer_test.cc @@ -144,6 +144,58 @@ void a_func() { EXPECT_EQ(expect, got); } +TEST_F(HlslSanitizerTest, Call_ArrayLength_ArrayLengthFromUniform) { + auto* s = Structure("my_struct", {Member(0, "a", ty.array(4))}, + {create()}); + Global("b", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead, + ast::DecorationList{ + create(1), + create(2), + }); + Global("c", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead, + ast::DecorationList{ + create(2), + create(2), + }); + + Func("a_func", ast::VariableList{}, ty.void_(), + ast::StatementList{ + Decl(Var( + "len", ty.u32(), ast::StorageClass::kNone, + Add(Call("arrayLength", AddressOf(MemberAccessor("b", "a"))), + Call("arrayLength", AddressOf(MemberAccessor("c", "a")))))), + }, + ast::DecorationList{ + Stage(ast::PipelineStage::kFragment), + }); + + Options options; + options.array_length_from_uniform.ubo_binding = {3, 4}; + options.array_length_from_uniform.bindpoint_to_size_index.emplace( + sem::BindingPoint{2, 2}, 7u); + GeneratorImpl& gen = SanitizeAndBuild(options); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + auto got = gen.result(); + auto* expect = R"(cbuffer cbuffer_tint_symbol_1 : register(b4, space3) { + uint4 tint_symbol_1[2]; +}; + +ByteAddressBuffer b : register(t1, space2); +ByteAddressBuffer c : register(t2, space2); + +void a_func() { + uint tint_symbol_4 = 0u; + b.GetDimensions(tint_symbol_4); + const uint tint_symbol_5 = ((tint_symbol_4 - 0u) / 4u); + uint len = (tint_symbol_5 + ((tint_symbol_1[1].w - 0u) / 4u)); + return; +} +)"; + EXPECT_EQ(expect, got); +} + TEST_F(HlslSanitizerTest, PromoteArrayInitializerToConstVar) { auto* array_init = array(1, 2, 3, 4); auto* array_index = IndexAccessor(array_init, 3); diff --git a/src/writer/hlsl/test_helper.h b/src/writer/hlsl/test_helper.h index 7523c82cbb..e55b9c1e95 100644 --- a/src/writer/hlsl/test_helper.h +++ b/src/writer/hlsl/test_helper.h @@ -22,6 +22,7 @@ #include "gtest/gtest.h" #include "src/transform/manager.h" #include "src/transform/renamer.h" +#include "src/writer/hlsl/generator.h" #include "src/writer/hlsl/generator_impl.h" namespace tint { @@ -58,10 +59,11 @@ class TestHelperBase : public BODY, public ProgramBuilder { /// Builds the program, runs the program through the HLSL sanitizer /// and returns a GeneratorImpl from the sanitized program. + /// @param options The HLSL generator options. /// @note The generator is only built once. Multiple calls to Build() will /// return the same GeneratorImpl without rebuilding. /// @return the built generator - GeneratorImpl& SanitizeAndBuild() { + GeneratorImpl& SanitizeAndBuild(const Options& options = {}) { if (gen_) { return *gen_; } @@ -76,7 +78,9 @@ class TestHelperBase : public BODY, public ProgramBuilder { << formatter.format(program->Diagnostics()); }(); - auto sanitized_result = Sanitize(program.get()); + auto sanitized_result = Sanitize( + program.get(), options.root_constant_binding_point, + options.disable_workgroup_init, options.array_length_from_uniform); [&]() { ASSERT_TRUE(sanitized_result.program.IsValid()) << formatter.format(sanitized_result.program.Diagnostics()); diff --git a/src/writer/msl/generator.cc b/src/writer/msl/generator.cc index cc1aa5c4a3..92741e8d75 100644 --- a/src/writer/msl/generator.cc +++ b/src/writer/msl/generator.cc @@ -14,12 +14,19 @@ #include "src/writer/msl/generator.h" +#include + #include "src/writer/msl/generator_impl.h" namespace tint { namespace writer { namespace msl { +Options::Options() = default; +Options::~Options() = default; +Options::Options(const Options&) = default; +Options& Options::operator=(const Options&) = default; + Result::Result() = default; Result::~Result() = default; Result::Result(const Result&) = default; @@ -30,7 +37,8 @@ Result Generate(const Program* program, const Options& options) { // Sanitize the program. auto sanitized_result = Sanitize( program, options.buffer_size_ubo_index, options.fixed_sample_mask, - options.emit_vertex_point_size, options.disable_workgroup_init); + options.emit_vertex_point_size, options.disable_workgroup_init, + options.array_length_from_uniform); if (!sanitized_result.program.IsValid()) { result.success = false; result.error = sanitized_result.program.Diagnostics().str(); @@ -38,6 +46,8 @@ Result Generate(const Program* program, const Options& options) { } result.needs_storage_buffer_sizes = sanitized_result.needs_storage_buffer_sizes; + result.used_array_length_from_uniform_indices = + std::move(sanitized_result.used_array_length_from_uniform_indices); // Generate the MSL code. auto impl = std::make_unique(&sanitized_result.program); diff --git a/src/writer/msl/generator.h b/src/writer/msl/generator.h index af812d7161..63d793957c 100644 --- a/src/writer/msl/generator.h +++ b/src/writer/msl/generator.h @@ -18,8 +18,10 @@ #include #include #include +#include #include +#include "src/writer/array_length_from_uniform_options.h" #include "src/writer/text.h" namespace tint { @@ -34,6 +36,16 @@ class GeneratorImpl; /// Configuration options used for generating MSL. struct Options { + /// Constructor + Options(); + /// Destructor + ~Options(); + /// Copy constructor + Options(const Options&); + /// Copy assignment + /// @returns this Options + Options& operator=(const Options&); + /// The index to use when generating a UBO to receive storage buffer sizes. /// Defaults to 30, which is the last valid buffer slot. uint32_t buffer_size_ubo_index = 30; @@ -48,6 +60,10 @@ struct Options { /// Set to `true` to disable workgroup memory zero initialization bool disable_workgroup_init = false; + + /// Options used to specify a mapping of binding points to indices into a UBO + /// from which to load buffer sizes. + ArrayLengthFromUniformOptions array_length_from_uniform = {}; }; /// The result produced when generating MSL. @@ -80,6 +96,10 @@ struct Result { /// Each entry in the vector is the size of the workgroup allocation that /// should be created for that index. std::unordered_map> workgroup_allocations; + + /// Indices into the array_length_from_uniform binding that are statically + /// used. + std::unordered_set used_array_length_from_uniform_indices; }; /// Generate MSL for a program, according to a set of configuration options. The diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc index f7ec04f414..4cfc22e9fa 100644 --- a/src/writer/msl/generator_impl.cc +++ b/src/writer/msl/generator_impl.cc @@ -112,31 +112,47 @@ class ScopedBitCast { }; } // namespace -SanitizedResult Sanitize(const Program* in, - uint32_t buffer_size_ubo_index, - uint32_t fixed_sample_mask, - bool emit_vertex_point_size, - bool disable_workgroup_init) { +SanitizedResult::SanitizedResult() = default; +SanitizedResult::~SanitizedResult() = default; +SanitizedResult::SanitizedResult(SanitizedResult&&) = default; + +SanitizedResult Sanitize( + const Program* in, + uint32_t buffer_size_ubo_index, + uint32_t fixed_sample_mask, + bool emit_vertex_point_size, + bool disable_workgroup_init, + const ArrayLengthFromUniformOptions& array_length_from_uniform) { transform::Manager manager; transform::DataMap internal_inputs; - // Build the configs for the internal transforms. - auto array_length_from_uniform_cfg = - transform::ArrayLengthFromUniform::Config( - sem::BindingPoint{0, buffer_size_ubo_index}); + // Build the config for the internal ArrayLengthFromUniform transform. + transform::ArrayLengthFromUniform::Config array_length_from_uniform_cfg( + array_length_from_uniform.ubo_binding); + if (!array_length_from_uniform.bindpoint_to_size_index.empty()) { + // If |array_length_from_uniform| bindings are provided, use that config. + array_length_from_uniform_cfg.bindpoint_to_size_index = + array_length_from_uniform.bindpoint_to_size_index; + } else { + // If the binding map is empty, use the deprecated |buffer_size_ubo_index| + // and automatically choose indices using the binding numbers. + array_length_from_uniform_cfg = transform::ArrayLengthFromUniform::Config( + sem::BindingPoint{0, buffer_size_ubo_index}); + // Use the SSBO binding numbers as the indices for the buffer size lookups. + for (auto* var : in->AST().GlobalVariables()) { + auto* global = in->Sem().Get(var); + if (global && global->StorageClass() == ast::StorageClass::kStorage) { + array_length_from_uniform_cfg.bindpoint_to_size_index.emplace( + global->BindingPoint(), global->BindingPoint().binding); + } + } + } + + // Build the configs for the internal CanonicalizeEntryPointIO transform. auto entry_point_io_cfg = transform::CanonicalizeEntryPointIO::Config( transform::CanonicalizeEntryPointIO::ShaderStyle::kMsl, fixed_sample_mask, emit_vertex_point_size); - // Use the SSBO binding numbers as the indices for the buffer size lookups. - for (auto* var : in->AST().GlobalVariables()) { - auto* global = in->Sem().Get(var); - if (global && global->StorageClass() == ast::StorageClass::kStorage) { - array_length_from_uniform_cfg.bindpoint_to_size_index.emplace( - global->BindingPoint(), global->BindingPoint().binding); - } - } - if (!disable_workgroup_init) { // ZeroInitWorkgroupMemory must come before CanonicalizeEntryPointIO as // ZeroInitWorkgroupMemory may inject new builtin parameters. @@ -160,13 +176,18 @@ SanitizedResult Sanitize(const Program* in, internal_inputs.Add( std::move(entry_point_io_cfg)); auto out = manager.Run(in, internal_inputs); - if (!out.program.IsValid()) { - return {std::move(out.program)}; - } - return {std::move(out.program), - out.data.Get() - ->needs_buffer_sizes}; + SanitizedResult result; + result.program = std::move(out.program); + if (!result.program.IsValid()) { + return result; + } + result.used_array_length_from_uniform_indices = + std::move(out.data.Get() + ->used_size_indices); + result.needs_storage_buffer_sizes = + !result.used_array_length_from_uniform_indices.empty(); + return result; } GeneratorImpl::GeneratorImpl(const Program* program) : TextGenerator(program) {} @@ -1314,6 +1335,13 @@ std::string GeneratorImpl::generate_builtin_name( case sem::IntrinsicType::kUnpack2x16unorm: out += "unpack_unorm2x16_to_float"; break; + case sem::IntrinsicType::kArrayLength: + diagnostics_.add_error( + diag::System::Writer, + "Unable to translate builtin: " + std::string(intrinsic->str()) + + "\nDid you forget to pass array_length_from_uniform generator " + "options?"); + return ""; default: diagnostics_.add_error( diag::System::Writer, diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h index 8b73a31c5c..710b8e2375 100644 --- a/src/writer/msl/generator_impl.h +++ b/src/writer/msl/generator_impl.h @@ -17,6 +17,7 @@ #include #include +#include #include #include "src/ast/assignment_statement.h" @@ -37,6 +38,7 @@ #include "src/program.h" #include "src/scope_stack.h" #include "src/sem/struct.h" +#include "src/writer/array_length_from_uniform_options.h" #include "src/writer/text_generator.h" namespace tint { @@ -54,10 +56,20 @@ namespace msl { /// The result of sanitizing a program for generation. struct SanitizedResult { + /// Constructor + SanitizedResult(); + /// Destructor + ~SanitizedResult(); + /// Move constructor + SanitizedResult(SanitizedResult&&); + /// The sanitized program. Program program; /// True if the shader needs a UBO of buffer sizes. bool needs_storage_buffer_sizes = false; + /// Indices into the array_length_from_uniform binding that are statically + /// used. + std::unordered_set used_array_length_from_uniform_indices; }; /// Sanitize a program in preparation for generating MSL. @@ -66,11 +78,13 @@ struct SanitizedResult { /// @param emit_vertex_point_size `true` to emit a vertex point size builtin /// @param disable_workgroup_init `true` to disable workgroup memory zero /// @returns the sanitized program and any supplementary information -SanitizedResult Sanitize(const Program* program, - uint32_t buffer_size_ubo_index, - uint32_t fixed_sample_mask = 0xFFFFFFFF, - bool emit_vertex_point_size = false, - bool disable_workgroup_init = false); +SanitizedResult Sanitize( + const Program* program, + uint32_t buffer_size_ubo_index, + uint32_t fixed_sample_mask = 0xFFFFFFFF, + bool emit_vertex_point_size = false, + bool disable_workgroup_init = false, + const ArrayLengthFromUniformOptions& array_length_from_uniform = {}); /// Implementation class for MSL generator class GeneratorImpl : public TextGenerator { diff --git a/src/writer/msl/generator_impl_sanitizer_test.cc b/src/writer/msl/generator_impl_sanitizer_test.cc new file mode 100644 index 0000000000..c8fb11653e --- /dev/null +++ b/src/writer/msl/generator_impl_sanitizer_test.cc @@ -0,0 +1,264 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gmock/gmock.h" +#include "src/ast/call_statement.h" +#include "src/ast/stage_decoration.h" +#include "src/ast/struct_block_decoration.h" +#include "src/ast/variable_decl_statement.h" +#include "src/writer/msl/test_helper.h" + +namespace tint { +namespace writer { +namespace msl { +namespace { + +using ::testing::HasSubstr; + +using MslSanitizerTest = TestHelper; + +TEST_F(MslSanitizerTest, Call_ArrayLength) { + auto* s = Structure("my_struct", {Member(0, "a", ty.array(4))}, + {create()}); + Global("b", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead, + ast::DecorationList{ + create(1), + create(2), + }); + + Func("a_func", ast::VariableList{}, ty.void_(), + ast::StatementList{ + Decl(Var("len", ty.u32(), ast::StorageClass::kNone, + Call("arrayLength", AddressOf(MemberAccessor("b", "a"))))), + }, + ast::DecorationList{ + Stage(ast::PipelineStage::kFragment), + }); + + GeneratorImpl& gen = SanitizeAndBuild(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + auto got = gen.result(); + auto* expect = R"(#include + +using namespace metal; +struct tint_symbol { + /* 0x0000 */ uint4 buffer_size[1]; +}; +struct my_struct { + float a[1]; +}; + +fragment void a_func(const constant tint_symbol* tint_symbol_2 [[buffer(30)]]) { + uint len = (((*(tint_symbol_2)).buffer_size[0u][1u] - 0u) / 4u); + return; +} + +)"; + EXPECT_EQ(expect, got); +} + +TEST_F(MslSanitizerTest, Call_ArrayLength_OtherMembersInStruct) { + auto* s = Structure("my_struct", + { + Member(0, "z", ty.f32()), + Member(4, "a", ty.array(4)), + }, + {create()}); + Global("b", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead, + ast::DecorationList{ + create(1), + create(2), + }); + + Func("a_func", ast::VariableList{}, ty.void_(), + ast::StatementList{ + Decl(Var("len", ty.u32(), ast::StorageClass::kNone, + Call("arrayLength", AddressOf(MemberAccessor("b", "a"))))), + }, + ast::DecorationList{ + Stage(ast::PipelineStage::kFragment), + }); + + GeneratorImpl& gen = SanitizeAndBuild(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + auto got = gen.result(); + auto* expect = R"(#include + +using namespace metal; +struct tint_symbol { + /* 0x0000 */ uint4 buffer_size[1]; +}; +struct my_struct { + float z; + float a[1]; +}; + +fragment void a_func(const constant tint_symbol* tint_symbol_2 [[buffer(30)]]) { + uint len = (((*(tint_symbol_2)).buffer_size[0u][1u] - 4u) / 4u); + return; +} + +)"; + + EXPECT_EQ(expect, got); +} + +TEST_F(MslSanitizerTest, Call_ArrayLength_ViaLets) { + auto* s = Structure("my_struct", {Member(0, "a", ty.array(4))}, + {create()}); + Global("b", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead, + ast::DecorationList{ + create(1), + create(2), + }); + + auto* p = Const("p", nullptr, AddressOf("b")); + auto* p2 = Const("p2", nullptr, AddressOf(MemberAccessor(Deref(p), "a"))); + + Func("a_func", ast::VariableList{}, ty.void_(), + ast::StatementList{ + Decl(p), + Decl(p2), + Decl(Var("len", ty.u32(), ast::StorageClass::kNone, + Call("arrayLength", p2))), + }, + ast::DecorationList{ + Stage(ast::PipelineStage::kFragment), + }); + + GeneratorImpl& gen = SanitizeAndBuild(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + auto got = gen.result(); + auto* expect = R"(#include + +using namespace metal; +struct tint_symbol { + /* 0x0000 */ uint4 buffer_size[1]; +}; +struct my_struct { + float a[1]; +}; + +fragment void a_func(const constant tint_symbol* tint_symbol_2 [[buffer(30)]]) { + uint len = (((*(tint_symbol_2)).buffer_size[0u][1u] - 0u) / 4u); + return; +} + +)"; + + EXPECT_EQ(expect, got); +} + +TEST_F(MslSanitizerTest, Call_ArrayLength_ArrayLengthFromUniform) { + auto* s = Structure("my_struct", {Member(0, "a", ty.array(4))}, + {create()}); + Global("b", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead, + ast::DecorationList{ + create(1), + create(0), + }); + Global("c", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead, + ast::DecorationList{ + create(2), + create(0), + }); + + Func("a_func", ast::VariableList{}, ty.void_(), + ast::StatementList{ + Decl(Var( + "len", ty.u32(), ast::StorageClass::kNone, + Add(Call("arrayLength", AddressOf(MemberAccessor("b", "a"))), + Call("arrayLength", AddressOf(MemberAccessor("c", "a")))))), + }, + ast::DecorationList{ + Stage(ast::PipelineStage::kFragment), + }); + + Options options; + options.array_length_from_uniform.ubo_binding = {0, 29}; + options.array_length_from_uniform.bindpoint_to_size_index.emplace( + sem::BindingPoint{0, 1}, 7u); + options.array_length_from_uniform.bindpoint_to_size_index.emplace( + sem::BindingPoint{0, 2}, 2u); + GeneratorImpl& gen = SanitizeAndBuild(options); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + auto got = gen.result(); + auto* expect = R"(#include + +using namespace metal; +struct tint_symbol { + /* 0x0000 */ uint4 buffer_size[2]; +}; +struct my_struct { + float a[1]; +}; + +fragment void a_func(const constant tint_symbol* tint_symbol_2 [[buffer(29)]]) { + uint len = ((((*(tint_symbol_2)).buffer_size[1u][3u] - 0u) / 4u) + (((*(tint_symbol_2)).buffer_size[0u][2u] - 0u) / 4u)); + return; +} + +)"; + EXPECT_EQ(expect, got); +} + +TEST_F(MslSanitizerTest, + Call_ArrayLength_ArrayLengthFromUniformMissingBinding) { + auto* s = Structure("my_struct", {Member(0, "a", ty.array(4))}, + {create()}); + Global("b", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead, + ast::DecorationList{ + create(1), + create(0), + }); + Global("c", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead, + ast::DecorationList{ + create(2), + create(0), + }); + + Func("a_func", ast::VariableList{}, ty.void_(), + ast::StatementList{ + Decl(Var( + "len", ty.u32(), ast::StorageClass::kNone, + Add(Call("arrayLength", AddressOf(MemberAccessor("b", "a"))), + Call("arrayLength", AddressOf(MemberAccessor("c", "a")))))), + }, + ast::DecorationList{ + Stage(ast::PipelineStage::kFragment), + }); + + Options options; + options.array_length_from_uniform.ubo_binding = {0, 29}; + options.array_length_from_uniform.bindpoint_to_size_index.emplace( + sem::BindingPoint{0, 2}, 2u); + GeneratorImpl& gen = SanitizeAndBuild(options); + + ASSERT_FALSE(gen.Generate()); + EXPECT_THAT(gen.error(), + HasSubstr("Unable to translate builtin: arrayLength")); +} + +} // namespace +} // namespace msl +} // namespace writer +} // namespace tint diff --git a/src/writer/msl/test_helper.h b/src/writer/msl/test_helper.h index 762b2b6528..8b4cf3b2f2 100644 --- a/src/writer/msl/test_helper.h +++ b/src/writer/msl/test_helper.h @@ -21,6 +21,7 @@ #include "gtest/gtest.h" #include "src/program_builder.h" +#include "src/writer/msl/generator.h" #include "src/writer/msl/generator_impl.h" namespace tint { @@ -57,10 +58,11 @@ class TestHelperBase : public BASE, public ProgramBuilder { /// Builds the program, runs the program through the transform::Msl sanitizer /// and returns a GeneratorImpl from the sanitized program. + /// @param options The MSL generator options. /// @note The generator is only built once. Multiple calls to Build() will /// return the same GeneratorImpl without rebuilding. /// @return the built generator - GeneratorImpl& SanitizeAndBuild() { + GeneratorImpl& SanitizeAndBuild(const Options& options = {}) { if (gen_) { return *gen_; } @@ -74,7 +76,10 @@ class TestHelperBase : public BASE, public ProgramBuilder { << diag::Formatter().format(program->Diagnostics()); }(); - auto result = Sanitize(program.get(), 30); + auto result = Sanitize( + program.get(), options.buffer_size_ubo_index, options.fixed_sample_mask, + options.emit_vertex_point_size, options.disable_workgroup_init, + options.array_length_from_uniform); [&]() { ASSERT_TRUE(result.program.IsValid()) << diag::Formatter().format(result.program.Diagnostics()); diff --git a/test/BUILD.gn b/test/BUILD.gn index 907018a0c1..099765a883 100644 --- a/test/BUILD.gn +++ b/test/BUILD.gn @@ -591,6 +591,7 @@ tint_unittests_source_set("tint_unittests_msl_writer_src") { "../src/writer/msl/generator_impl_member_accessor_test.cc", "../src/writer/msl/generator_impl_module_constant_test.cc", "../src/writer/msl/generator_impl_return_test.cc", + "../src/writer/msl/generator_impl_sanitizer_test.cc", "../src/writer/msl/generator_impl_switch_test.cc", "../src/writer/msl/generator_impl_test.cc", "../src/writer/msl/generator_impl_type_test.cc",