Add HLSL/MSL generator options for ArrayLengthFromUniform

ArrayLengthFromUniform is needed for correct bounds checks on
dynamic storage buffers on D3D12. The intrinsic GetDimensions does
not return the actual size of the buffer binding.

ArrayLengthFromUniform is updated to output the indices of the
uniform buffer that are statically used. This allows Dawn to minimize
the amount of data needed to upload into the uniform buffer.
These output indices are returned on the HLSL/MSL generator result.

ArrayLengthFromUniform is also updated to allow only some of the
arrayLength calls to be replaced with uniform buffer loads. For HLSL
output, the remaining arrayLength computations will continue to use
GetDimensions(). For MSL, it is invalid to not specify an index into
the uniform buffer for all storage buffers.

After Dawn is updated to use the array_length_from_uniform option in the
Metal backend, the buffer_size_ubo_index member for MSL output may be
removed.

Bug: dawn:429
Change-Id: I9da4ec4a20882e9f1bfa5bb026725d72529eff26
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/69301
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Austin Eng <enga@chromium.org>
This commit is contained in:
Austin Eng 2021-11-19 04:11:33 +00:00 committed by Tint LUCI CQ
parent a660b510ac
commit 38fa643702
20 changed files with 850 additions and 119 deletions

View File

@ -491,6 +491,8 @@ libtint_source_set("libtint_core_all_src") {
"utils/unique_vector.h",
"writer/append_vector.cc",
"writer/append_vector.h",
"writer/array_length_from_uniform_options.cc",
"writer/array_length_from_uniform_options.h",
"writer/float_to_string.cc",
"writer/float_to_string.h",
"writer/text.cc",

View File

@ -403,6 +403,8 @@ set(TINT_LIB_SRCS
utils/unique_vector.h
writer/append_vector.cc
writer/append_vector.h
writer/array_length_from_uniform_options.cc
writer/array_length_from_uniform_options.h
writer/float_to_string.cc
writer/float_to_string.h
writer/text_generator.cc
@ -998,6 +1000,7 @@ if(${TINT_BUILD_TESTS})
writer/msl/generator_impl_member_accessor_test.cc
writer/msl/generator_impl_module_constant_test.cc
writer/msl/generator_impl_return_test.cc
writer/msl/generator_impl_sanitizer_test.cc
writer/msl/generator_impl_switch_test.cc
writer/msl/generator_impl_test.cc
writer/msl/generator_impl_type_test.cc

View File

@ -35,60 +35,18 @@ namespace transform {
ArrayLengthFromUniform::ArrayLengthFromUniform() = default;
ArrayLengthFromUniform::~ArrayLengthFromUniform() = default;
void ArrayLengthFromUniform::Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) {
if (!Requires<InlinePointerLets, Simplify>(ctx)) {
return;
}
auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform,
"missing transform data for " + std::string(TypeInfo().name));
return;
}
/// Iterate over all arrayLength() intrinsics that operate on
/// storage buffer variables.
/// @param ctx the CloneContext.
/// @param functor of type void(const ast::CallExpression*, const
/// sem::VariableUser, const sem::GlobalVariable*). It takes in an
/// ast::CallExpression of the arrayLength call expression node, a
/// sem::VariableUser of the used storage buffer variable, and the
/// sem::GlobalVariable for the storage buffer.
template <typename F>
static void IterateArrayLengthOnStorageVar(CloneContext& ctx, F&& functor) {
auto& sem = ctx.src->Sem();
const char* kBufferSizeMemberName = "buffer_size";
// Determine the size of the buffer size array.
uint32_t max_buffer_size_index = 0;
for (auto& idx : cfg->bindpoint_to_size_index) {
if (idx.second > max_buffer_size_index) {
max_buffer_size_index = idx.second;
}
}
// Get (or create, on first call) the uniform buffer that will receive the
// size of each storage buffer in the module.
const ast::Variable* buffer_size_ubo = nullptr;
auto get_ubo = [&]() {
if (!buffer_size_ubo) {
// Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
// We do this because UBOs require an element stride that is 16-byte
// aligned.
auto* buffer_size_struct = ctx.dst->Structure(
ctx.dst->Sym(),
{ctx.dst->Member(
kBufferSizeMemberName,
ctx.dst->ty.array(ctx.dst->ty.vec4(ctx.dst->ty.u32()),
(max_buffer_size_index / 4) + 1))},
ast::DecorationList{ctx.dst->create<ast::StructBlockDecoration>()});
buffer_size_ubo = ctx.dst->Global(
ctx.dst->Sym(), ctx.dst->ty.Of(buffer_size_struct),
ast::StorageClass::kUniform,
ast::DecorationList{
ctx.dst->create<ast::GroupDecoration>(cfg->ubo_binding.group),
ctx.dst->create<ast::BindingDecoration>(
cfg->ubo_binding.binding)});
}
return buffer_size_ubo;
};
// Find all calls to the arrayLength() intrinsic.
for (auto* node : ctx.src->ASTNodes().Objects()) {
auto* call_expr = node->As<ast::CallExpression>();
@ -137,23 +95,91 @@ void ArrayLengthFromUniform::Run(CloneContext& ctx,
<< "storage buffer is not a global variable";
break;
}
functor(call_expr, storage_buffer_sem, var);
}
}
void ArrayLengthFromUniform::Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) {
if (!Requires<InlinePointerLets, Simplify>(ctx)) {
return;
}
auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform,
"missing transform data for " + std::string(TypeInfo().name));
return;
}
const char* kBufferSizeMemberName = "buffer_size";
// Determine the size of the buffer size array.
uint32_t max_buffer_size_index = 0;
IterateArrayLengthOnStorageVar(
ctx, [&](const ast::CallExpression*, const sem::VariableUser*,
const sem::GlobalVariable* var) {
auto binding = var->BindingPoint();
auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
return;
}
if (idx_itr->second > max_buffer_size_index) {
max_buffer_size_index = idx_itr->second;
}
});
// Get (or create, on first call) the uniform buffer that will receive the
// size of each storage buffer in the module.
const ast::Variable* buffer_size_ubo = nullptr;
auto get_ubo = [&]() {
if (!buffer_size_ubo) {
// Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
// We do this because UBOs require an element stride that is 16-byte
// aligned.
auto* buffer_size_struct = ctx.dst->Structure(
ctx.dst->Sym(),
{ctx.dst->Member(
kBufferSizeMemberName,
ctx.dst->ty.array(ctx.dst->ty.vec4(ctx.dst->ty.u32()),
(max_buffer_size_index / 4) + 1))},
ast::DecorationList{ctx.dst->create<ast::StructBlockDecoration>()});
buffer_size_ubo = ctx.dst->Global(
ctx.dst->Sym(), ctx.dst->ty.Of(buffer_size_struct),
ast::StorageClass::kUniform,
ast::DecorationList{
ctx.dst->create<ast::GroupDecoration>(cfg->ubo_binding.group),
ctx.dst->create<ast::BindingDecoration>(
cfg->ubo_binding.binding)});
}
return buffer_size_ubo;
};
std::unordered_set<uint32_t> used_size_indices;
IterateArrayLengthOnStorageVar(ctx, [&](const ast::CallExpression* call_expr,
const sem::VariableUser*
storage_buffer_sem,
const sem::GlobalVariable* var) {
auto binding = var->BindingPoint();
auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform,
"missing size index mapping for binding point (" +
std::to_string(binding.group) + "," +
std::to_string(binding.binding) + ")");
continue;
return;
}
uint32_t size_index = idx_itr->second;
used_size_indices.insert(size_index);
// Load the total storage buffer size from the UBO.
uint32_t array_index = idx_itr->second / 4;
uint32_t array_index = size_index / 4;
auto* vec_expr = ctx.dst->IndexAccessor(
ctx.dst->MemberAccessor(get_ubo()->symbol, kBufferSizeMemberName),
array_index);
uint32_t vec_index = idx_itr->second % 4;
uint32_t vec_index = size_index % 4;
auto* total_storage_buffer_size =
ctx.dst->IndexAccessor(vec_expr, vec_index);
@ -170,20 +196,23 @@ void ArrayLengthFromUniform::Run(CloneContext& ctx,
ctx.dst->Sub(total_storage_buffer_size, array_offset), array_stride);
ctx.Replace(call_expr, array_length);
}
});
ctx.Clone();
outputs.Add<Result>(buffer_size_ubo ? true : false);
outputs.Add<Result>(used_size_indices);
}
ArrayLengthFromUniform::Config::Config(sem::BindingPoint ubo_bp)
: ubo_binding(ubo_bp) {}
ArrayLengthFromUniform::Config::Config(const Config&) = default;
ArrayLengthFromUniform::Config& ArrayLengthFromUniform::Config::operator=(
const Config&) = default;
ArrayLengthFromUniform::Config::~Config() = default;
ArrayLengthFromUniform::Result::Result(bool needs_sizes)
: needs_buffer_sizes(needs_sizes) {}
ArrayLengthFromUniform::Result::Result(
std::unordered_set<uint32_t> used_size_indices_in)
: used_size_indices(std::move(used_size_indices_in)) {}
ArrayLengthFromUniform::Result::Result(const Result&) = default;
ArrayLengthFromUniform::Result::~Result() = default;

View File

@ -16,6 +16,7 @@
#define SRC_TRANSFORM_ARRAY_LENGTH_FROM_UNIFORM_H_
#include <unordered_map>
#include <unordered_set>
#include "src/sem/binding_point.h"
#include "src/transform/transform.h"
@ -66,6 +67,10 @@ class ArrayLengthFromUniform
/// Copy constructor
Config(const Config&);
/// Copy assignment
/// @return this Config
Config& operator=(const Config&);
/// Destructor
~Config() override;
@ -79,8 +84,8 @@ class ArrayLengthFromUniform
/// Information produced about what the transform did.
struct Result : public Castable<Result, transform::Data> {
/// Constructor
/// @param needs_sizes True if the transform generated the buffer sizes UBO.
explicit Result(bool needs_sizes);
/// @param used_size_indices Indices into the UBO that are statically used.
explicit Result(std::unordered_set<uint32_t> used_size_indices);
/// Copy constructor
Result(const Result&);
@ -88,8 +93,8 @@ class ArrayLengthFromUniform
/// Destructor
~Result() override;
/// True if the transform generated the buffer sizes UBO.
const bool needs_buffer_sizes;
/// Indices into the UBO that are statically used.
const std::unordered_set<uint32_t> used_size_indices;
};
protected:

View File

@ -110,8 +110,8 @@ fn main() {
Run<InlinePointerLets, Simplify, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(expect, str(got));
EXPECT_TRUE(
got.data.Get<ArrayLengthFromUniform::Result>()->needs_buffer_sizes);
EXPECT_EQ(std::unordered_set<uint32_t>({0}),
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
}
TEST_F(ArrayLengthFromUniformTest, WithStride) {
@ -164,8 +164,8 @@ fn main() {
Run<InlinePointerLets, Simplify, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(expect, str(got));
EXPECT_TRUE(
got.data.Get<ArrayLengthFromUniform::Result>()->needs_buffer_sizes);
EXPECT_EQ(std::unordered_set<uint32_t>({0}),
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
}
TEST_F(ArrayLengthFromUniformTest, MultipleStorageBuffers) {
@ -286,8 +286,124 @@ fn main() {
Run<InlinePointerLets, Simplify, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(expect, str(got));
EXPECT_TRUE(
got.data.Get<ArrayLengthFromUniform::Result>()->needs_buffer_sizes);
EXPECT_EQ(std::unordered_set<uint32_t>({0, 1, 2, 3, 4}),
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
}
TEST_F(ArrayLengthFromUniformTest, MultipleUnusedStorageBuffers) {
auto* src = R"(
[[block]]
struct SB1 {
x : i32;
arr1 : array<i32>;
};
[[block]]
struct SB2 {
x : i32;
arr2 : array<vec4<f32>>;
};
[[block]]
struct SB3 {
x : i32;
arr3 : array<vec4<f32>>;
};
[[block]]
struct SB4 {
x : i32;
arr4 : array<vec4<f32>>;
};
[[block]]
struct SB5 {
x : i32;
arr5 : array<vec4<f32>>;
};
[[group(0), binding(2)]] var<storage, read> sb1 : SB1;
[[group(1), binding(2)]] var<storage, read> sb2 : SB2;
[[group(2), binding(2)]] var<storage, read> sb3 : SB3;
[[group(3), binding(2)]] var<storage, read> sb4 : SB4;
[[group(4), binding(2)]] var<storage, read> sb5 : SB5;
[[stage(compute), workgroup_size(1)]]
fn main() {
var len1 : u32 = arrayLength(&(sb1.arr1));
var len3 : u32 = arrayLength(&(sb3.arr3));
var x : u32 = (len1 + len3);
}
)";
auto* expect = R"(
[[block]]
struct tint_symbol {
buffer_size : array<vec4<u32>, 1u>;
};
[[group(0), binding(30)]] var<uniform> tint_symbol_1 : tint_symbol;
[[block]]
struct SB1 {
x : i32;
arr1 : array<i32>;
};
[[block]]
struct SB2 {
x : i32;
arr2 : array<vec4<f32>>;
};
[[block]]
struct SB3 {
x : i32;
arr3 : array<vec4<f32>>;
};
[[block]]
struct SB4 {
x : i32;
arr4 : array<vec4<f32>>;
};
[[block]]
struct SB5 {
x : i32;
arr5 : array<vec4<f32>>;
};
[[group(0), binding(2)]] var<storage, read> sb1 : SB1;
[[group(1), binding(2)]] var<storage, read> sb2 : SB2;
[[group(2), binding(2)]] var<storage, read> sb3 : SB3;
[[group(3), binding(2)]] var<storage, read> sb4 : SB4;
[[group(4), binding(2)]] var<storage, read> sb5 : SB5;
[[stage(compute), workgroup_size(1)]]
fn main() {
var len1 : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
var len3 : u32 = ((tint_symbol_1.buffer_size[0u][2u] - 16u) / 16u);
var x : u32 = (len1 + len3);
}
)";
ArrayLengthFromUniform::Config cfg({0, 30u});
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2u}, 0);
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{1u, 2u}, 1);
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{2u, 2u}, 2);
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{3u, 2u}, 3);
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{4u, 2u}, 4);
DataMap data;
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
auto got =
Run<InlinePointerLets, Simplify, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(expect, str(got));
EXPECT_EQ(std::unordered_set<uint32_t>({0, 2}),
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
}
TEST_F(ArrayLengthFromUniformTest, NoArrayLengthCalls) {
@ -316,8 +432,8 @@ fn main() {
Run<InlinePointerLets, Simplify, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(src, str(got));
EXPECT_FALSE(
got.data.Get<ArrayLengthFromUniform::Result>()->needs_buffer_sizes);
EXPECT_EQ(std::unordered_set<uint32_t>(),
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
}
TEST_F(ArrayLengthFromUniformTest, MissingBindingPointToIndexMapping) {
@ -346,7 +462,37 @@ fn main() {
}
)";
auto* expect = "error: missing size index mapping for binding point (1,2)";
auto* expect = R"(
[[block]]
struct tint_symbol {
buffer_size : array<vec4<u32>, 1u>;
};
[[group(0), binding(30)]] var<uniform> tint_symbol_1 : tint_symbol;
[[block]]
struct SB1 {
x : i32;
arr1 : array<i32>;
};
[[block]]
struct SB2 {
x : i32;
arr2 : array<vec4<f32>>;
};
[[group(0), binding(2)]] var<storage, read> sb1 : SB1;
[[group(1), binding(2)]] var<storage, read> sb2 : SB2;
[[stage(compute), workgroup_size(1)]]
fn main() {
var len1 : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
var len2 : u32 = arrayLength(&(sb2.arr2));
var x : u32 = (len1 + len2);
}
)";
ArrayLengthFromUniform::Config cfg({0, 30u});
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2}, 0);
@ -358,6 +504,8 @@ fn main() {
Run<InlinePointerLets, Simplify, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(expect, str(got));
EXPECT_EQ(std::unordered_set<uint32_t>({0}),
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
}
} // namespace

View File

@ -0,0 +1,30 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/writer/array_length_from_uniform_options.h"
namespace tint {
namespace writer {
ArrayLengthFromUniformOptions::ArrayLengthFromUniformOptions() = default;
ArrayLengthFromUniformOptions::~ArrayLengthFromUniformOptions() = default;
ArrayLengthFromUniformOptions::ArrayLengthFromUniformOptions(
const ArrayLengthFromUniformOptions&) = default;
ArrayLengthFromUniformOptions& ArrayLengthFromUniformOptions::operator=(
const ArrayLengthFromUniformOptions&) = default;
ArrayLengthFromUniformOptions::ArrayLengthFromUniformOptions(
ArrayLengthFromUniformOptions&&) = default;
} // namespace writer
} // namespace tint

View File

@ -0,0 +1,52 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_WRITER_ARRAY_LENGTH_FROM_UNIFORM_OPTIONS_H_
#define SRC_WRITER_ARRAY_LENGTH_FROM_UNIFORM_OPTIONS_H_
#include <unordered_map>
#include "src/sem/binding_point.h"
namespace tint {
namespace writer {
/// Options used to specify a mapping of binding points to indices into a UBO
/// from which to load buffer sizes.
struct ArrayLengthFromUniformOptions {
/// Constructor
ArrayLengthFromUniformOptions();
/// Destructor
~ArrayLengthFromUniformOptions();
/// Copy constructor
ArrayLengthFromUniformOptions(const ArrayLengthFromUniformOptions&);
/// Copy assignment
/// @returns this ArrayLengthFromUniformOptions
ArrayLengthFromUniformOptions& operator=(
const ArrayLengthFromUniformOptions&);
/// Move constructor
ArrayLengthFromUniformOptions(ArrayLengthFromUniformOptions&&);
/// The binding point to use to generate a uniform buffer from which to read
/// buffer sizes.
sem::BindingPoint ubo_binding;
/// The mapping from storage buffer binding points to the index into the
/// uniform buffer where the length of the buffer is stored.
std::unordered_map<sem::BindingPoint, uint32_t> bindpoint_to_size_index;
};
} // namespace writer
} // namespace tint
#endif // SRC_WRITER_ARRAY_LENGTH_FROM_UNIFORM_OPTIONS_H_

View File

@ -20,6 +20,11 @@ namespace tint {
namespace writer {
namespace hlsl {
Options::Options() = default;
Options::~Options() = default;
Options::Options(const Options&) = default;
Options& Options::operator=(const Options&) = default;
Result::Result() = default;
Result::~Result() = default;
Result::Result(const Result&) = default;
@ -29,7 +34,8 @@ Result Generate(const Program* program, const Options& options) {
// Sanitize the program.
auto sanitized_result = Sanitize(program, options.root_constant_binding_point,
options.disable_workgroup_init);
options.disable_workgroup_init,
options.array_length_from_uniform);
if (!sanitized_result.program.IsValid()) {
result.success = false;
result.error = sanitized_result.program.Diagnostics().str();
@ -50,6 +56,9 @@ Result Generate(const Program* program, const Options& options) {
}
}
result.used_array_length_from_uniform_indices =
std::move(sanitized_result.used_array_length_from_uniform_indices);
return result;
}

View File

@ -17,11 +17,13 @@
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "src/ast/pipeline_stage.h"
#include "src/sem/binding_point.h"
#include "src/writer/array_length_from_uniform_options.h"
#include "src/writer/text.h"
namespace tint {
@ -37,10 +39,23 @@ class GeneratorImpl;
/// Configuration options used for generating HLSL.
struct Options {
/// Constructor
Options();
/// Destructor
~Options();
/// Copy constructor
Options(const Options&);
/// Copy assignment
/// @returns this Options
Options& operator=(const Options&);
/// The binding point to use for information passed via root constants.
sem::BindingPoint root_constant_binding_point;
/// Set to `true` to disable workgroup memory zero initialization
bool disable_workgroup_init = false;
/// Options used to specify a mapping of binding points to indices into a UBO
/// from which to load buffer sizes.
ArrayLengthFromUniformOptions array_length_from_uniform = {};
};
/// The result produced when generating HLSL.
@ -65,6 +80,10 @@ struct Result {
/// The list of entry points in the generated HLSL.
std::vector<std::pair<std::string, ast::PipelineStage>> entry_points;
/// Indices into the array_length_from_uniform binding that are statically
/// used.
std::unordered_set<uint32_t> used_array_length_from_uniform_indices;
};
/// Generate HLSL for a program, according to a set of configuration options.

View File

@ -45,6 +45,7 @@
#include "src/sem/type_conversion.h"
#include "src/sem/variable.h"
#include "src/transform/add_empty_entry_point.h"
#include "src/transform/array_length_from_uniform.h"
#include "src/transform/calculate_array_length.h"
#include "src/transform/canonicalize_entry_point_io.h"
#include "src/transform/decompose_memory_access.h"
@ -124,12 +125,24 @@ const char* LoopAttribute() {
} // namespace
SanitizedResult Sanitize(const Program* in,
sem::BindingPoint root_constant_binding_point,
bool disable_workgroup_init) {
SanitizedResult::SanitizedResult() = default;
SanitizedResult::~SanitizedResult() = default;
SanitizedResult::SanitizedResult(SanitizedResult&&) = default;
SanitizedResult Sanitize(
const Program* in,
sem::BindingPoint root_constant_binding_point,
bool disable_workgroup_init,
const ArrayLengthFromUniformOptions& array_length_from_uniform) {
transform::Manager manager;
transform::DataMap data;
// Build the config for the internal ArrayLengthFromUniform transform.
transform::ArrayLengthFromUniform::Config array_length_from_uniform_cfg(
array_length_from_uniform.ubo_binding);
array_length_from_uniform_cfg.bindpoint_to_size_index =
array_length_from_uniform.bindpoint_to_size_index;
// Attempt to convert `loop`s into for-loops. This is to try and massage the
// output into something that will not cause FXC to choke or misbehave.
manager.Add<transform::FoldTrivialSingleUseLets>();
@ -149,6 +162,11 @@ SanitizedResult Sanitize(const Program* in,
// Simplify cleans up messy `*(&(expr))` expressions from InlinePointerLets.
manager.Add<transform::Simplify>();
manager.Add<transform::RemovePhonies>();
// ArrayLengthFromUniform must come after InlinePointerLets and Simplify, as
// it assumes that the form of the array length argument is &var.array.
manager.Add<transform::ArrayLengthFromUniform>();
data.Add<transform::ArrayLengthFromUniform::Config>(
std::move(array_length_from_uniform_cfg));
// DecomposeMemoryAccess must come after:
// * InlinePointerLets, as we cannot take the address of calls to
// DecomposeMemoryAccess::Intrinsic.
@ -171,8 +189,13 @@ SanitizedResult Sanitize(const Program* in,
data.Add<transform::NumWorkgroupsFromUniform::Config>(
root_constant_binding_point);
auto out = manager.Run(in, data);
SanitizedResult result;
result.program = std::move(manager.Run(in, data).program);
result.program = std::move(out.program);
result.used_array_length_from_uniform_indices =
std::move(out.data.Get<transform::ArrayLengthFromUniform::Result>()
->used_size_indices);
return result;
}

View File

@ -36,6 +36,7 @@
#include "src/sem/binding_point.h"
#include "src/transform/decompose_memory_access.h"
#include "src/utils/hash.h"
#include "src/writer/array_length_from_uniform_options.h"
#include "src/writer/text_generator.h"
namespace tint {
@ -53,8 +54,18 @@ namespace hlsl {
/// The result of sanitizing a program for generation.
struct SanitizedResult {
/// Constructor
SanitizedResult();
/// Destructor
~SanitizedResult();
/// Move constructor
SanitizedResult(SanitizedResult&&);
/// The sanitized program.
Program program;
/// Indices into the array_length_from_uniform binding that are statically
/// used.
std::unordered_set<uint32_t> used_array_length_from_uniform_indices;
};
/// Sanitize a program in preparation for generating HLSL.
@ -62,9 +73,11 @@ struct SanitizedResult {
/// that will be passed via root constants
/// @param disable_workgroup_init `true` to disable workgroup memory zero
/// @returns the sanitized program and any supplementary information
SanitizedResult Sanitize(const Program* program,
sem::BindingPoint root_constant_binding_point = {},
bool disable_workgroup_init = false);
SanitizedResult Sanitize(
const Program* program,
sem::BindingPoint root_constant_binding_point = {},
bool disable_workgroup_init = false,
const ArrayLengthFromUniformOptions& array_length_from_uniform = {});
/// Implementation class for HLSL generator
class GeneratorImpl : public TextGenerator {

View File

@ -144,6 +144,58 @@ void a_func() {
EXPECT_EQ(expect, got);
}
TEST_F(HlslSanitizerTest, Call_ArrayLength_ArrayLengthFromUniform) {
auto* s = Structure("my_struct", {Member(0, "a", ty.array<f32>(4))},
{create<ast::StructBlockDecoration>()});
Global("b", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead,
ast::DecorationList{
create<ast::BindingDecoration>(1),
create<ast::GroupDecoration>(2),
});
Global("c", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead,
ast::DecorationList{
create<ast::BindingDecoration>(2),
create<ast::GroupDecoration>(2),
});
Func("a_func", ast::VariableList{}, ty.void_(),
ast::StatementList{
Decl(Var(
"len", ty.u32(), ast::StorageClass::kNone,
Add(Call("arrayLength", AddressOf(MemberAccessor("b", "a"))),
Call("arrayLength", AddressOf(MemberAccessor("c", "a")))))),
},
ast::DecorationList{
Stage(ast::PipelineStage::kFragment),
});
Options options;
options.array_length_from_uniform.ubo_binding = {3, 4};
options.array_length_from_uniform.bindpoint_to_size_index.emplace(
sem::BindingPoint{2, 2}, 7u);
GeneratorImpl& gen = SanitizeAndBuild(options);
ASSERT_TRUE(gen.Generate()) << gen.error();
auto got = gen.result();
auto* expect = R"(cbuffer cbuffer_tint_symbol_1 : register(b4, space3) {
uint4 tint_symbol_1[2];
};
ByteAddressBuffer b : register(t1, space2);
ByteAddressBuffer c : register(t2, space2);
void a_func() {
uint tint_symbol_4 = 0u;
b.GetDimensions(tint_symbol_4);
const uint tint_symbol_5 = ((tint_symbol_4 - 0u) / 4u);
uint len = (tint_symbol_5 + ((tint_symbol_1[1].w - 0u) / 4u));
return;
}
)";
EXPECT_EQ(expect, got);
}
TEST_F(HlslSanitizerTest, PromoteArrayInitializerToConstVar) {
auto* array_init = array<i32, 4>(1, 2, 3, 4);
auto* array_index = IndexAccessor(array_init, 3);

View File

@ -22,6 +22,7 @@
#include "gtest/gtest.h"
#include "src/transform/manager.h"
#include "src/transform/renamer.h"
#include "src/writer/hlsl/generator.h"
#include "src/writer/hlsl/generator_impl.h"
namespace tint {
@ -58,10 +59,11 @@ class TestHelperBase : public BODY, public ProgramBuilder {
/// Builds the program, runs the program through the HLSL sanitizer
/// and returns a GeneratorImpl from the sanitized program.
/// @param options The HLSL generator options.
/// @note The generator is only built once. Multiple calls to Build() will
/// return the same GeneratorImpl without rebuilding.
/// @return the built generator
GeneratorImpl& SanitizeAndBuild() {
GeneratorImpl& SanitizeAndBuild(const Options& options = {}) {
if (gen_) {
return *gen_;
}
@ -76,7 +78,9 @@ class TestHelperBase : public BODY, public ProgramBuilder {
<< formatter.format(program->Diagnostics());
}();
auto sanitized_result = Sanitize(program.get());
auto sanitized_result = Sanitize(
program.get(), options.root_constant_binding_point,
options.disable_workgroup_init, options.array_length_from_uniform);
[&]() {
ASSERT_TRUE(sanitized_result.program.IsValid())
<< formatter.format(sanitized_result.program.Diagnostics());

View File

@ -14,12 +14,19 @@
#include "src/writer/msl/generator.h"
#include <utility>
#include "src/writer/msl/generator_impl.h"
namespace tint {
namespace writer {
namespace msl {
Options::Options() = default;
Options::~Options() = default;
Options::Options(const Options&) = default;
Options& Options::operator=(const Options&) = default;
Result::Result() = default;
Result::~Result() = default;
Result::Result(const Result&) = default;
@ -30,7 +37,8 @@ Result Generate(const Program* program, const Options& options) {
// Sanitize the program.
auto sanitized_result = Sanitize(
program, options.buffer_size_ubo_index, options.fixed_sample_mask,
options.emit_vertex_point_size, options.disable_workgroup_init);
options.emit_vertex_point_size, options.disable_workgroup_init,
options.array_length_from_uniform);
if (!sanitized_result.program.IsValid()) {
result.success = false;
result.error = sanitized_result.program.Diagnostics().str();
@ -38,6 +46,8 @@ Result Generate(const Program* program, const Options& options) {
}
result.needs_storage_buffer_sizes =
sanitized_result.needs_storage_buffer_sizes;
result.used_array_length_from_uniform_indices =
std::move(sanitized_result.used_array_length_from_uniform_indices);
// Generate the MSL code.
auto impl = std::make_unique<GeneratorImpl>(&sanitized_result.program);

View File

@ -18,8 +18,10 @@
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "src/writer/array_length_from_uniform_options.h"
#include "src/writer/text.h"
namespace tint {
@ -34,6 +36,16 @@ class GeneratorImpl;
/// Configuration options used for generating MSL.
struct Options {
/// Constructor
Options();
/// Destructor
~Options();
/// Copy constructor
Options(const Options&);
/// Copy assignment
/// @returns this Options
Options& operator=(const Options&);
/// The index to use when generating a UBO to receive storage buffer sizes.
/// Defaults to 30, which is the last valid buffer slot.
uint32_t buffer_size_ubo_index = 30;
@ -48,6 +60,10 @@ struct Options {
/// Set to `true` to disable workgroup memory zero initialization
bool disable_workgroup_init = false;
/// Options used to specify a mapping of binding points to indices into a UBO
/// from which to load buffer sizes.
ArrayLengthFromUniformOptions array_length_from_uniform = {};
};
/// The result produced when generating MSL.
@ -80,6 +96,10 @@ struct Result {
/// Each entry in the vector is the size of the workgroup allocation that
/// should be created for that index.
std::unordered_map<std::string, std::vector<uint32_t>> workgroup_allocations;
/// Indices into the array_length_from_uniform binding that are statically
/// used.
std::unordered_set<uint32_t> used_array_length_from_uniform_indices;
};
/// Generate MSL for a program, according to a set of configuration options. The

View File

@ -112,31 +112,47 @@ class ScopedBitCast {
};
} // namespace
SanitizedResult Sanitize(const Program* in,
uint32_t buffer_size_ubo_index,
uint32_t fixed_sample_mask,
bool emit_vertex_point_size,
bool disable_workgroup_init) {
SanitizedResult::SanitizedResult() = default;
SanitizedResult::~SanitizedResult() = default;
SanitizedResult::SanitizedResult(SanitizedResult&&) = default;
SanitizedResult Sanitize(
const Program* in,
uint32_t buffer_size_ubo_index,
uint32_t fixed_sample_mask,
bool emit_vertex_point_size,
bool disable_workgroup_init,
const ArrayLengthFromUniformOptions& array_length_from_uniform) {
transform::Manager manager;
transform::DataMap internal_inputs;
// Build the configs for the internal transforms.
auto array_length_from_uniform_cfg =
transform::ArrayLengthFromUniform::Config(
sem::BindingPoint{0, buffer_size_ubo_index});
// Build the config for the internal ArrayLengthFromUniform transform.
transform::ArrayLengthFromUniform::Config array_length_from_uniform_cfg(
array_length_from_uniform.ubo_binding);
if (!array_length_from_uniform.bindpoint_to_size_index.empty()) {
// If |array_length_from_uniform| bindings are provided, use that config.
array_length_from_uniform_cfg.bindpoint_to_size_index =
array_length_from_uniform.bindpoint_to_size_index;
} else {
// If the binding map is empty, use the deprecated |buffer_size_ubo_index|
// and automatically choose indices using the binding numbers.
array_length_from_uniform_cfg = transform::ArrayLengthFromUniform::Config(
sem::BindingPoint{0, buffer_size_ubo_index});
// Use the SSBO binding numbers as the indices for the buffer size lookups.
for (auto* var : in->AST().GlobalVariables()) {
auto* global = in->Sem().Get<sem::GlobalVariable>(var);
if (global && global->StorageClass() == ast::StorageClass::kStorage) {
array_length_from_uniform_cfg.bindpoint_to_size_index.emplace(
global->BindingPoint(), global->BindingPoint().binding);
}
}
}
// Build the configs for the internal CanonicalizeEntryPointIO transform.
auto entry_point_io_cfg = transform::CanonicalizeEntryPointIO::Config(
transform::CanonicalizeEntryPointIO::ShaderStyle::kMsl, fixed_sample_mask,
emit_vertex_point_size);
// Use the SSBO binding numbers as the indices for the buffer size lookups.
for (auto* var : in->AST().GlobalVariables()) {
auto* global = in->Sem().Get<sem::GlobalVariable>(var);
if (global && global->StorageClass() == ast::StorageClass::kStorage) {
array_length_from_uniform_cfg.bindpoint_to_size_index.emplace(
global->BindingPoint(), global->BindingPoint().binding);
}
}
if (!disable_workgroup_init) {
// ZeroInitWorkgroupMemory must come before CanonicalizeEntryPointIO as
// ZeroInitWorkgroupMemory may inject new builtin parameters.
@ -160,13 +176,18 @@ SanitizedResult Sanitize(const Program* in,
internal_inputs.Add<transform::CanonicalizeEntryPointIO::Config>(
std::move(entry_point_io_cfg));
auto out = manager.Run(in, internal_inputs);
if (!out.program.IsValid()) {
return {std::move(out.program)};
}
return {std::move(out.program),
out.data.Get<transform::ArrayLengthFromUniform::Result>()
->needs_buffer_sizes};
SanitizedResult result;
result.program = std::move(out.program);
if (!result.program.IsValid()) {
return result;
}
result.used_array_length_from_uniform_indices =
std::move(out.data.Get<transform::ArrayLengthFromUniform::Result>()
->used_size_indices);
result.needs_storage_buffer_sizes =
!result.used_array_length_from_uniform_indices.empty();
return result;
}
GeneratorImpl::GeneratorImpl(const Program* program) : TextGenerator(program) {}
@ -1314,6 +1335,13 @@ std::string GeneratorImpl::generate_builtin_name(
case sem::IntrinsicType::kUnpack2x16unorm:
out += "unpack_unorm2x16_to_float";
break;
case sem::IntrinsicType::kArrayLength:
diagnostics_.add_error(
diag::System::Writer,
"Unable to translate builtin: " + std::string(intrinsic->str()) +
"\nDid you forget to pass array_length_from_uniform generator "
"options?");
return "";
default:
diagnostics_.add_error(
diag::System::Writer,

View File

@ -17,6 +17,7 @@
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "src/ast/assignment_statement.h"
@ -37,6 +38,7 @@
#include "src/program.h"
#include "src/scope_stack.h"
#include "src/sem/struct.h"
#include "src/writer/array_length_from_uniform_options.h"
#include "src/writer/text_generator.h"
namespace tint {
@ -54,10 +56,20 @@ namespace msl {
/// The result of sanitizing a program for generation.
struct SanitizedResult {
/// Constructor
SanitizedResult();
/// Destructor
~SanitizedResult();
/// Move constructor
SanitizedResult(SanitizedResult&&);
/// The sanitized program.
Program program;
/// True if the shader needs a UBO of buffer sizes.
bool needs_storage_buffer_sizes = false;
/// Indices into the array_length_from_uniform binding that are statically
/// used.
std::unordered_set<uint32_t> used_array_length_from_uniform_indices;
};
/// Sanitize a program in preparation for generating MSL.
@ -66,11 +78,13 @@ struct SanitizedResult {
/// @param emit_vertex_point_size `true` to emit a vertex point size builtin
/// @param disable_workgroup_init `true` to disable workgroup memory zero
/// @returns the sanitized program and any supplementary information
SanitizedResult Sanitize(const Program* program,
uint32_t buffer_size_ubo_index,
uint32_t fixed_sample_mask = 0xFFFFFFFF,
bool emit_vertex_point_size = false,
bool disable_workgroup_init = false);
SanitizedResult Sanitize(
const Program* program,
uint32_t buffer_size_ubo_index,
uint32_t fixed_sample_mask = 0xFFFFFFFF,
bool emit_vertex_point_size = false,
bool disable_workgroup_init = false,
const ArrayLengthFromUniformOptions& array_length_from_uniform = {});
/// Implementation class for MSL generator
class GeneratorImpl : public TextGenerator {

View File

@ -0,0 +1,264 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gmock/gmock.h"
#include "src/ast/call_statement.h"
#include "src/ast/stage_decoration.h"
#include "src/ast/struct_block_decoration.h"
#include "src/ast/variable_decl_statement.h"
#include "src/writer/msl/test_helper.h"
namespace tint {
namespace writer {
namespace msl {
namespace {
using ::testing::HasSubstr;
using MslSanitizerTest = TestHelper;
TEST_F(MslSanitizerTest, Call_ArrayLength) {
auto* s = Structure("my_struct", {Member(0, "a", ty.array<f32>(4))},
{create<ast::StructBlockDecoration>()});
Global("b", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead,
ast::DecorationList{
create<ast::BindingDecoration>(1),
create<ast::GroupDecoration>(2),
});
Func("a_func", ast::VariableList{}, ty.void_(),
ast::StatementList{
Decl(Var("len", ty.u32(), ast::StorageClass::kNone,
Call("arrayLength", AddressOf(MemberAccessor("b", "a"))))),
},
ast::DecorationList{
Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
auto got = gen.result();
auto* expect = R"(#include <metal_stdlib>
using namespace metal;
struct tint_symbol {
/* 0x0000 */ uint4 buffer_size[1];
};
struct my_struct {
float a[1];
};
fragment void a_func(const constant tint_symbol* tint_symbol_2 [[buffer(30)]]) {
uint len = (((*(tint_symbol_2)).buffer_size[0u][1u] - 0u) / 4u);
return;
}
)";
EXPECT_EQ(expect, got);
}
TEST_F(MslSanitizerTest, Call_ArrayLength_OtherMembersInStruct) {
auto* s = Structure("my_struct",
{
Member(0, "z", ty.f32()),
Member(4, "a", ty.array<f32>(4)),
},
{create<ast::StructBlockDecoration>()});
Global("b", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead,
ast::DecorationList{
create<ast::BindingDecoration>(1),
create<ast::GroupDecoration>(2),
});
Func("a_func", ast::VariableList{}, ty.void_(),
ast::StatementList{
Decl(Var("len", ty.u32(), ast::StorageClass::kNone,
Call("arrayLength", AddressOf(MemberAccessor("b", "a"))))),
},
ast::DecorationList{
Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
auto got = gen.result();
auto* expect = R"(#include <metal_stdlib>
using namespace metal;
struct tint_symbol {
/* 0x0000 */ uint4 buffer_size[1];
};
struct my_struct {
float z;
float a[1];
};
fragment void a_func(const constant tint_symbol* tint_symbol_2 [[buffer(30)]]) {
uint len = (((*(tint_symbol_2)).buffer_size[0u][1u] - 4u) / 4u);
return;
}
)";
EXPECT_EQ(expect, got);
}
TEST_F(MslSanitizerTest, Call_ArrayLength_ViaLets) {
auto* s = Structure("my_struct", {Member(0, "a", ty.array<f32>(4))},
{create<ast::StructBlockDecoration>()});
Global("b", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead,
ast::DecorationList{
create<ast::BindingDecoration>(1),
create<ast::GroupDecoration>(2),
});
auto* p = Const("p", nullptr, AddressOf("b"));
auto* p2 = Const("p2", nullptr, AddressOf(MemberAccessor(Deref(p), "a")));
Func("a_func", ast::VariableList{}, ty.void_(),
ast::StatementList{
Decl(p),
Decl(p2),
Decl(Var("len", ty.u32(), ast::StorageClass::kNone,
Call("arrayLength", p2))),
},
ast::DecorationList{
Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
auto got = gen.result();
auto* expect = R"(#include <metal_stdlib>
using namespace metal;
struct tint_symbol {
/* 0x0000 */ uint4 buffer_size[1];
};
struct my_struct {
float a[1];
};
fragment void a_func(const constant tint_symbol* tint_symbol_2 [[buffer(30)]]) {
uint len = (((*(tint_symbol_2)).buffer_size[0u][1u] - 0u) / 4u);
return;
}
)";
EXPECT_EQ(expect, got);
}
TEST_F(MslSanitizerTest, Call_ArrayLength_ArrayLengthFromUniform) {
auto* s = Structure("my_struct", {Member(0, "a", ty.array<f32>(4))},
{create<ast::StructBlockDecoration>()});
Global("b", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead,
ast::DecorationList{
create<ast::BindingDecoration>(1),
create<ast::GroupDecoration>(0),
});
Global("c", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead,
ast::DecorationList{
create<ast::BindingDecoration>(2),
create<ast::GroupDecoration>(0),
});
Func("a_func", ast::VariableList{}, ty.void_(),
ast::StatementList{
Decl(Var(
"len", ty.u32(), ast::StorageClass::kNone,
Add(Call("arrayLength", AddressOf(MemberAccessor("b", "a"))),
Call("arrayLength", AddressOf(MemberAccessor("c", "a")))))),
},
ast::DecorationList{
Stage(ast::PipelineStage::kFragment),
});
Options options;
options.array_length_from_uniform.ubo_binding = {0, 29};
options.array_length_from_uniform.bindpoint_to_size_index.emplace(
sem::BindingPoint{0, 1}, 7u);
options.array_length_from_uniform.bindpoint_to_size_index.emplace(
sem::BindingPoint{0, 2}, 2u);
GeneratorImpl& gen = SanitizeAndBuild(options);
ASSERT_TRUE(gen.Generate()) << gen.error();
auto got = gen.result();
auto* expect = R"(#include <metal_stdlib>
using namespace metal;
struct tint_symbol {
/* 0x0000 */ uint4 buffer_size[2];
};
struct my_struct {
float a[1];
};
fragment void a_func(const constant tint_symbol* tint_symbol_2 [[buffer(29)]]) {
uint len = ((((*(tint_symbol_2)).buffer_size[1u][3u] - 0u) / 4u) + (((*(tint_symbol_2)).buffer_size[0u][2u] - 0u) / 4u));
return;
}
)";
EXPECT_EQ(expect, got);
}
TEST_F(MslSanitizerTest,
Call_ArrayLength_ArrayLengthFromUniformMissingBinding) {
auto* s = Structure("my_struct", {Member(0, "a", ty.array<f32>(4))},
{create<ast::StructBlockDecoration>()});
Global("b", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead,
ast::DecorationList{
create<ast::BindingDecoration>(1),
create<ast::GroupDecoration>(0),
});
Global("c", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead,
ast::DecorationList{
create<ast::BindingDecoration>(2),
create<ast::GroupDecoration>(0),
});
Func("a_func", ast::VariableList{}, ty.void_(),
ast::StatementList{
Decl(Var(
"len", ty.u32(), ast::StorageClass::kNone,
Add(Call("arrayLength", AddressOf(MemberAccessor("b", "a"))),
Call("arrayLength", AddressOf(MemberAccessor("c", "a")))))),
},
ast::DecorationList{
Stage(ast::PipelineStage::kFragment),
});
Options options;
options.array_length_from_uniform.ubo_binding = {0, 29};
options.array_length_from_uniform.bindpoint_to_size_index.emplace(
sem::BindingPoint{0, 2}, 2u);
GeneratorImpl& gen = SanitizeAndBuild(options);
ASSERT_FALSE(gen.Generate());
EXPECT_THAT(gen.error(),
HasSubstr("Unable to translate builtin: arrayLength"));
}
} // namespace
} // namespace msl
} // namespace writer
} // namespace tint

View File

@ -21,6 +21,7 @@
#include "gtest/gtest.h"
#include "src/program_builder.h"
#include "src/writer/msl/generator.h"
#include "src/writer/msl/generator_impl.h"
namespace tint {
@ -57,10 +58,11 @@ class TestHelperBase : public BASE, public ProgramBuilder {
/// Builds the program, runs the program through the transform::Msl sanitizer
/// and returns a GeneratorImpl from the sanitized program.
/// @param options The MSL generator options.
/// @note The generator is only built once. Multiple calls to Build() will
/// return the same GeneratorImpl without rebuilding.
/// @return the built generator
GeneratorImpl& SanitizeAndBuild() {
GeneratorImpl& SanitizeAndBuild(const Options& options = {}) {
if (gen_) {
return *gen_;
}
@ -74,7 +76,10 @@ class TestHelperBase : public BASE, public ProgramBuilder {
<< diag::Formatter().format(program->Diagnostics());
}();
auto result = Sanitize(program.get(), 30);
auto result = Sanitize(
program.get(), options.buffer_size_ubo_index, options.fixed_sample_mask,
options.emit_vertex_point_size, options.disable_workgroup_init,
options.array_length_from_uniform);
[&]() {
ASSERT_TRUE(result.program.IsValid())
<< diag::Formatter().format(result.program.Diagnostics());

View File

@ -591,6 +591,7 @@ tint_unittests_source_set("tint_unittests_msl_writer_src") {
"../src/writer/msl/generator_impl_member_accessor_test.cc",
"../src/writer/msl/generator_impl_module_constant_test.cc",
"../src/writer/msl/generator_impl_return_test.cc",
"../src/writer/msl/generator_impl_sanitizer_test.cc",
"../src/writer/msl/generator_impl_switch_test.cc",
"../src/writer/msl/generator_impl_test.cc",
"../src/writer/msl/generator_impl_type_test.cc",