Add HLSL/MSL generator options for ArrayLengthFromUniform

ArrayLengthFromUniform is needed for correct bounds checks on
dynamic storage buffers on D3D12. The intrinsic GetDimensions does
not return the actual size of the buffer binding.

ArrayLengthFromUniform is updated to output the indices of the
uniform buffer that are statically used. This allows Dawn to minimize
the amount of data needed to upload into the uniform buffer.
These output indices are returned on the HLSL/MSL generator result.

ArrayLengthFromUniform is also updated to allow only some of the
arrayLength calls to be replaced with uniform buffer loads. For HLSL
output, the remaining arrayLength computations will continue to use
GetDimensions(). For MSL, it is invalid to not specify an index into
the uniform buffer for all storage buffers.

After Dawn is updated to use the array_length_from_uniform option in the
Metal backend, the buffer_size_ubo_index member for MSL output may be
removed.

Bug: dawn:429
Change-Id: I9da4ec4a20882e9f1bfa5bb026725d72529eff26
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/69301
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Austin Eng <enga@chromium.org>
This commit is contained in:
Austin Eng 2021-11-19 04:11:33 +00:00 committed by Tint LUCI CQ
parent a660b510ac
commit 38fa643702
20 changed files with 850 additions and 119 deletions

View File

@ -491,6 +491,8 @@ libtint_source_set("libtint_core_all_src") {
"utils/unique_vector.h", "utils/unique_vector.h",
"writer/append_vector.cc", "writer/append_vector.cc",
"writer/append_vector.h", "writer/append_vector.h",
"writer/array_length_from_uniform_options.cc",
"writer/array_length_from_uniform_options.h",
"writer/float_to_string.cc", "writer/float_to_string.cc",
"writer/float_to_string.h", "writer/float_to_string.h",
"writer/text.cc", "writer/text.cc",

View File

@ -403,6 +403,8 @@ set(TINT_LIB_SRCS
utils/unique_vector.h utils/unique_vector.h
writer/append_vector.cc writer/append_vector.cc
writer/append_vector.h writer/append_vector.h
writer/array_length_from_uniform_options.cc
writer/array_length_from_uniform_options.h
writer/float_to_string.cc writer/float_to_string.cc
writer/float_to_string.h writer/float_to_string.h
writer/text_generator.cc writer/text_generator.cc
@ -998,6 +1000,7 @@ if(${TINT_BUILD_TESTS})
writer/msl/generator_impl_member_accessor_test.cc writer/msl/generator_impl_member_accessor_test.cc
writer/msl/generator_impl_module_constant_test.cc writer/msl/generator_impl_module_constant_test.cc
writer/msl/generator_impl_return_test.cc writer/msl/generator_impl_return_test.cc
writer/msl/generator_impl_sanitizer_test.cc
writer/msl/generator_impl_switch_test.cc writer/msl/generator_impl_switch_test.cc
writer/msl/generator_impl_test.cc writer/msl/generator_impl_test.cc
writer/msl/generator_impl_type_test.cc writer/msl/generator_impl_type_test.cc

View File

@ -35,60 +35,18 @@ namespace transform {
ArrayLengthFromUniform::ArrayLengthFromUniform() = default; ArrayLengthFromUniform::ArrayLengthFromUniform() = default;
ArrayLengthFromUniform::~ArrayLengthFromUniform() = default; ArrayLengthFromUniform::~ArrayLengthFromUniform() = default;
void ArrayLengthFromUniform::Run(CloneContext& ctx, /// Iterate over all arrayLength() intrinsics that operate on
const DataMap& inputs, /// storage buffer variables.
DataMap& outputs) { /// @param ctx the CloneContext.
if (!Requires<InlinePointerLets, Simplify>(ctx)) { /// @param functor of type void(const ast::CallExpression*, const
return; /// sem::VariableUser, const sem::GlobalVariable*). It takes in an
} /// ast::CallExpression of the arrayLength call expression node, a
/// sem::VariableUser of the used storage buffer variable, and the
auto* cfg = inputs.Get<Config>(); /// sem::GlobalVariable for the storage buffer.
if (cfg == nullptr) { template <typename F>
ctx.dst->Diagnostics().add_error( static void IterateArrayLengthOnStorageVar(CloneContext& ctx, F&& functor) {
diag::System::Transform,
"missing transform data for " + std::string(TypeInfo().name));
return;
}
auto& sem = ctx.src->Sem(); auto& sem = ctx.src->Sem();
const char* kBufferSizeMemberName = "buffer_size";
// Determine the size of the buffer size array.
uint32_t max_buffer_size_index = 0;
for (auto& idx : cfg->bindpoint_to_size_index) {
if (idx.second > max_buffer_size_index) {
max_buffer_size_index = idx.second;
}
}
// Get (or create, on first call) the uniform buffer that will receive the
// size of each storage buffer in the module.
const ast::Variable* buffer_size_ubo = nullptr;
auto get_ubo = [&]() {
if (!buffer_size_ubo) {
// Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
// We do this because UBOs require an element stride that is 16-byte
// aligned.
auto* buffer_size_struct = ctx.dst->Structure(
ctx.dst->Sym(),
{ctx.dst->Member(
kBufferSizeMemberName,
ctx.dst->ty.array(ctx.dst->ty.vec4(ctx.dst->ty.u32()),
(max_buffer_size_index / 4) + 1))},
ast::DecorationList{ctx.dst->create<ast::StructBlockDecoration>()});
buffer_size_ubo = ctx.dst->Global(
ctx.dst->Sym(), ctx.dst->ty.Of(buffer_size_struct),
ast::StorageClass::kUniform,
ast::DecorationList{
ctx.dst->create<ast::GroupDecoration>(cfg->ubo_binding.group),
ctx.dst->create<ast::BindingDecoration>(
cfg->ubo_binding.binding)});
}
return buffer_size_ubo;
};
// Find all calls to the arrayLength() intrinsic. // Find all calls to the arrayLength() intrinsic.
for (auto* node : ctx.src->ASTNodes().Objects()) { for (auto* node : ctx.src->ASTNodes().Objects()) {
auto* call_expr = node->As<ast::CallExpression>(); auto* call_expr = node->As<ast::CallExpression>();
@ -137,23 +95,91 @@ void ArrayLengthFromUniform::Run(CloneContext& ctx,
<< "storage buffer is not a global variable"; << "storage buffer is not a global variable";
break; break;
} }
functor(call_expr, storage_buffer_sem, var);
}
}
void ArrayLengthFromUniform::Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) {
if (!Requires<InlinePointerLets, Simplify>(ctx)) {
return;
}
auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform,
"missing transform data for " + std::string(TypeInfo().name));
return;
}
const char* kBufferSizeMemberName = "buffer_size";
// Determine the size of the buffer size array.
uint32_t max_buffer_size_index = 0;
IterateArrayLengthOnStorageVar(
ctx, [&](const ast::CallExpression*, const sem::VariableUser*,
const sem::GlobalVariable* var) {
auto binding = var->BindingPoint();
auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
return;
}
if (idx_itr->second > max_buffer_size_index) {
max_buffer_size_index = idx_itr->second;
}
});
// Get (or create, on first call) the uniform buffer that will receive the
// size of each storage buffer in the module.
const ast::Variable* buffer_size_ubo = nullptr;
auto get_ubo = [&]() {
if (!buffer_size_ubo) {
// Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
// We do this because UBOs require an element stride that is 16-byte
// aligned.
auto* buffer_size_struct = ctx.dst->Structure(
ctx.dst->Sym(),
{ctx.dst->Member(
kBufferSizeMemberName,
ctx.dst->ty.array(ctx.dst->ty.vec4(ctx.dst->ty.u32()),
(max_buffer_size_index / 4) + 1))},
ast::DecorationList{ctx.dst->create<ast::StructBlockDecoration>()});
buffer_size_ubo = ctx.dst->Global(
ctx.dst->Sym(), ctx.dst->ty.Of(buffer_size_struct),
ast::StorageClass::kUniform,
ast::DecorationList{
ctx.dst->create<ast::GroupDecoration>(cfg->ubo_binding.group),
ctx.dst->create<ast::BindingDecoration>(
cfg->ubo_binding.binding)});
}
return buffer_size_ubo;
};
std::unordered_set<uint32_t> used_size_indices;
IterateArrayLengthOnStorageVar(ctx, [&](const ast::CallExpression* call_expr,
const sem::VariableUser*
storage_buffer_sem,
const sem::GlobalVariable* var) {
auto binding = var->BindingPoint(); auto binding = var->BindingPoint();
auto idx_itr = cfg->bindpoint_to_size_index.find(binding); auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
if (idx_itr == cfg->bindpoint_to_size_index.end()) { if (idx_itr == cfg->bindpoint_to_size_index.end()) {
ctx.dst->Diagnostics().add_error( return;
diag::System::Transform,
"missing size index mapping for binding point (" +
std::to_string(binding.group) + "," +
std::to_string(binding.binding) + ")");
continue;
} }
uint32_t size_index = idx_itr->second;
used_size_indices.insert(size_index);
// Load the total storage buffer size from the UBO. // Load the total storage buffer size from the UBO.
uint32_t array_index = idx_itr->second / 4; uint32_t array_index = size_index / 4;
auto* vec_expr = ctx.dst->IndexAccessor( auto* vec_expr = ctx.dst->IndexAccessor(
ctx.dst->MemberAccessor(get_ubo()->symbol, kBufferSizeMemberName), ctx.dst->MemberAccessor(get_ubo()->symbol, kBufferSizeMemberName),
array_index); array_index);
uint32_t vec_index = idx_itr->second % 4; uint32_t vec_index = size_index % 4;
auto* total_storage_buffer_size = auto* total_storage_buffer_size =
ctx.dst->IndexAccessor(vec_expr, vec_index); ctx.dst->IndexAccessor(vec_expr, vec_index);
@ -170,20 +196,23 @@ void ArrayLengthFromUniform::Run(CloneContext& ctx,
ctx.dst->Sub(total_storage_buffer_size, array_offset), array_stride); ctx.dst->Sub(total_storage_buffer_size, array_offset), array_stride);
ctx.Replace(call_expr, array_length); ctx.Replace(call_expr, array_length);
} });
ctx.Clone(); ctx.Clone();
outputs.Add<Result>(buffer_size_ubo ? true : false); outputs.Add<Result>(used_size_indices);
} }
ArrayLengthFromUniform::Config::Config(sem::BindingPoint ubo_bp) ArrayLengthFromUniform::Config::Config(sem::BindingPoint ubo_bp)
: ubo_binding(ubo_bp) {} : ubo_binding(ubo_bp) {}
ArrayLengthFromUniform::Config::Config(const Config&) = default; ArrayLengthFromUniform::Config::Config(const Config&) = default;
ArrayLengthFromUniform::Config& ArrayLengthFromUniform::Config::operator=(
const Config&) = default;
ArrayLengthFromUniform::Config::~Config() = default; ArrayLengthFromUniform::Config::~Config() = default;
ArrayLengthFromUniform::Result::Result(bool needs_sizes) ArrayLengthFromUniform::Result::Result(
: needs_buffer_sizes(needs_sizes) {} std::unordered_set<uint32_t> used_size_indices_in)
: used_size_indices(std::move(used_size_indices_in)) {}
ArrayLengthFromUniform::Result::Result(const Result&) = default; ArrayLengthFromUniform::Result::Result(const Result&) = default;
ArrayLengthFromUniform::Result::~Result() = default; ArrayLengthFromUniform::Result::~Result() = default;

View File

@ -16,6 +16,7 @@
#define SRC_TRANSFORM_ARRAY_LENGTH_FROM_UNIFORM_H_ #define SRC_TRANSFORM_ARRAY_LENGTH_FROM_UNIFORM_H_
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include "src/sem/binding_point.h" #include "src/sem/binding_point.h"
#include "src/transform/transform.h" #include "src/transform/transform.h"
@ -66,6 +67,10 @@ class ArrayLengthFromUniform
/// Copy constructor /// Copy constructor
Config(const Config&); Config(const Config&);
/// Copy assignment
/// @return this Config
Config& operator=(const Config&);
/// Destructor /// Destructor
~Config() override; ~Config() override;
@ -79,8 +84,8 @@ class ArrayLengthFromUniform
/// Information produced about what the transform did. /// Information produced about what the transform did.
struct Result : public Castable<Result, transform::Data> { struct Result : public Castable<Result, transform::Data> {
/// Constructor /// Constructor
/// @param needs_sizes True if the transform generated the buffer sizes UBO. /// @param used_size_indices Indices into the UBO that are statically used.
explicit Result(bool needs_sizes); explicit Result(std::unordered_set<uint32_t> used_size_indices);
/// Copy constructor /// Copy constructor
Result(const Result&); Result(const Result&);
@ -88,8 +93,8 @@ class ArrayLengthFromUniform
/// Destructor /// Destructor
~Result() override; ~Result() override;
/// True if the transform generated the buffer sizes UBO. /// Indices into the UBO that are statically used.
const bool needs_buffer_sizes; const std::unordered_set<uint32_t> used_size_indices;
}; };
protected: protected:

View File

@ -110,8 +110,8 @@ fn main() {
Run<InlinePointerLets, Simplify, ArrayLengthFromUniform>(src, data); Run<InlinePointerLets, Simplify, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
EXPECT_TRUE( EXPECT_EQ(std::unordered_set<uint32_t>({0}),
got.data.Get<ArrayLengthFromUniform::Result>()->needs_buffer_sizes); got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
} }
TEST_F(ArrayLengthFromUniformTest, WithStride) { TEST_F(ArrayLengthFromUniformTest, WithStride) {
@ -164,8 +164,8 @@ fn main() {
Run<InlinePointerLets, Simplify, ArrayLengthFromUniform>(src, data); Run<InlinePointerLets, Simplify, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
EXPECT_TRUE( EXPECT_EQ(std::unordered_set<uint32_t>({0}),
got.data.Get<ArrayLengthFromUniform::Result>()->needs_buffer_sizes); got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
} }
TEST_F(ArrayLengthFromUniformTest, MultipleStorageBuffers) { TEST_F(ArrayLengthFromUniformTest, MultipleStorageBuffers) {
@ -286,8 +286,124 @@ fn main() {
Run<InlinePointerLets, Simplify, ArrayLengthFromUniform>(src, data); Run<InlinePointerLets, Simplify, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
EXPECT_TRUE( EXPECT_EQ(std::unordered_set<uint32_t>({0, 1, 2, 3, 4}),
got.data.Get<ArrayLengthFromUniform::Result>()->needs_buffer_sizes); got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
}
TEST_F(ArrayLengthFromUniformTest, MultipleUnusedStorageBuffers) {
auto* src = R"(
[[block]]
struct SB1 {
x : i32;
arr1 : array<i32>;
};
[[block]]
struct SB2 {
x : i32;
arr2 : array<vec4<f32>>;
};
[[block]]
struct SB3 {
x : i32;
arr3 : array<vec4<f32>>;
};
[[block]]
struct SB4 {
x : i32;
arr4 : array<vec4<f32>>;
};
[[block]]
struct SB5 {
x : i32;
arr5 : array<vec4<f32>>;
};
[[group(0), binding(2)]] var<storage, read> sb1 : SB1;
[[group(1), binding(2)]] var<storage, read> sb2 : SB2;
[[group(2), binding(2)]] var<storage, read> sb3 : SB3;
[[group(3), binding(2)]] var<storage, read> sb4 : SB4;
[[group(4), binding(2)]] var<storage, read> sb5 : SB5;
[[stage(compute), workgroup_size(1)]]
fn main() {
var len1 : u32 = arrayLength(&(sb1.arr1));
var len3 : u32 = arrayLength(&(sb3.arr3));
var x : u32 = (len1 + len3);
}
)";
auto* expect = R"(
[[block]]
struct tint_symbol {
buffer_size : array<vec4<u32>, 1u>;
};
[[group(0), binding(30)]] var<uniform> tint_symbol_1 : tint_symbol;
[[block]]
struct SB1 {
x : i32;
arr1 : array<i32>;
};
[[block]]
struct SB2 {
x : i32;
arr2 : array<vec4<f32>>;
};
[[block]]
struct SB3 {
x : i32;
arr3 : array<vec4<f32>>;
};
[[block]]
struct SB4 {
x : i32;
arr4 : array<vec4<f32>>;
};
[[block]]
struct SB5 {
x : i32;
arr5 : array<vec4<f32>>;
};
[[group(0), binding(2)]] var<storage, read> sb1 : SB1;
[[group(1), binding(2)]] var<storage, read> sb2 : SB2;
[[group(2), binding(2)]] var<storage, read> sb3 : SB3;
[[group(3), binding(2)]] var<storage, read> sb4 : SB4;
[[group(4), binding(2)]] var<storage, read> sb5 : SB5;
[[stage(compute), workgroup_size(1)]]
fn main() {
var len1 : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
var len3 : u32 = ((tint_symbol_1.buffer_size[0u][2u] - 16u) / 16u);
var x : u32 = (len1 + len3);
}
)";
ArrayLengthFromUniform::Config cfg({0, 30u});
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2u}, 0);
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{1u, 2u}, 1);
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{2u, 2u}, 2);
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{3u, 2u}, 3);
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{4u, 2u}, 4);
DataMap data;
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
auto got =
Run<InlinePointerLets, Simplify, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(expect, str(got));
EXPECT_EQ(std::unordered_set<uint32_t>({0, 2}),
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
} }
TEST_F(ArrayLengthFromUniformTest, NoArrayLengthCalls) { TEST_F(ArrayLengthFromUniformTest, NoArrayLengthCalls) {
@ -316,8 +432,8 @@ fn main() {
Run<InlinePointerLets, Simplify, ArrayLengthFromUniform>(src, data); Run<InlinePointerLets, Simplify, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(src, str(got)); EXPECT_EQ(src, str(got));
EXPECT_FALSE( EXPECT_EQ(std::unordered_set<uint32_t>(),
got.data.Get<ArrayLengthFromUniform::Result>()->needs_buffer_sizes); got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
} }
TEST_F(ArrayLengthFromUniformTest, MissingBindingPointToIndexMapping) { TEST_F(ArrayLengthFromUniformTest, MissingBindingPointToIndexMapping) {
@ -346,7 +462,37 @@ fn main() {
} }
)"; )";
auto* expect = "error: missing size index mapping for binding point (1,2)"; auto* expect = R"(
[[block]]
struct tint_symbol {
buffer_size : array<vec4<u32>, 1u>;
};
[[group(0), binding(30)]] var<uniform> tint_symbol_1 : tint_symbol;
[[block]]
struct SB1 {
x : i32;
arr1 : array<i32>;
};
[[block]]
struct SB2 {
x : i32;
arr2 : array<vec4<f32>>;
};
[[group(0), binding(2)]] var<storage, read> sb1 : SB1;
[[group(1), binding(2)]] var<storage, read> sb2 : SB2;
[[stage(compute), workgroup_size(1)]]
fn main() {
var len1 : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
var len2 : u32 = arrayLength(&(sb2.arr2));
var x : u32 = (len1 + len2);
}
)";
ArrayLengthFromUniform::Config cfg({0, 30u}); ArrayLengthFromUniform::Config cfg({0, 30u});
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2}, 0); cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2}, 0);
@ -358,6 +504,8 @@ fn main() {
Run<InlinePointerLets, Simplify, ArrayLengthFromUniform>(src, data); Run<InlinePointerLets, Simplify, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
EXPECT_EQ(std::unordered_set<uint32_t>({0}),
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
} }
} // namespace } // namespace

View File

@ -0,0 +1,30 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/writer/array_length_from_uniform_options.h"
namespace tint {
namespace writer {
ArrayLengthFromUniformOptions::ArrayLengthFromUniformOptions() = default;
ArrayLengthFromUniformOptions::~ArrayLengthFromUniformOptions() = default;
ArrayLengthFromUniformOptions::ArrayLengthFromUniformOptions(
const ArrayLengthFromUniformOptions&) = default;
ArrayLengthFromUniformOptions& ArrayLengthFromUniformOptions::operator=(
const ArrayLengthFromUniformOptions&) = default;
ArrayLengthFromUniformOptions::ArrayLengthFromUniformOptions(
ArrayLengthFromUniformOptions&&) = default;
} // namespace writer
} // namespace tint

View File

@ -0,0 +1,52 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_WRITER_ARRAY_LENGTH_FROM_UNIFORM_OPTIONS_H_
#define SRC_WRITER_ARRAY_LENGTH_FROM_UNIFORM_OPTIONS_H_
#include <unordered_map>
#include "src/sem/binding_point.h"
namespace tint {
namespace writer {
/// Options used to specify a mapping of binding points to indices into a UBO
/// from which to load buffer sizes.
struct ArrayLengthFromUniformOptions {
/// Constructor
ArrayLengthFromUniformOptions();
/// Destructor
~ArrayLengthFromUniformOptions();
/// Copy constructor
ArrayLengthFromUniformOptions(const ArrayLengthFromUniformOptions&);
/// Copy assignment
/// @returns this ArrayLengthFromUniformOptions
ArrayLengthFromUniformOptions& operator=(
const ArrayLengthFromUniformOptions&);
/// Move constructor
ArrayLengthFromUniformOptions(ArrayLengthFromUniformOptions&&);
/// The binding point to use to generate a uniform buffer from which to read
/// buffer sizes.
sem::BindingPoint ubo_binding;
/// The mapping from storage buffer binding points to the index into the
/// uniform buffer where the length of the buffer is stored.
std::unordered_map<sem::BindingPoint, uint32_t> bindpoint_to_size_index;
};
} // namespace writer
} // namespace tint
#endif // SRC_WRITER_ARRAY_LENGTH_FROM_UNIFORM_OPTIONS_H_

View File

@ -20,6 +20,11 @@ namespace tint {
namespace writer { namespace writer {
namespace hlsl { namespace hlsl {
Options::Options() = default;
Options::~Options() = default;
Options::Options(const Options&) = default;
Options& Options::operator=(const Options&) = default;
Result::Result() = default; Result::Result() = default;
Result::~Result() = default; Result::~Result() = default;
Result::Result(const Result&) = default; Result::Result(const Result&) = default;
@ -29,7 +34,8 @@ Result Generate(const Program* program, const Options& options) {
// Sanitize the program. // Sanitize the program.
auto sanitized_result = Sanitize(program, options.root_constant_binding_point, auto sanitized_result = Sanitize(program, options.root_constant_binding_point,
options.disable_workgroup_init); options.disable_workgroup_init,
options.array_length_from_uniform);
if (!sanitized_result.program.IsValid()) { if (!sanitized_result.program.IsValid()) {
result.success = false; result.success = false;
result.error = sanitized_result.program.Diagnostics().str(); result.error = sanitized_result.program.Diagnostics().str();
@ -50,6 +56,9 @@ Result Generate(const Program* program, const Options& options) {
} }
} }
result.used_array_length_from_uniform_indices =
std::move(sanitized_result.used_array_length_from_uniform_indices);
return result; return result;
} }

View File

@ -17,11 +17,13 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "src/ast/pipeline_stage.h" #include "src/ast/pipeline_stage.h"
#include "src/sem/binding_point.h" #include "src/sem/binding_point.h"
#include "src/writer/array_length_from_uniform_options.h"
#include "src/writer/text.h" #include "src/writer/text.h"
namespace tint { namespace tint {
@ -37,10 +39,23 @@ class GeneratorImpl;
/// Configuration options used for generating HLSL. /// Configuration options used for generating HLSL.
struct Options { struct Options {
/// Constructor
Options();
/// Destructor
~Options();
/// Copy constructor
Options(const Options&);
/// Copy assignment
/// @returns this Options
Options& operator=(const Options&);
/// The binding point to use for information passed via root constants. /// The binding point to use for information passed via root constants.
sem::BindingPoint root_constant_binding_point; sem::BindingPoint root_constant_binding_point;
/// Set to `true` to disable workgroup memory zero initialization /// Set to `true` to disable workgroup memory zero initialization
bool disable_workgroup_init = false; bool disable_workgroup_init = false;
/// Options used to specify a mapping of binding points to indices into a UBO
/// from which to load buffer sizes.
ArrayLengthFromUniformOptions array_length_from_uniform = {};
}; };
/// The result produced when generating HLSL. /// The result produced when generating HLSL.
@ -65,6 +80,10 @@ struct Result {
/// The list of entry points in the generated HLSL. /// The list of entry points in the generated HLSL.
std::vector<std::pair<std::string, ast::PipelineStage>> entry_points; std::vector<std::pair<std::string, ast::PipelineStage>> entry_points;
/// Indices into the array_length_from_uniform binding that are statically
/// used.
std::unordered_set<uint32_t> used_array_length_from_uniform_indices;
}; };
/// Generate HLSL for a program, according to a set of configuration options. /// Generate HLSL for a program, according to a set of configuration options.

View File

@ -45,6 +45,7 @@
#include "src/sem/type_conversion.h" #include "src/sem/type_conversion.h"
#include "src/sem/variable.h" #include "src/sem/variable.h"
#include "src/transform/add_empty_entry_point.h" #include "src/transform/add_empty_entry_point.h"
#include "src/transform/array_length_from_uniform.h"
#include "src/transform/calculate_array_length.h" #include "src/transform/calculate_array_length.h"
#include "src/transform/canonicalize_entry_point_io.h" #include "src/transform/canonicalize_entry_point_io.h"
#include "src/transform/decompose_memory_access.h" #include "src/transform/decompose_memory_access.h"
@ -124,12 +125,24 @@ const char* LoopAttribute() {
} // namespace } // namespace
SanitizedResult Sanitize(const Program* in, SanitizedResult::SanitizedResult() = default;
sem::BindingPoint root_constant_binding_point, SanitizedResult::~SanitizedResult() = default;
bool disable_workgroup_init) { SanitizedResult::SanitizedResult(SanitizedResult&&) = default;
SanitizedResult Sanitize(
const Program* in,
sem::BindingPoint root_constant_binding_point,
bool disable_workgroup_init,
const ArrayLengthFromUniformOptions& array_length_from_uniform) {
transform::Manager manager; transform::Manager manager;
transform::DataMap data; transform::DataMap data;
// Build the config for the internal ArrayLengthFromUniform transform.
transform::ArrayLengthFromUniform::Config array_length_from_uniform_cfg(
array_length_from_uniform.ubo_binding);
array_length_from_uniform_cfg.bindpoint_to_size_index =
array_length_from_uniform.bindpoint_to_size_index;
// Attempt to convert `loop`s into for-loops. This is to try and massage the // Attempt to convert `loop`s into for-loops. This is to try and massage the
// output into something that will not cause FXC to choke or misbehave. // output into something that will not cause FXC to choke or misbehave.
manager.Add<transform::FoldTrivialSingleUseLets>(); manager.Add<transform::FoldTrivialSingleUseLets>();
@ -149,6 +162,11 @@ SanitizedResult Sanitize(const Program* in,
// Simplify cleans up messy `*(&(expr))` expressions from InlinePointerLets. // Simplify cleans up messy `*(&(expr))` expressions from InlinePointerLets.
manager.Add<transform::Simplify>(); manager.Add<transform::Simplify>();
manager.Add<transform::RemovePhonies>(); manager.Add<transform::RemovePhonies>();
// ArrayLengthFromUniform must come after InlinePointerLets and Simplify, as
// it assumes that the form of the array length argument is &var.array.
manager.Add<transform::ArrayLengthFromUniform>();
data.Add<transform::ArrayLengthFromUniform::Config>(
std::move(array_length_from_uniform_cfg));
// DecomposeMemoryAccess must come after: // DecomposeMemoryAccess must come after:
// * InlinePointerLets, as we cannot take the address of calls to // * InlinePointerLets, as we cannot take the address of calls to
// DecomposeMemoryAccess::Intrinsic. // DecomposeMemoryAccess::Intrinsic.
@ -171,8 +189,13 @@ SanitizedResult Sanitize(const Program* in,
data.Add<transform::NumWorkgroupsFromUniform::Config>( data.Add<transform::NumWorkgroupsFromUniform::Config>(
root_constant_binding_point); root_constant_binding_point);
auto out = manager.Run(in, data);
SanitizedResult result; SanitizedResult result;
result.program = std::move(manager.Run(in, data).program); result.program = std::move(out.program);
result.used_array_length_from_uniform_indices =
std::move(out.data.Get<transform::ArrayLengthFromUniform::Result>()
->used_size_indices);
return result; return result;
} }

View File

@ -36,6 +36,7 @@
#include "src/sem/binding_point.h" #include "src/sem/binding_point.h"
#include "src/transform/decompose_memory_access.h" #include "src/transform/decompose_memory_access.h"
#include "src/utils/hash.h" #include "src/utils/hash.h"
#include "src/writer/array_length_from_uniform_options.h"
#include "src/writer/text_generator.h" #include "src/writer/text_generator.h"
namespace tint { namespace tint {
@ -53,8 +54,18 @@ namespace hlsl {
/// The result of sanitizing a program for generation. /// The result of sanitizing a program for generation.
struct SanitizedResult { struct SanitizedResult {
/// Constructor
SanitizedResult();
/// Destructor
~SanitizedResult();
/// Move constructor
SanitizedResult(SanitizedResult&&);
/// The sanitized program. /// The sanitized program.
Program program; Program program;
/// Indices into the array_length_from_uniform binding that are statically
/// used.
std::unordered_set<uint32_t> used_array_length_from_uniform_indices;
}; };
/// Sanitize a program in preparation for generating HLSL. /// Sanitize a program in preparation for generating HLSL.
@ -62,9 +73,11 @@ struct SanitizedResult {
/// that will be passed via root constants /// that will be passed via root constants
/// @param disable_workgroup_init `true` to disable workgroup memory zero /// @param disable_workgroup_init `true` to disable workgroup memory zero
/// @returns the sanitized program and any supplementary information /// @returns the sanitized program and any supplementary information
SanitizedResult Sanitize(const Program* program, SanitizedResult Sanitize(
sem::BindingPoint root_constant_binding_point = {}, const Program* program,
bool disable_workgroup_init = false); sem::BindingPoint root_constant_binding_point = {},
bool disable_workgroup_init = false,
const ArrayLengthFromUniformOptions& array_length_from_uniform = {});
/// Implementation class for HLSL generator /// Implementation class for HLSL generator
class GeneratorImpl : public TextGenerator { class GeneratorImpl : public TextGenerator {

View File

@ -144,6 +144,58 @@ void a_func() {
EXPECT_EQ(expect, got); EXPECT_EQ(expect, got);
} }
TEST_F(HlslSanitizerTest, Call_ArrayLength_ArrayLengthFromUniform) {
auto* s = Structure("my_struct", {Member(0, "a", ty.array<f32>(4))},
{create<ast::StructBlockDecoration>()});
Global("b", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead,
ast::DecorationList{
create<ast::BindingDecoration>(1),
create<ast::GroupDecoration>(2),
});
Global("c", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead,
ast::DecorationList{
create<ast::BindingDecoration>(2),
create<ast::GroupDecoration>(2),
});
Func("a_func", ast::VariableList{}, ty.void_(),
ast::StatementList{
Decl(Var(
"len", ty.u32(), ast::StorageClass::kNone,
Add(Call("arrayLength", AddressOf(MemberAccessor("b", "a"))),
Call("arrayLength", AddressOf(MemberAccessor("c", "a")))))),
},
ast::DecorationList{
Stage(ast::PipelineStage::kFragment),
});
Options options;
options.array_length_from_uniform.ubo_binding = {3, 4};
options.array_length_from_uniform.bindpoint_to_size_index.emplace(
sem::BindingPoint{2, 2}, 7u);
GeneratorImpl& gen = SanitizeAndBuild(options);
ASSERT_TRUE(gen.Generate()) << gen.error();
auto got = gen.result();
auto* expect = R"(cbuffer cbuffer_tint_symbol_1 : register(b4, space3) {
uint4 tint_symbol_1[2];
};
ByteAddressBuffer b : register(t1, space2);
ByteAddressBuffer c : register(t2, space2);
void a_func() {
uint tint_symbol_4 = 0u;
b.GetDimensions(tint_symbol_4);
const uint tint_symbol_5 = ((tint_symbol_4 - 0u) / 4u);
uint len = (tint_symbol_5 + ((tint_symbol_1[1].w - 0u) / 4u));
return;
}
)";
EXPECT_EQ(expect, got);
}
TEST_F(HlslSanitizerTest, PromoteArrayInitializerToConstVar) { TEST_F(HlslSanitizerTest, PromoteArrayInitializerToConstVar) {
auto* array_init = array<i32, 4>(1, 2, 3, 4); auto* array_init = array<i32, 4>(1, 2, 3, 4);
auto* array_index = IndexAccessor(array_init, 3); auto* array_index = IndexAccessor(array_init, 3);

View File

@ -22,6 +22,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "src/transform/manager.h" #include "src/transform/manager.h"
#include "src/transform/renamer.h" #include "src/transform/renamer.h"
#include "src/writer/hlsl/generator.h"
#include "src/writer/hlsl/generator_impl.h" #include "src/writer/hlsl/generator_impl.h"
namespace tint { namespace tint {
@ -58,10 +59,11 @@ class TestHelperBase : public BODY, public ProgramBuilder {
/// Builds the program, runs the program through the HLSL sanitizer /// Builds the program, runs the program through the HLSL sanitizer
/// and returns a GeneratorImpl from the sanitized program. /// and returns a GeneratorImpl from the sanitized program.
/// @param options The HLSL generator options.
/// @note The generator is only built once. Multiple calls to Build() will /// @note The generator is only built once. Multiple calls to Build() will
/// return the same GeneratorImpl without rebuilding. /// return the same GeneratorImpl without rebuilding.
/// @return the built generator /// @return the built generator
GeneratorImpl& SanitizeAndBuild() { GeneratorImpl& SanitizeAndBuild(const Options& options = {}) {
if (gen_) { if (gen_) {
return *gen_; return *gen_;
} }
@ -76,7 +78,9 @@ class TestHelperBase : public BODY, public ProgramBuilder {
<< formatter.format(program->Diagnostics()); << formatter.format(program->Diagnostics());
}(); }();
auto sanitized_result = Sanitize(program.get()); auto sanitized_result = Sanitize(
program.get(), options.root_constant_binding_point,
options.disable_workgroup_init, options.array_length_from_uniform);
[&]() { [&]() {
ASSERT_TRUE(sanitized_result.program.IsValid()) ASSERT_TRUE(sanitized_result.program.IsValid())
<< formatter.format(sanitized_result.program.Diagnostics()); << formatter.format(sanitized_result.program.Diagnostics());

View File

@ -14,12 +14,19 @@
#include "src/writer/msl/generator.h" #include "src/writer/msl/generator.h"
#include <utility>
#include "src/writer/msl/generator_impl.h" #include "src/writer/msl/generator_impl.h"
namespace tint { namespace tint {
namespace writer { namespace writer {
namespace msl { namespace msl {
Options::Options() = default;
Options::~Options() = default;
Options::Options(const Options&) = default;
Options& Options::operator=(const Options&) = default;
Result::Result() = default; Result::Result() = default;
Result::~Result() = default; Result::~Result() = default;
Result::Result(const Result&) = default; Result::Result(const Result&) = default;
@ -30,7 +37,8 @@ Result Generate(const Program* program, const Options& options) {
// Sanitize the program. // Sanitize the program.
auto sanitized_result = Sanitize( auto sanitized_result = Sanitize(
program, options.buffer_size_ubo_index, options.fixed_sample_mask, program, options.buffer_size_ubo_index, options.fixed_sample_mask,
options.emit_vertex_point_size, options.disable_workgroup_init); options.emit_vertex_point_size, options.disable_workgroup_init,
options.array_length_from_uniform);
if (!sanitized_result.program.IsValid()) { if (!sanitized_result.program.IsValid()) {
result.success = false; result.success = false;
result.error = sanitized_result.program.Diagnostics().str(); result.error = sanitized_result.program.Diagnostics().str();
@ -38,6 +46,8 @@ Result Generate(const Program* program, const Options& options) {
} }
result.needs_storage_buffer_sizes = result.needs_storage_buffer_sizes =
sanitized_result.needs_storage_buffer_sizes; sanitized_result.needs_storage_buffer_sizes;
result.used_array_length_from_uniform_indices =
std::move(sanitized_result.used_array_length_from_uniform_indices);
// Generate the MSL code. // Generate the MSL code.
auto impl = std::make_unique<GeneratorImpl>(&sanitized_result.program); auto impl = std::make_unique<GeneratorImpl>(&sanitized_result.program);

View File

@ -18,8 +18,10 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "src/writer/array_length_from_uniform_options.h"
#include "src/writer/text.h" #include "src/writer/text.h"
namespace tint { namespace tint {
@ -34,6 +36,16 @@ class GeneratorImpl;
/// Configuration options used for generating MSL. /// Configuration options used for generating MSL.
struct Options { struct Options {
/// Constructor
Options();
/// Destructor
~Options();
/// Copy constructor
Options(const Options&);
/// Copy assignment
/// @returns this Options
Options& operator=(const Options&);
/// The index to use when generating a UBO to receive storage buffer sizes. /// The index to use when generating a UBO to receive storage buffer sizes.
/// Defaults to 30, which is the last valid buffer slot. /// Defaults to 30, which is the last valid buffer slot.
uint32_t buffer_size_ubo_index = 30; uint32_t buffer_size_ubo_index = 30;
@ -48,6 +60,10 @@ struct Options {
/// Set to `true` to disable workgroup memory zero initialization /// Set to `true` to disable workgroup memory zero initialization
bool disable_workgroup_init = false; bool disable_workgroup_init = false;
/// Options used to specify a mapping of binding points to indices into a UBO
/// from which to load buffer sizes.
ArrayLengthFromUniformOptions array_length_from_uniform = {};
}; };
/// The result produced when generating MSL. /// The result produced when generating MSL.
@ -80,6 +96,10 @@ struct Result {
/// Each entry in the vector is the size of the workgroup allocation that /// Each entry in the vector is the size of the workgroup allocation that
/// should be created for that index. /// should be created for that index.
std::unordered_map<std::string, std::vector<uint32_t>> workgroup_allocations; std::unordered_map<std::string, std::vector<uint32_t>> workgroup_allocations;
/// Indices into the array_length_from_uniform binding that are statically
/// used.
std::unordered_set<uint32_t> used_array_length_from_uniform_indices;
}; };
/// Generate MSL for a program, according to a set of configuration options. The /// Generate MSL for a program, according to a set of configuration options. The

View File

@ -112,31 +112,47 @@ class ScopedBitCast {
}; };
} // namespace } // namespace
SanitizedResult Sanitize(const Program* in, SanitizedResult::SanitizedResult() = default;
uint32_t buffer_size_ubo_index, SanitizedResult::~SanitizedResult() = default;
uint32_t fixed_sample_mask, SanitizedResult::SanitizedResult(SanitizedResult&&) = default;
bool emit_vertex_point_size,
bool disable_workgroup_init) { SanitizedResult Sanitize(
const Program* in,
uint32_t buffer_size_ubo_index,
uint32_t fixed_sample_mask,
bool emit_vertex_point_size,
bool disable_workgroup_init,
const ArrayLengthFromUniformOptions& array_length_from_uniform) {
transform::Manager manager; transform::Manager manager;
transform::DataMap internal_inputs; transform::DataMap internal_inputs;
// Build the configs for the internal transforms. // Build the config for the internal ArrayLengthFromUniform transform.
auto array_length_from_uniform_cfg = transform::ArrayLengthFromUniform::Config array_length_from_uniform_cfg(
transform::ArrayLengthFromUniform::Config( array_length_from_uniform.ubo_binding);
sem::BindingPoint{0, buffer_size_ubo_index}); if (!array_length_from_uniform.bindpoint_to_size_index.empty()) {
// If |array_length_from_uniform| bindings are provided, use that config.
array_length_from_uniform_cfg.bindpoint_to_size_index =
array_length_from_uniform.bindpoint_to_size_index;
} else {
// If the binding map is empty, use the deprecated |buffer_size_ubo_index|
// and automatically choose indices using the binding numbers.
array_length_from_uniform_cfg = transform::ArrayLengthFromUniform::Config(
sem::BindingPoint{0, buffer_size_ubo_index});
// Use the SSBO binding numbers as the indices for the buffer size lookups.
for (auto* var : in->AST().GlobalVariables()) {
auto* global = in->Sem().Get<sem::GlobalVariable>(var);
if (global && global->StorageClass() == ast::StorageClass::kStorage) {
array_length_from_uniform_cfg.bindpoint_to_size_index.emplace(
global->BindingPoint(), global->BindingPoint().binding);
}
}
}
// Build the configs for the internal CanonicalizeEntryPointIO transform.
auto entry_point_io_cfg = transform::CanonicalizeEntryPointIO::Config( auto entry_point_io_cfg = transform::CanonicalizeEntryPointIO::Config(
transform::CanonicalizeEntryPointIO::ShaderStyle::kMsl, fixed_sample_mask, transform::CanonicalizeEntryPointIO::ShaderStyle::kMsl, fixed_sample_mask,
emit_vertex_point_size); emit_vertex_point_size);
// Use the SSBO binding numbers as the indices for the buffer size lookups.
for (auto* var : in->AST().GlobalVariables()) {
auto* global = in->Sem().Get<sem::GlobalVariable>(var);
if (global && global->StorageClass() == ast::StorageClass::kStorage) {
array_length_from_uniform_cfg.bindpoint_to_size_index.emplace(
global->BindingPoint(), global->BindingPoint().binding);
}
}
if (!disable_workgroup_init) { if (!disable_workgroup_init) {
// ZeroInitWorkgroupMemory must come before CanonicalizeEntryPointIO as // ZeroInitWorkgroupMemory must come before CanonicalizeEntryPointIO as
// ZeroInitWorkgroupMemory may inject new builtin parameters. // ZeroInitWorkgroupMemory may inject new builtin parameters.
@ -160,13 +176,18 @@ SanitizedResult Sanitize(const Program* in,
internal_inputs.Add<transform::CanonicalizeEntryPointIO::Config>( internal_inputs.Add<transform::CanonicalizeEntryPointIO::Config>(
std::move(entry_point_io_cfg)); std::move(entry_point_io_cfg));
auto out = manager.Run(in, internal_inputs); auto out = manager.Run(in, internal_inputs);
if (!out.program.IsValid()) {
return {std::move(out.program)};
}
return {std::move(out.program), SanitizedResult result;
out.data.Get<transform::ArrayLengthFromUniform::Result>() result.program = std::move(out.program);
->needs_buffer_sizes}; if (!result.program.IsValid()) {
return result;
}
result.used_array_length_from_uniform_indices =
std::move(out.data.Get<transform::ArrayLengthFromUniform::Result>()
->used_size_indices);
result.needs_storage_buffer_sizes =
!result.used_array_length_from_uniform_indices.empty();
return result;
} }
GeneratorImpl::GeneratorImpl(const Program* program) : TextGenerator(program) {} GeneratorImpl::GeneratorImpl(const Program* program) : TextGenerator(program) {}
@ -1314,6 +1335,13 @@ std::string GeneratorImpl::generate_builtin_name(
case sem::IntrinsicType::kUnpack2x16unorm: case sem::IntrinsicType::kUnpack2x16unorm:
out += "unpack_unorm2x16_to_float"; out += "unpack_unorm2x16_to_float";
break; break;
case sem::IntrinsicType::kArrayLength:
diagnostics_.add_error(
diag::System::Writer,
"Unable to translate builtin: " + std::string(intrinsic->str()) +
"\nDid you forget to pass array_length_from_uniform generator "
"options?");
return "";
default: default:
diagnostics_.add_error( diagnostics_.add_error(
diag::System::Writer, diag::System::Writer,

View File

@ -17,6 +17,7 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "src/ast/assignment_statement.h" #include "src/ast/assignment_statement.h"
@ -37,6 +38,7 @@
#include "src/program.h" #include "src/program.h"
#include "src/scope_stack.h" #include "src/scope_stack.h"
#include "src/sem/struct.h" #include "src/sem/struct.h"
#include "src/writer/array_length_from_uniform_options.h"
#include "src/writer/text_generator.h" #include "src/writer/text_generator.h"
namespace tint { namespace tint {
@ -54,10 +56,20 @@ namespace msl {
/// The result of sanitizing a program for generation. /// The result of sanitizing a program for generation.
struct SanitizedResult { struct SanitizedResult {
/// Constructor
SanitizedResult();
/// Destructor
~SanitizedResult();
/// Move constructor
SanitizedResult(SanitizedResult&&);
/// The sanitized program. /// The sanitized program.
Program program; Program program;
/// True if the shader needs a UBO of buffer sizes. /// True if the shader needs a UBO of buffer sizes.
bool needs_storage_buffer_sizes = false; bool needs_storage_buffer_sizes = false;
/// Indices into the array_length_from_uniform binding that are statically
/// used.
std::unordered_set<uint32_t> used_array_length_from_uniform_indices;
}; };
/// Sanitize a program in preparation for generating MSL. /// Sanitize a program in preparation for generating MSL.
@ -66,11 +78,13 @@ struct SanitizedResult {
/// @param emit_vertex_point_size `true` to emit a vertex point size builtin /// @param emit_vertex_point_size `true` to emit a vertex point size builtin
/// @param disable_workgroup_init `true` to disable workgroup memory zero /// @param disable_workgroup_init `true` to disable workgroup memory zero
/// @returns the sanitized program and any supplementary information /// @returns the sanitized program and any supplementary information
SanitizedResult Sanitize(const Program* program, SanitizedResult Sanitize(
uint32_t buffer_size_ubo_index, const Program* program,
uint32_t fixed_sample_mask = 0xFFFFFFFF, uint32_t buffer_size_ubo_index,
bool emit_vertex_point_size = false, uint32_t fixed_sample_mask = 0xFFFFFFFF,
bool disable_workgroup_init = false); bool emit_vertex_point_size = false,
bool disable_workgroup_init = false,
const ArrayLengthFromUniformOptions& array_length_from_uniform = {});
/// Implementation class for MSL generator /// Implementation class for MSL generator
class GeneratorImpl : public TextGenerator { class GeneratorImpl : public TextGenerator {

View File

@ -0,0 +1,264 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gmock/gmock.h"
#include "src/ast/call_statement.h"
#include "src/ast/stage_decoration.h"
#include "src/ast/struct_block_decoration.h"
#include "src/ast/variable_decl_statement.h"
#include "src/writer/msl/test_helper.h"
namespace tint {
namespace writer {
namespace msl {
namespace {
using ::testing::HasSubstr;
using MslSanitizerTest = TestHelper;
TEST_F(MslSanitizerTest, Call_ArrayLength) {
auto* s = Structure("my_struct", {Member(0, "a", ty.array<f32>(4))},
{create<ast::StructBlockDecoration>()});
Global("b", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead,
ast::DecorationList{
create<ast::BindingDecoration>(1),
create<ast::GroupDecoration>(2),
});
Func("a_func", ast::VariableList{}, ty.void_(),
ast::StatementList{
Decl(Var("len", ty.u32(), ast::StorageClass::kNone,
Call("arrayLength", AddressOf(MemberAccessor("b", "a"))))),
},
ast::DecorationList{
Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
auto got = gen.result();
auto* expect = R"(#include <metal_stdlib>
using namespace metal;
struct tint_symbol {
/* 0x0000 */ uint4 buffer_size[1];
};
struct my_struct {
float a[1];
};
fragment void a_func(const constant tint_symbol* tint_symbol_2 [[buffer(30)]]) {
uint len = (((*(tint_symbol_2)).buffer_size[0u][1u] - 0u) / 4u);
return;
}
)";
EXPECT_EQ(expect, got);
}
TEST_F(MslSanitizerTest, Call_ArrayLength_OtherMembersInStruct) {
auto* s = Structure("my_struct",
{
Member(0, "z", ty.f32()),
Member(4, "a", ty.array<f32>(4)),
},
{create<ast::StructBlockDecoration>()});
Global("b", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead,
ast::DecorationList{
create<ast::BindingDecoration>(1),
create<ast::GroupDecoration>(2),
});
Func("a_func", ast::VariableList{}, ty.void_(),
ast::StatementList{
Decl(Var("len", ty.u32(), ast::StorageClass::kNone,
Call("arrayLength", AddressOf(MemberAccessor("b", "a"))))),
},
ast::DecorationList{
Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
auto got = gen.result();
auto* expect = R"(#include <metal_stdlib>
using namespace metal;
struct tint_symbol {
/* 0x0000 */ uint4 buffer_size[1];
};
struct my_struct {
float z;
float a[1];
};
fragment void a_func(const constant tint_symbol* tint_symbol_2 [[buffer(30)]]) {
uint len = (((*(tint_symbol_2)).buffer_size[0u][1u] - 4u) / 4u);
return;
}
)";
EXPECT_EQ(expect, got);
}
TEST_F(MslSanitizerTest, Call_ArrayLength_ViaLets) {
auto* s = Structure("my_struct", {Member(0, "a", ty.array<f32>(4))},
{create<ast::StructBlockDecoration>()});
Global("b", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead,
ast::DecorationList{
create<ast::BindingDecoration>(1),
create<ast::GroupDecoration>(2),
});
auto* p = Const("p", nullptr, AddressOf("b"));
auto* p2 = Const("p2", nullptr, AddressOf(MemberAccessor(Deref(p), "a")));
Func("a_func", ast::VariableList{}, ty.void_(),
ast::StatementList{
Decl(p),
Decl(p2),
Decl(Var("len", ty.u32(), ast::StorageClass::kNone,
Call("arrayLength", p2))),
},
ast::DecorationList{
Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
auto got = gen.result();
auto* expect = R"(#include <metal_stdlib>
using namespace metal;
struct tint_symbol {
/* 0x0000 */ uint4 buffer_size[1];
};
struct my_struct {
float a[1];
};
fragment void a_func(const constant tint_symbol* tint_symbol_2 [[buffer(30)]]) {
uint len = (((*(tint_symbol_2)).buffer_size[0u][1u] - 0u) / 4u);
return;
}
)";
EXPECT_EQ(expect, got);
}
TEST_F(MslSanitizerTest, Call_ArrayLength_ArrayLengthFromUniform) {
auto* s = Structure("my_struct", {Member(0, "a", ty.array<f32>(4))},
{create<ast::StructBlockDecoration>()});
Global("b", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead,
ast::DecorationList{
create<ast::BindingDecoration>(1),
create<ast::GroupDecoration>(0),
});
Global("c", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead,
ast::DecorationList{
create<ast::BindingDecoration>(2),
create<ast::GroupDecoration>(0),
});
Func("a_func", ast::VariableList{}, ty.void_(),
ast::StatementList{
Decl(Var(
"len", ty.u32(), ast::StorageClass::kNone,
Add(Call("arrayLength", AddressOf(MemberAccessor("b", "a"))),
Call("arrayLength", AddressOf(MemberAccessor("c", "a")))))),
},
ast::DecorationList{
Stage(ast::PipelineStage::kFragment),
});
Options options;
options.array_length_from_uniform.ubo_binding = {0, 29};
options.array_length_from_uniform.bindpoint_to_size_index.emplace(
sem::BindingPoint{0, 1}, 7u);
options.array_length_from_uniform.bindpoint_to_size_index.emplace(
sem::BindingPoint{0, 2}, 2u);
GeneratorImpl& gen = SanitizeAndBuild(options);
ASSERT_TRUE(gen.Generate()) << gen.error();
auto got = gen.result();
auto* expect = R"(#include <metal_stdlib>
using namespace metal;
struct tint_symbol {
/* 0x0000 */ uint4 buffer_size[2];
};
struct my_struct {
float a[1];
};
fragment void a_func(const constant tint_symbol* tint_symbol_2 [[buffer(29)]]) {
uint len = ((((*(tint_symbol_2)).buffer_size[1u][3u] - 0u) / 4u) + (((*(tint_symbol_2)).buffer_size[0u][2u] - 0u) / 4u));
return;
}
)";
EXPECT_EQ(expect, got);
}
TEST_F(MslSanitizerTest,
Call_ArrayLength_ArrayLengthFromUniformMissingBinding) {
auto* s = Structure("my_struct", {Member(0, "a", ty.array<f32>(4))},
{create<ast::StructBlockDecoration>()});
Global("b", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead,
ast::DecorationList{
create<ast::BindingDecoration>(1),
create<ast::GroupDecoration>(0),
});
Global("c", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead,
ast::DecorationList{
create<ast::BindingDecoration>(2),
create<ast::GroupDecoration>(0),
});
Func("a_func", ast::VariableList{}, ty.void_(),
ast::StatementList{
Decl(Var(
"len", ty.u32(), ast::StorageClass::kNone,
Add(Call("arrayLength", AddressOf(MemberAccessor("b", "a"))),
Call("arrayLength", AddressOf(MemberAccessor("c", "a")))))),
},
ast::DecorationList{
Stage(ast::PipelineStage::kFragment),
});
Options options;
options.array_length_from_uniform.ubo_binding = {0, 29};
options.array_length_from_uniform.bindpoint_to_size_index.emplace(
sem::BindingPoint{0, 2}, 2u);
GeneratorImpl& gen = SanitizeAndBuild(options);
ASSERT_FALSE(gen.Generate());
EXPECT_THAT(gen.error(),
HasSubstr("Unable to translate builtin: arrayLength"));
}
} // namespace
} // namespace msl
} // namespace writer
} // namespace tint

View File

@ -21,6 +21,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "src/program_builder.h" #include "src/program_builder.h"
#include "src/writer/msl/generator.h"
#include "src/writer/msl/generator_impl.h" #include "src/writer/msl/generator_impl.h"
namespace tint { namespace tint {
@ -57,10 +58,11 @@ class TestHelperBase : public BASE, public ProgramBuilder {
/// Builds the program, runs the program through the transform::Msl sanitizer /// Builds the program, runs the program through the transform::Msl sanitizer
/// and returns a GeneratorImpl from the sanitized program. /// and returns a GeneratorImpl from the sanitized program.
/// @param options The MSL generator options.
/// @note The generator is only built once. Multiple calls to Build() will /// @note The generator is only built once. Multiple calls to Build() will
/// return the same GeneratorImpl without rebuilding. /// return the same GeneratorImpl without rebuilding.
/// @return the built generator /// @return the built generator
GeneratorImpl& SanitizeAndBuild() { GeneratorImpl& SanitizeAndBuild(const Options& options = {}) {
if (gen_) { if (gen_) {
return *gen_; return *gen_;
} }
@ -74,7 +76,10 @@ class TestHelperBase : public BASE, public ProgramBuilder {
<< diag::Formatter().format(program->Diagnostics()); << diag::Formatter().format(program->Diagnostics());
}(); }();
auto result = Sanitize(program.get(), 30); auto result = Sanitize(
program.get(), options.buffer_size_ubo_index, options.fixed_sample_mask,
options.emit_vertex_point_size, options.disable_workgroup_init,
options.array_length_from_uniform);
[&]() { [&]() {
ASSERT_TRUE(result.program.IsValid()) ASSERT_TRUE(result.program.IsValid())
<< diag::Formatter().format(result.program.Diagnostics()); << diag::Formatter().format(result.program.Diagnostics());

View File

@ -591,6 +591,7 @@ tint_unittests_source_set("tint_unittests_msl_writer_src") {
"../src/writer/msl/generator_impl_member_accessor_test.cc", "../src/writer/msl/generator_impl_member_accessor_test.cc",
"../src/writer/msl/generator_impl_module_constant_test.cc", "../src/writer/msl/generator_impl_module_constant_test.cc",
"../src/writer/msl/generator_impl_return_test.cc", "../src/writer/msl/generator_impl_return_test.cc",
"../src/writer/msl/generator_impl_sanitizer_test.cc",
"../src/writer/msl/generator_impl_switch_test.cc", "../src/writer/msl/generator_impl_switch_test.cc",
"../src/writer/msl/generator_impl_test.cc", "../src/writer/msl/generator_impl_test.cc",
"../src/writer/msl/generator_impl_type_test.cc", "../src/writer/msl/generator_impl_type_test.cc",