tint/writer/msl: Generate an array<T,N> helper

And remove the WrapArraysInStructs transform.

Wrapping arrays in structures becomes troublesome for `const` arrays, as
currently WGSL does not allow `const` structures.

MSL 2.0+ has a builtin array<> helper, but we're targetting MSL 1.2, so
we have to emit our own. Fortunately, it can be done with a few lines of
templated code.

This produces significantly cleaner output.

Change-Id: Ifc92ef21e09befa252a07c856c4b5afdc51cc2e4
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/94540
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
Ben Clayton
2022-06-24 17:01:59 +00:00
committed by Dawn LUCI CQ
parent 3c054304a8
commit f47887d207
218 changed files with 3637 additions and 2269 deletions

View File

@@ -543,8 +543,6 @@ libtint_source_set("libtint_core_all_src") {
"transform/vertex_pulling.h",
"transform/while_to_loop.cc",
"transform/while_to_loop.h",
"transform/wrap_arrays_in_structs.cc",
"transform/wrap_arrays_in_structs.h",
"transform/zero_init_workgroup_memory.cc",
"transform/zero_init_workgroup_memory.h",
"utils/bitcast.h",
@@ -1199,7 +1197,6 @@ if (tint_build_unittests) {
"transform/vectorize_scalar_matrix_constructors_test.cc",
"transform/vertex_pulling_test.cc",
"transform/while_to_loop_test.cc",
"transform/wrap_arrays_in_structs_test.cc",
"transform/zero_init_workgroup_memory_test.cc",
]
}

View File

@@ -466,8 +466,6 @@ set(TINT_LIB_SRCS
transform/vertex_pulling.h
transform/while_to_loop.cc
transform/while_to_loop.h
transform/wrap_arrays_in_structs.cc
transform/wrap_arrays_in_structs.h
transform/zero_init_workgroup_memory.cc
transform/zero_init_workgroup_memory.h
utils/bitcast.h
@@ -1121,7 +1119,6 @@ if(TINT_BUILD_TESTS)
transform/vectorize_scalar_matrix_constructors_test.cc
transform/vertex_pulling_test.cc
transform/while_to_loop_test.cc
transform/wrap_arrays_in_structs_test.cc
transform/zero_init_workgroup_memory_test.cc
transform/utils/get_insertion_point_test.cc
transform/utils/hoist_to_decl_before_test.cc

View File

@@ -1,158 +0,0 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/wrap_arrays_in_structs.h"
#include <utility>
#include "src/tint/program_builder.h"
#include "src/tint/sem/array.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/expression.h"
#include "src/tint/sem/type_constructor.h"
#include "src/tint/utils/map.h"
#include "src/tint/utils/transform.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::WrapArraysInStructs);
namespace tint::transform {
WrapArraysInStructs::WrappedArrayInfo::WrappedArrayInfo() = default;
WrapArraysInStructs::WrappedArrayInfo::WrappedArrayInfo(const WrappedArrayInfo&) = default;
WrapArraysInStructs::WrappedArrayInfo::~WrappedArrayInfo() = default;
WrapArraysInStructs::WrapArraysInStructs() = default;
WrapArraysInStructs::~WrapArraysInStructs() = default;
bool WrapArraysInStructs::ShouldRun(const Program* program, const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (program->Sem().Get<sem::Array>(node->As<ast::Type>())) {
return true;
}
}
return false;
}
void WrapArraysInStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
auto& sem = ctx.src->Sem();
std::unordered_map<const sem::Array*, WrappedArrayInfo> wrapped_arrays;
auto wrapper = [&](const sem::Array* array) { return WrapArray(ctx, wrapped_arrays, array); };
auto wrapper_typename = [&](const sem::Array* arr) -> ast::TypeName* {
auto info = wrapper(arr);
return info ? ctx.dst->create<ast::TypeName>(info.wrapper_name) : nullptr;
};
// Replace all array types with their corresponding wrapper
ctx.ReplaceAll([&](const ast::Type* ast_type) -> const ast::Type* {
auto* type = ctx.src->TypeOf(ast_type);
if (auto* array = type->UnwrapRef()->As<sem::Array>()) {
return wrapper_typename(array);
}
return nullptr;
});
// Fix up index accessors so `a[1]` becomes `a.arr[1]`
ctx.ReplaceAll(
[&](const ast::IndexAccessorExpression* accessor) -> const ast::IndexAccessorExpression* {
if (auto* array =
::tint::As<sem::Array>(sem.Get(accessor->object)->Type()->UnwrapRef())) {
if (wrapper(array)) {
// Array is wrapped in a structure. Emit a member accessor to get
// to the actual array.
auto* arr = ctx.Clone(accessor->object);
auto* idx = ctx.Clone(accessor->index);
auto* unwrapped = ctx.dst->MemberAccessor(arr, "arr");
return ctx.dst->IndexAccessor(accessor->source, unwrapped, idx);
}
}
return nullptr;
});
// Fix up array constructors so `A(1,2)` becomes `tint_array_wrapper(A(1,2))`
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::Expression* {
if (auto* call = sem.Get(expr)->UnwrapMaterialize()->As<sem::Call>()) {
if (auto* ctor = call->Target()->As<sem::TypeConstructor>()) {
if (auto* array = ctor->ReturnType()->As<sem::Array>()) {
if (auto w = wrapper(array)) {
// Wrap the array type constructor with another constructor for
// the wrapper
auto* wrapped_array_ty = ctx.dst->ty.type_name(w.wrapper_name);
auto* array_ty = w.array_type(ctx);
auto args = utils::Transform(call->Arguments(),
[&](const tint::sem::Expression* s) {
return ctx.Clone(s->Declaration());
});
auto* arr_ctor = ctx.dst->Construct(array_ty, args);
return ctx.dst->Construct(wrapped_array_ty, arr_ctor);
}
}
}
}
return nullptr;
});
ctx.Clone();
}
WrapArraysInStructs::WrappedArrayInfo WrapArraysInStructs::WrapArray(
CloneContext& ctx,
std::unordered_map<const sem::Array*, WrappedArrayInfo>& wrapped_arrays,
const sem::Array* array) const {
if (array->IsRuntimeSized()) {
return {}; // We don't want to wrap runtime sized arrays
}
return utils::GetOrCreate(wrapped_arrays, array, [&] {
WrappedArrayInfo info;
// Generate a unique name for the array wrapper
info.wrapper_name = ctx.dst->Symbols().New("tint_array_wrapper");
// Examine the element type. Is it also an array?
std::function<const ast::Type*(CloneContext&)> el_type;
if (auto* el_array = array->ElemType()->As<sem::Array>()) {
// Array of array - call WrapArray() on the element type
if (auto el = WrapArray(ctx, wrapped_arrays, el_array)) {
el_type = [=](CloneContext& c) {
return c.dst->create<ast::TypeName>(el.wrapper_name);
};
}
}
// If the element wasn't an array, just create the typical AST type for it
if (!el_type) {
el_type = [=](CloneContext& c) { return CreateASTTypeFor(c, array->ElemType()); };
}
// Construct the single structure field type
info.array_type = [=](CloneContext& c) {
ast::AttributeList attrs;
if (!array->IsStrideImplicit()) {
attrs.emplace_back(c.dst->create<ast::StrideAttribute>(array->Stride()));
}
return c.dst->ty.array(el_type(c), u32(array->Count()), std::move(attrs));
};
// Structure() will create and append the ast::Struct to the
// global declarations of `ctx.dst`. As we haven't finished building the
// current module-scope statement or function, this will be placed
// immediately before the usage.
ctx.dst->Structure(info.wrapper_name, {ctx.dst->Member("arr", info.array_type(ctx))});
return info;
});
}
} // namespace tint::transform

View File

@@ -1,88 +0,0 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_WRAP_ARRAYS_IN_STRUCTS_H_
#define SRC_TINT_TRANSFORM_WRAP_ARRAYS_IN_STRUCTS_H_
#include <string>
#include <unordered_map>
#include "src/tint/transform/transform.h"
// Forward declarations
namespace tint::ast {
class Type;
} // namespace tint::ast
namespace tint::transform {
/// WrapArraysInStructs is a transform that replaces all array types with a
/// structure holding a single field of that array type.
/// Array index expressions and constructors are also adjusted to deal with this
/// wrapping.
/// This transform helps with backends that cannot directly return arrays or use
/// them as parameters.
class WrapArraysInStructs : public Castable<WrapArraysInStructs, Transform> {
public:
/// Constructor
WrapArraysInStructs();
/// Destructor
~WrapArraysInStructs() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
private:
struct WrappedArrayInfo {
WrappedArrayInfo();
WrappedArrayInfo(const WrappedArrayInfo&);
~WrappedArrayInfo();
Symbol wrapper_name;
std::function<const ast::Type*(CloneContext&)> array_type;
operator bool() { return wrapper_name.IsValid(); }
};
/// WrapArray wraps the fixed-size array type in a new structure (if it hasn't
/// already been wrapped). WrapArray will recursively wrap arrays-of-arrays.
/// The new structure will be added to module-scope type declarations of
/// `ctx.dst`.
/// @param ctx the CloneContext
/// @param wrapped_arrays a map of src array type to the wrapped structure
/// name
/// @param array the array type
/// @return the name of the structure that wraps the array, or an invalid
/// Symbol if this array should not be wrapped
WrappedArrayInfo WrapArray(
CloneContext& ctx,
std::unordered_map<const sem::Array*, WrappedArrayInfo>& wrapped_arrays,
const sem::Array* array) const;
};
} // namespace tint::transform
#endif // SRC_TINT_TRANSFORM_WRAP_ARRAYS_IN_STRUCTS_H_

View File

@@ -1,422 +0,0 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/wrap_arrays_in_structs.h"
#include <memory>
#include <utility>
#include "src/tint/transform/test_helper.h"
namespace tint::transform {
namespace {
using WrapArraysInStructsTest = TransformTest;
TEST_F(WrapArraysInStructsTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<WrapArraysInStructs>(src));
}
TEST_F(WrapArraysInStructsTest, ShouldRunHasArray) {
auto* src = R"(
var<private> arr : array<i32, 4>;
)";
EXPECT_TRUE(ShouldRun<WrapArraysInStructs>(src));
}
TEST_F(WrapArraysInStructsTest, EmptyModule) {
auto* src = R"()";
auto* expect = src;
auto got = Run<WrapArraysInStructs>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(WrapArraysInStructsTest, ArrayAsGlobal) {
auto* src = R"(
var<private> arr : array<i32, 4>;
)";
auto* expect = R"(
struct tint_array_wrapper {
arr : array<i32, 4u>,
}
var<private> arr : tint_array_wrapper;
)";
auto got = Run<WrapArraysInStructs>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(WrapArraysInStructsTest, ArrayAsFunctionVar) {
auto* src = R"(
fn f() {
var arr : array<i32, 4>;
let x = arr[3];
}
)";
auto* expect = R"(
struct tint_array_wrapper {
arr : array<i32, 4u>,
}
fn f() {
var arr : tint_array_wrapper;
let x = arr.arr[3];
}
)";
auto got = Run<WrapArraysInStructs>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(WrapArraysInStructsTest, ArrayAsParam) {
auto* src = R"(
fn f(a : array<i32, 4>) -> i32 {
return a[2];
}
)";
auto* expect = R"(
struct tint_array_wrapper {
arr : array<i32, 4u>,
}
fn f(a : tint_array_wrapper) -> i32 {
return a.arr[2];
}
)";
auto got = Run<WrapArraysInStructs>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(WrapArraysInStructsTest, ArrayAsReturn) {
auto* src = R"(
fn f() -> array<i32, 4> {
return array<i32, 4>(1, 2, 3, 4);
}
)";
auto* expect = R"(
struct tint_array_wrapper {
arr : array<i32, 4u>,
}
fn f() -> tint_array_wrapper {
return tint_array_wrapper(array<i32, 4u>(1, 2, 3, 4));
}
)";
auto got = Run<WrapArraysInStructs>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(WrapArraysInStructsTest, ArrayAlias) {
auto* src = R"(
type Inner = array<i32, 2>;
type Array = array<Inner, 2>;
fn f() {
var arr : Array;
arr = Array();
arr = Array(Inner(1, 2), Inner(3, 4));
let vals : Array = Array(Inner(1, 2), Inner(3, 4));
arr = vals;
let x = arr[3];
}
)";
auto* expect = R"(
struct tint_array_wrapper {
arr : array<i32, 2u>,
}
type Inner = tint_array_wrapper;
struct tint_array_wrapper_1 {
arr : array<tint_array_wrapper, 2u>,
}
type Array = tint_array_wrapper_1;
fn f() {
var arr : tint_array_wrapper_1;
arr = tint_array_wrapper_1(array<tint_array_wrapper, 2u>());
arr = tint_array_wrapper_1(array<tint_array_wrapper, 2u>(tint_array_wrapper(array<i32, 2u>(1, 2)), tint_array_wrapper(array<i32, 2u>(3, 4))));
let vals : tint_array_wrapper_1 = tint_array_wrapper_1(array<tint_array_wrapper, 2u>(tint_array_wrapper(array<i32, 2u>(1, 2)), tint_array_wrapper(array<i32, 2u>(3, 4))));
arr = vals;
let x = arr.arr[3];
}
)";
auto got = Run<WrapArraysInStructs>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(WrapArraysInStructsTest, ArrayAlias_OutOfOrder) {
auto* src = R"(
fn f() {
var arr : Array;
arr = Array();
arr = Array(Inner(1, 2), Inner(3, 4));
let vals : Array = Array(Inner(1, 2), Inner(3, 4));
arr = vals;
let x = arr[3];
}
type Array = array<Inner, 2>;
type Inner = array<i32, 2>;
)";
auto* expect = R"(
struct tint_array_wrapper_1 {
arr : array<i32, 2u>,
}
struct tint_array_wrapper {
arr : array<tint_array_wrapper_1, 2u>,
}
fn f() {
var arr : tint_array_wrapper;
arr = tint_array_wrapper(array<tint_array_wrapper_1, 2u>());
arr = tint_array_wrapper(array<tint_array_wrapper_1, 2u>(tint_array_wrapper_1(array<i32, 2u>(1, 2)), tint_array_wrapper_1(array<i32, 2u>(3, 4))));
let vals : tint_array_wrapper = tint_array_wrapper(array<tint_array_wrapper_1, 2u>(tint_array_wrapper_1(array<i32, 2u>(1, 2)), tint_array_wrapper_1(array<i32, 2u>(3, 4))));
arr = vals;
let x = arr.arr[3];
}
type Array = tint_array_wrapper;
type Inner = tint_array_wrapper_1;
)";
auto got = Run<WrapArraysInStructs>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(WrapArraysInStructsTest, ArraysInStruct) {
auto* src = R"(
struct S {
a : array<i32, 4>,
b : array<i32, 8>,
c : array<i32, 4>,
};
)";
auto* expect = R"(
struct tint_array_wrapper {
arr : array<i32, 4u>,
}
struct tint_array_wrapper_1 {
arr : array<i32, 8u>,
}
struct S {
a : tint_array_wrapper,
b : tint_array_wrapper_1,
c : tint_array_wrapper,
}
)";
auto got = Run<WrapArraysInStructs>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(WrapArraysInStructsTest, ArraysOfArraysInStruct) {
auto* src = R"(
struct S {
a : array<i32, 4>,
b : array<array<i32, 4>, 4>,
c : array<array<array<i32, 4>, 4>, 4>,
};
)";
auto* expect = R"(
struct tint_array_wrapper {
arr : array<i32, 4u>,
}
struct tint_array_wrapper_1 {
arr : array<tint_array_wrapper, 4u>,
}
struct tint_array_wrapper_2 {
arr : array<tint_array_wrapper_1, 4u>,
}
struct S {
a : tint_array_wrapper,
b : tint_array_wrapper_1,
c : tint_array_wrapper_2,
}
)";
auto got = Run<WrapArraysInStructs>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(WrapArraysInStructsTest, AccessArraysOfArraysInStruct) {
auto* src = R"(
struct S {
a : array<i32, 4>,
b : array<array<i32, 4>, 4>,
c : array<array<array<i32, 4>, 4>, 4>,
};
fn f(s : S) -> i32 {
return s.a[2] + s.b[1][2] + s.c[3][1][2];
}
)";
auto* expect = R"(
struct tint_array_wrapper {
arr : array<i32, 4u>,
}
struct tint_array_wrapper_1 {
arr : array<tint_array_wrapper, 4u>,
}
struct tint_array_wrapper_2 {
arr : array<tint_array_wrapper_1, 4u>,
}
struct S {
a : tint_array_wrapper,
b : tint_array_wrapper_1,
c : tint_array_wrapper_2,
}
fn f(s : S) -> i32 {
return ((s.a.arr[2] + s.b.arr[1].arr[2]) + s.c.arr[3].arr[1].arr[2]);
}
)";
auto got = Run<WrapArraysInStructs>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(WrapArraysInStructsTest, DeclarationOrder) {
auto* src = R"(
type T0 = i32;
type T1 = array<i32, 1>;
type T2 = i32;
fn f1(a : array<i32, 2>) {
}
type T3 = i32;
fn f2() {
var v : array<i32, 3>;
}
)";
auto* expect = R"(
type T0 = i32;
struct tint_array_wrapper {
arr : array<i32, 1u>,
}
type T1 = tint_array_wrapper;
type T2 = i32;
struct tint_array_wrapper_1 {
arr : array<i32, 2u>,
}
fn f1(a : tint_array_wrapper_1) {
}
type T3 = i32;
struct tint_array_wrapper_2 {
arr : array<i32, 3u>,
}
fn f2() {
var v : tint_array_wrapper_2;
}
)";
auto got = Run<WrapArraysInStructs>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(WrapArraysInStructsTest, DeclarationOrder_OutOfOrder) {
auto* src = R"(
fn f2() {
var v : array<i32, 3>;
}
type T3 = i32;
fn f1(a : array<i32, 2>) {
}
type T2 = i32;
type T1 = array<i32, 1>;
type T0 = i32;
)";
auto* expect = R"(
struct tint_array_wrapper {
arr : array<i32, 3u>,
}
fn f2() {
var v : tint_array_wrapper;
}
type T3 = i32;
struct tint_array_wrapper_1 {
arr : array<i32, 2u>,
}
fn f1(a : tint_array_wrapper_1) {
}
type T2 = i32;
struct tint_array_wrapper_2 {
arr : array<i32, 1u>,
}
type T1 = tint_array_wrapper_2;
type T0 = i32;
)";
auto got = Run<WrapArraysInStructs>(src);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace tint::transform

View File

@@ -71,7 +71,6 @@
#include "src/tint/transform/unshadow.h"
#include "src/tint/transform/unwind_discard_functions.h"
#include "src/tint/transform/vectorize_scalar_matrix_constructors.h"
#include "src/tint/transform/wrap_arrays_in_structs.h"
#include "src/tint/transform/zero_init_workgroup_memory.h"
#include "src/tint/utils/defer.h"
#include "src/tint/utils/map.h"
@@ -208,7 +207,6 @@ SanitizedResult Sanitize(const Program* in, const Options& options) {
manager.Add<transform::PromoteInitializersToConstVar>();
manager.Add<transform::VectorizeScalarMatrixConstructors>();
manager.Add<transform::WrapArraysInStructs>();
manager.Add<transform::RemovePhonies>();
manager.Add<transform::SimplifyPointers>();
// ArrayLengthFromUniform must come after SimplifyPointers, as
@@ -731,13 +729,33 @@ bool GeneratorImpl::EmitTypeConstructor(std::ostream& out,
const sem::TypeConstructor* ctor) {
auto* type = ctor->ReturnType();
if (type->IsAnyOf<sem::Array, sem::Struct>()) {
out << "{";
} else {
if (!EmitType(out, type, "")) {
return false;
}
out << "(";
const char* terminator = ")";
TINT_DEFER(out << terminator);
bool ok = Switch(
type,
[&](const sem::Array*) {
if (!EmitType(out, type, "")) {
return false;
}
out << "{";
terminator = "}";
return true;
},
[&](const sem::Struct*) {
out << "{";
terminator = "}";
return true;
},
[&](Default) {
if (!EmitType(out, type, "")) {
return false;
}
out << "(";
return true;
});
if (!ok) {
return false;
}
int i = 0;
@@ -760,11 +778,6 @@ bool GeneratorImpl::EmitTypeConstructor(std::ostream& out,
i++;
}
if (type->IsAnyOf<sem::Array, sem::Struct>()) {
out << "}";
} else {
out << ")";
}
return true;
}
@@ -1561,10 +1574,9 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
ScopedParen sp(out);
return EmitZeroValue(out, mat->type());
},
[&](const sem::Array* arr) {
out << "{";
TINT_DEFER(out << "}");
return EmitZeroValue(out, arr->ElemType());
[&](const sem::Array*) {
out << "{}";
return true;
},
[&](const sem::Struct*) {
out << "{}";
@@ -1772,8 +1784,8 @@ bool GeneratorImpl::EmitFunction(const ast::Function* func) {
if (!EmitType(out, type, param_name)) {
return false;
}
// Parameter name is output as part of the type for arrays and pointers.
if (!type->Is<sem::Array>() && !type->Is<sem::Pointer>()) {
// Parameter name is output as part of the type for pointers.
if (!type->Is<sem::Pointer>()) {
out << " " << program_->Symbols().NameFor(v->symbol);
}
}
@@ -1896,8 +1908,8 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
if (!EmitType(out, type, param_name)) {
return false;
}
// Parameter name is output as part of the type for arrays and pointers.
if (!type->Is<sem::Array>() && !type->Is<sem::Pointer>()) {
// Parameter name is output as part of the type for pointers.
if (!type->Is<sem::Pointer>()) {
out << " " << param_name;
}
@@ -2412,29 +2424,12 @@ bool GeneratorImpl::EmitType(std::ostream& out,
<< "unhandled atomic type " << atomic->Type()->FriendlyName(builder_.Symbols());
return false;
},
[&](const sem::Array* ary) {
const sem::Type* base_type = ary;
std::vector<uint32_t> sizes;
while (auto* arr = base_type->As<sem::Array>()) {
if (arr->IsRuntimeSized()) {
sizes.push_back(1);
} else {
sizes.push_back(arr->Count());
}
base_type = arr->ElemType();
}
if (!EmitType(out, base_type, "")) {
[&](const sem::Array* arr) {
out << ArrayType() << "<";
if (!EmitType(out, arr->ElemType(), "")) {
return false;
}
if (!name.empty()) {
out << " " << name;
if (name_printed) {
*name_printed = true;
}
}
for (uint32_t size : sizes) {
out << "[" << size << "]";
}
out << ", " << (arr->IsRuntimeSized() ? 1u : arr->Count()) << ">";
return true;
},
[&](const sem::Bool*) {
@@ -2469,22 +2464,12 @@ bool GeneratorImpl::EmitType(std::ostream& out,
return false;
}
out << " ";
if (ptr->StoreType()->Is<sem::Array>()) {
std::string inner = "(*" + name + ")";
if (!EmitType(out, ptr->StoreType(), inner)) {
return false;
}
if (name_printed) {
*name_printed = true;
}
} else {
if (!EmitType(out, ptr->StoreType(), "")) {
return false;
}
out << "* " << name;
if (name_printed) {
*name_printed = true;
}
if (!EmitType(out, ptr->StoreType(), "")) {
return false;
}
out << "* " << name;
if (name_printed) {
*name_printed = true;
}
return true;
},
@@ -2700,7 +2685,7 @@ bool GeneratorImpl::EmitStructType(TextBuffer* b, const sem::Struct* str) {
auto out = line(b);
add_byte_offset_comment(out, msl_offset);
out << "int8_t " << name << "[" << size << "];";
out << ArrayType() << "<int8_t, " << size << "> " << name << ";";
};
b->IncrementIndent();
@@ -2738,11 +2723,7 @@ bool GeneratorImpl::EmitStructType(TextBuffer* b, const sem::Struct* str) {
auto* ty = mem->Type();
// Array member name will be output with the type
if (!ty->Is<sem::Array>()) {
out << " " << mem_name;
}
out << " " << mem_name;
// Emit attributes
if (auto* decl = mem->Declaration()) {
for (auto* attr : decl->attributes) {
@@ -2945,8 +2926,8 @@ bool GeneratorImpl::EmitVar(const ast::Var* var) {
if (!EmitType(out, type, name)) {
return false;
}
// Variable name is output as part of the type for arrays and pointers.
if (!type->Is<sem::Array>() && !type->Is<sem::Pointer>()) {
// Variable name is output as part of the type for pointers.
if (!type->Is<sem::Pointer>()) {
out << " " << name;
}
@@ -2995,8 +2976,8 @@ bool GeneratorImpl::EmitLet(const ast::Let* let) {
return false;
}
// Variable name is output as part of the type for arrays and pointers.
if (!type->Is<sem::Array>() && !type->Is<sem::Pointer>()) {
// Variable name is output as part of the type for pointers.
if (!type->Is<sem::Pointer>()) {
out << " " << name;
}
@@ -3018,9 +2999,7 @@ bool GeneratorImpl::EmitProgramConstVariable(const ast::Let* let) {
if (!EmitType(out, type, program_->Symbols().NameFor(let->symbol))) {
return false;
}
if (!type->Is<sem::Array>()) {
out << " " << program_->Symbols().NameFor(let->symbol);
}
out << " " << program_->Symbols().NameFor(let->symbol);
if (let->constructor != nullptr) {
out << " = ";
@@ -3042,9 +3021,7 @@ bool GeneratorImpl::EmitOverride(const ast::Override* override) {
if (!EmitType(out, type, program_->Symbols().NameFor(override->symbol))) {
return false;
}
if (!type->Is<sem::Array>()) {
out << " " << program_->Symbols().NameFor(override->symbol);
}
out << " " << program_->Symbols().NameFor(override->symbol);
out << " [[function_constant(" << global->ConstantId() << ")]];";
@@ -3117,8 +3094,8 @@ GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(const sem::
[&](const sem::Array* arr) {
if (!arr->IsStrideImplicit()) {
TINT_ICE(Writer, diagnostics_) << "arrays with explicit strides not "
"exist past the SPIR-V reader";
TINT_ICE(Writer, diagnostics_)
<< "arrays with explicit strides should not exist past the SPIR-V reader";
return SizeAndAlign{};
}
auto num_els = std::max<uint32_t>(arr->Count(), 1);
@@ -3205,4 +3182,25 @@ bool GeneratorImpl::CallBuiltinHelper(std::ostream& out,
return true;
}
const std::string& GeneratorImpl::ArrayType() {
if (array_template_name_.empty()) {
array_template_name_ = UniqueIdentifier("tint_array");
auto* buf = &helpers_;
line(buf) << "template<typename T, size_t N>";
line(buf) << "struct " << array_template_name_ << " {";
line(buf) << " const constant T& operator[](size_t i) const constant"
<< " { return elements[i]; }";
for (auto* space : {"device", "thread", "threadgroup"}) {
line(buf) << " " << space << " T& operator[](size_t i) " << space
<< " { return elements[i]; }";
line(buf) << " const " << space << " T& operator[](size_t i) const " << space
<< " { return elements[i]; }";
}
line(buf) << " T elements[N];";
line(buf) << "};";
line(buf);
}
return array_template_name_;
}
} // namespace tint::writer::msl

View File

@@ -425,6 +425,10 @@ class GeneratorImpl : public TextGenerator {
const sem::Builtin* builtin,
F&& build);
/// @returns the name of the templated tint_array helper type, generating it if this is the
/// first call.
const std::string& ArrayType();
TextBuffer helpers_; // Helper functions emitted at the top of the output
/// @returns the MSL packed type size and alignment in bytes for the given
@@ -439,13 +443,17 @@ class GeneratorImpl : public TextGenerator {
utils::UnorderedKeyWrapper<std::tuple<ast::StorageClass, const sem::Struct*>>;
std::unordered_map<ACEWKeyType, std::string> atomicCompareExchangeWeak_;
/// Unique name of the 'TINT_INVARIANT' preprocessor define. Non-empty only if
/// an invariant attribute has been generated.
/// Unique name of the 'TINT_INVARIANT' preprocessor define.
/// Non-empty only if an invariant attribute has been generated.
std::string invariant_define_name_;
/// True if matrix-packed_vector operator overloads have been generated.
bool matrix_packed_vector_overloads_ = false;
/// Unique name of the tint_array<T, N> template.
/// Non-empty only if the template has been generated.
std::string array_template_name_;
/// A map from entry point name to a list of dynamic workgroup allocations.
/// Each entry in the vector is the size of the workgroup allocation that
/// should be created for that index.

View File

@@ -591,11 +591,20 @@ TEST_F(MslGeneratorImplTest, Emit_Function_WithArrayParams) {
EXPECT_EQ(gen.result(), R"( #include <metal_stdlib>
using namespace metal;
struct tint_array_wrapper {
float arr[5];
};
void my_func(tint_array_wrapper a) {
template<typename T, size_t N>
struct tint_array {
const constant T& operator[](size_t i) const constant { return elements[i]; }
device T& operator[](size_t i) device { return elements[i]; }
const device T& operator[](size_t i) const device { return elements[i]; }
thread T& operator[](size_t i) thread { return elements[i]; }
const thread T& operator[](size_t i) const thread { return elements[i]; }
threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
T elements[N];
};
void my_func(tint_array<float, 5> a) {
return;
}
@@ -616,12 +625,21 @@ TEST_F(MslGeneratorImplTest, Emit_Function_WithArrayReturn) {
EXPECT_EQ(gen.result(), R"( #include <metal_stdlib>
using namespace metal;
struct tint_array_wrapper {
float arr[5];
};
tint_array_wrapper my_func() {
tint_array_wrapper const tint_symbol = {.arr={}};
template<typename T, size_t N>
struct tint_array {
const constant T& operator[](size_t i) const constant { return elements[i]; }
device T& operator[](size_t i) device { return elements[i]; }
const device T& operator[](size_t i) const device { return elements[i]; }
thread T& operator[](size_t i) thread { return elements[i]; }
const thread T& operator[](size_t i) const thread { return elements[i]; }
threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
T elements[N];
};
tint_array<float, 5> my_func() {
tint_array<float, 5> const tint_symbol = tint_array<float, 5>{};
return tint_symbol;
}

View File

@@ -28,7 +28,7 @@ TEST_F(MslGeneratorImplTest, Emit_ModuleConstant) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitProgramConstVariable(var)) << gen.error();
EXPECT_EQ(gen.result(), "constant float pos[3] = {1.0f, 2.0f, 3.0f};\n");
EXPECT_EQ(gen.result(), "constant tint_array<float, 3> pos = tint_array<float, 3>{1.0f, 2.0f, 3.0f};\n");
}
TEST_F(MslGeneratorImplTest, Emit_SpecConstant) {

View File

@@ -50,12 +50,25 @@ TEST_F(MslSanitizerTest, Call_ArrayLength) {
auto* expect = R"(#include <metal_stdlib>
using namespace metal;
template<typename T, size_t N>
struct tint_array {
const constant T& operator[](size_t i) const constant { return elements[i]; }
device T& operator[](size_t i) device { return elements[i]; }
const device T& operator[](size_t i) const device { return elements[i]; }
thread T& operator[](size_t i) thread { return elements[i]; }
const thread T& operator[](size_t i) const thread { return elements[i]; }
threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
T elements[N];
};
struct tint_symbol {
/* 0x0000 */ uint4 buffer_size[1];
/* 0x0000 */ tint_array<uint4, 1> buffer_size;
};
struct my_struct {
float a[1];
tint_array<float, 1> a;
};
fragment void a_func(const constant tint_symbol* tint_symbol_2 [[buffer(30)]]) {
@@ -95,13 +108,26 @@ TEST_F(MslSanitizerTest, Call_ArrayLength_OtherMembersInStruct) {
auto* expect = R"(#include <metal_stdlib>
using namespace metal;
template<typename T, size_t N>
struct tint_array {
const constant T& operator[](size_t i) const constant { return elements[i]; }
device T& operator[](size_t i) device { return elements[i]; }
const device T& operator[](size_t i) const device { return elements[i]; }
thread T& operator[](size_t i) thread { return elements[i]; }
const thread T& operator[](size_t i) const thread { return elements[i]; }
threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
T elements[N];
};
struct tint_symbol {
/* 0x0000 */ uint4 buffer_size[1];
/* 0x0000 */ tint_array<uint4, 1> buffer_size;
};
struct my_struct {
float z;
float a[1];
tint_array<float, 1> a;
};
fragment void a_func(const constant tint_symbol* tint_symbol_2 [[buffer(30)]]) {
@@ -143,12 +169,25 @@ TEST_F(MslSanitizerTest, Call_ArrayLength_ViaLets) {
auto* expect = R"(#include <metal_stdlib>
using namespace metal;
template<typename T, size_t N>
struct tint_array {
const constant T& operator[](size_t i) const constant { return elements[i]; }
device T& operator[](size_t i) device { return elements[i]; }
const device T& operator[](size_t i) const device { return elements[i]; }
thread T& operator[](size_t i) thread { return elements[i]; }
const thread T& operator[](size_t i) const thread { return elements[i]; }
threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
T elements[N];
};
struct tint_symbol {
/* 0x0000 */ uint4 buffer_size[1];
/* 0x0000 */ tint_array<uint4, 1> buffer_size;
};
struct my_struct {
float a[1];
tint_array<float, 1> a;
};
fragment void a_func(const constant tint_symbol* tint_symbol_2 [[buffer(30)]]) {
@@ -196,12 +235,25 @@ TEST_F(MslSanitizerTest, Call_ArrayLength_ArrayLengthFromUniform) {
auto* expect = R"(#include <metal_stdlib>
using namespace metal;
template<typename T, size_t N>
struct tint_array {
const constant T& operator[](size_t i) const constant { return elements[i]; }
device T& operator[](size_t i) device { return elements[i]; }
const device T& operator[](size_t i) const device { return elements[i]; }
thread T& operator[](size_t i) thread { return elements[i]; }
const thread T& operator[](size_t i) const thread { return elements[i]; }
threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
T elements[N];
};
struct tint_symbol {
/* 0x0000 */ uint4 buffer_size[2];
/* 0x0000 */ tint_array<uint4, 2> buffer_size;
};
struct my_struct {
float a[1];
tint_array<float, 1> a;
};
fragment void a_func(const constant tint_symbol* tint_symbol_2 [[buffer(29)]]) {

View File

@@ -188,25 +188,34 @@ TEST_F(MslGeneratorImplTest, WorkgroupMatrixInArray) {
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
using namespace metal;
struct tint_array_wrapper {
float2x2 arr[4];
template<typename T, size_t N>
struct tint_array {
const constant T& operator[](size_t i) const constant { return elements[i]; }
device T& operator[](size_t i) device { return elements[i]; }
const device T& operator[](size_t i) const device { return elements[i]; }
thread T& operator[](size_t i) thread { return elements[i]; }
const thread T& operator[](size_t i) const thread { return elements[i]; }
threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
T elements[N];
};
struct tint_symbol_3 {
tint_array_wrapper m;
tint_array<float2x2, 4> m;
};
void comp_main_inner(uint local_invocation_index, threadgroup tint_array_wrapper* const tint_symbol) {
void comp_main_inner(uint local_invocation_index, threadgroup tint_array<float2x2, 4>* const tint_symbol) {
for(uint idx = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
uint const i = idx;
(*(tint_symbol)).arr[i] = float2x2(float2(0.0f), float2(0.0f));
(*(tint_symbol))[i] = float2x2(float2(0.0f), float2(0.0f));
}
threadgroup_barrier(mem_flags::mem_threadgroup);
tint_array_wrapper const x = *(tint_symbol);
tint_array<float2x2, 4> const x = *(tint_symbol);
}
kernel void comp_main(threadgroup tint_symbol_3* tint_symbol_2 [[threadgroup(0)]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
threadgroup tint_array_wrapper* const tint_symbol_1 = &((*(tint_symbol_2)).m);
threadgroup tint_array<float2x2, 4>* const tint_symbol_1 = &((*(tint_symbol_2)).m);
comp_main_inner(local_invocation_index, tint_symbol_1);
return;
}

View File

@@ -31,6 +31,20 @@ using namespace tint::number_suffixes; // NOLINT
namespace tint::writer::msl {
namespace {
void FormatMSLField(std::stringstream& out,
const char* addr,
const char* type,
size_t array_count,
const char* name) {
out << " /* " << std::string(addr) << " */ ";
if (array_count == 0) {
out << type << " ";
} else {
out << "tint_array<" << type << ", " << std::to_string(array_count) << "> ";
}
out << name << ";\n";
}
#define CHECK_TYPE_SIZE_AND_ALIGN(TYPE, SIZE, ALIGN) \
static_assert(sizeof(TYPE) == SIZE, "Bad type size"); \
static_assert(alignof(TYPE) == ALIGN, "Bad type alignment")
@@ -69,7 +83,7 @@ TEST_F(MslGeneratorImplTest, EmitType_Array) {
std::stringstream out;
ASSERT_TRUE(gen.EmitType(out, program->TypeOf(arr), "ary")) << gen.error();
EXPECT_EQ(out.str(), "bool ary[4]");
EXPECT_EQ(out.str(), "tint_array<bool, 4>");
}
TEST_F(MslGeneratorImplTest, EmitType_ArrayOfArray) {
@@ -81,7 +95,7 @@ TEST_F(MslGeneratorImplTest, EmitType_ArrayOfArray) {
std::stringstream out;
ASSERT_TRUE(gen.EmitType(out, program->TypeOf(b), "ary")) << gen.error();
EXPECT_EQ(out.str(), "bool ary[5][4]");
EXPECT_EQ(out.str(), "tint_array<tint_array<bool, 4>, 5>");
}
TEST_F(MslGeneratorImplTest, EmitType_ArrayOfArrayOfArray) {
@@ -94,7 +108,7 @@ TEST_F(MslGeneratorImplTest, EmitType_ArrayOfArrayOfArray) {
std::stringstream out;
ASSERT_TRUE(gen.EmitType(out, program->TypeOf(c), "ary")) << gen.error();
EXPECT_EQ(out.str(), "bool ary[6][5][4]");
EXPECT_EQ(out.str(), "tint_array<tint_array<tint_array<bool, 4>, 5>, 6>");
}
TEST_F(MslGeneratorImplTest, EmitType_Array_WithoutName) {
@@ -105,7 +119,7 @@ TEST_F(MslGeneratorImplTest, EmitType_Array_WithoutName) {
std::stringstream out;
ASSERT_TRUE(gen.EmitType(out, program->TypeOf(arr), "")) << gen.error();
EXPECT_EQ(out.str(), "bool[4]");
EXPECT_EQ(out.str(), "tint_array<bool, 4>");
}
TEST_F(MslGeneratorImplTest, EmitType_RuntimeArray) {
@@ -116,7 +130,7 @@ TEST_F(MslGeneratorImplTest, EmitType_RuntimeArray) {
std::stringstream out;
ASSERT_TRUE(gen.EmitType(out, program->TypeOf(arr), "ary")) << gen.error();
EXPECT_EQ(out.str(), "bool ary[1]");
EXPECT_EQ(out.str(), "tint_array<bool, 1>");
}
TEST_F(MslGeneratorImplTest, EmitType_Bool) {
@@ -245,54 +259,58 @@ TEST_F(MslGeneratorImplTest, EmitType_Struct_Layout_NonComposites) {
auto* sem_s = program->TypeOf(s)->As<sem::Struct>();
ASSERT_TRUE(gen.EmitStructType(&buf, sem_s)) << gen.error();
// ALL_FIELDS() calls the macro FIELD(ADDR, TYPE, NAME, SUFFIX)
// ALL_FIELDS() calls the macro FIELD(ADDR, TYPE, ARRAY_COUNT, NAME)
// for each field of the structure s.
#define ALL_FIELDS() \
FIELD(0x0000, int, a, /*NO SUFFIX*/) \
FIELD(0x0004, int8_t, tint_pad, [124]) \
FIELD(0x0080, float, b, /*NO SUFFIX*/) \
FIELD(0x0084, int8_t, tint_pad_1, [124]) \
FIELD(0x0100, float2, c, /*NO SUFFIX*/) \
FIELD(0x0108, uint, d, /*NO SUFFIX*/) \
FIELD(0x010c, int8_t, tint_pad_2, [4]) \
FIELD(0x0110, packed_float3, e, /*NO SUFFIX*/) \
FIELD(0x011c, uint, f, /*NO SUFFIX*/) \
FIELD(0x0120, float4, g, /*NO SUFFIX*/) \
FIELD(0x0130, uint, h, /*NO SUFFIX*/) \
FIELD(0x0134, int8_t, tint_pad_3, [4]) \
FIELD(0x0138, float2x2, i, /*NO SUFFIX*/) \
FIELD(0x0148, uint, j, /*NO SUFFIX*/) \
FIELD(0x014c, int8_t, tint_pad_4, [4]) \
FIELD(0x0150, float2x3, k, /*NO SUFFIX*/) \
FIELD(0x0170, uint, l, /*NO SUFFIX*/) \
FIELD(0x0174, int8_t, tint_pad_5, [12]) \
FIELD(0x0180, float2x4, m, /*NO SUFFIX*/) \
FIELD(0x01a0, uint, n, /*NO SUFFIX*/) \
FIELD(0x01a4, int8_t, tint_pad_6, [4]) \
FIELD(0x01a8, float3x2, o, /*NO SUFFIX*/) \
FIELD(0x01c0, uint, p, /*NO SUFFIX*/) \
FIELD(0x01c4, int8_t, tint_pad_7, [12]) \
FIELD(0x01d0, float3x3, q, /*NO SUFFIX*/) \
FIELD(0x0200, uint, r, /*NO SUFFIX*/) \
FIELD(0x0204, int8_t, tint_pad_8, [12]) \
FIELD(0x0210, float3x4, s, /*NO SUFFIX*/) \
FIELD(0x0240, uint, t, /*NO SUFFIX*/) \
FIELD(0x0244, int8_t, tint_pad_9, [4]) \
FIELD(0x0248, float4x2, u, /*NO SUFFIX*/) \
FIELD(0x0268, uint, v, /*NO SUFFIX*/) \
FIELD(0x026c, int8_t, tint_pad_10, [4]) \
FIELD(0x0270, float4x3, w, /*NO SUFFIX*/) \
FIELD(0x02b0, uint, x, /*NO SUFFIX*/) \
FIELD(0x02b4, int8_t, tint_pad_11, [12]) \
FIELD(0x02c0, float4x4, y, /*NO SUFFIX*/) \
FIELD(0x0300, float, z, /*NO SUFFIX*/) \
FIELD(0x0304, int8_t, tint_pad_12, [124])
#define ALL_FIELDS() \
FIELD(0x0000, int, 0, a) \
FIELD(0x0004, int8_t, 124, tint_pad) \
FIELD(0x0080, float, 0, b) \
FIELD(0x0084, int8_t, 124, tint_pad_1) \
FIELD(0x0100, float2, 0, c) \
FIELD(0x0108, uint, 0, d) \
FIELD(0x010c, int8_t, 4, tint_pad_2) \
FIELD(0x0110, packed_float3, 0, e) \
FIELD(0x011c, uint, 0, f) \
FIELD(0x0120, float4, 0, g) \
FIELD(0x0130, uint, 0, h) \
FIELD(0x0134, int8_t, 4, tint_pad_3) \
FIELD(0x0138, float2x2, 0, i) \
FIELD(0x0148, uint, 0, j) \
FIELD(0x014c, int8_t, 4, tint_pad_4) \
FIELD(0x0150, float2x3, 0, k) \
FIELD(0x0170, uint, 0, l) \
FIELD(0x0174, int8_t, 12, tint_pad_5) \
FIELD(0x0180, float2x4, 0, m) \
FIELD(0x01a0, uint, 0, n) \
FIELD(0x01a4, int8_t, 4, tint_pad_6) \
FIELD(0x01a8, float3x2, 0, o) \
FIELD(0x01c0, uint, 0, p) \
FIELD(0x01c4, int8_t, 12, tint_pad_7) \
FIELD(0x01d0, float3x3, 0, q) \
FIELD(0x0200, uint, 0, r) \
FIELD(0x0204, int8_t, 12, tint_pad_8) \
FIELD(0x0210, float3x4, 0, s) \
FIELD(0x0240, uint, 0, t) \
FIELD(0x0244, int8_t, 4, tint_pad_9) \
FIELD(0x0248, float4x2, 0, u) \
FIELD(0x0268, uint, 0, v) \
FIELD(0x026c, int8_t, 4, tint_pad_10) \
FIELD(0x0270, float4x3, 0, w) \
FIELD(0x02b0, uint, 0, x) \
FIELD(0x02b4, int8_t, 12, tint_pad_11) \
FIELD(0x02c0, float4x4, 0, y) \
FIELD(0x0300, float, 0, z) \
FIELD(0x0304, int8_t, 124, tint_pad_12)
// Check that the generated string is as expected.
#define FIELD(ADDR, TYPE, NAME, SUFFIX) " /* " #ADDR " */ " #TYPE " " #NAME #SUFFIX ";\n"
auto* expect = "struct S {\n" ALL_FIELDS() "};\n";
std::stringstream expect;
expect << "struct S {\n";
#define FIELD(ADDR, TYPE, ARRAY_COUNT, NAME) \
FormatMSLField(expect, #ADDR, #TYPE, ARRAY_COUNT, #NAME);
ALL_FIELDS()
#undef FIELD
EXPECT_EQ(buf.String(), expect);
expect << "};\n";
EXPECT_EQ(buf.String(), expect.str());
// 1.4 Metal and C++14
// The Metal programming language is a C++14-based Specification with
@@ -304,12 +322,12 @@ TEST_F(MslGeneratorImplTest, EmitType_Struct_Layout_NonComposites) {
// layout is as expected for C++14 / MSL.
{
struct S {
#define FIELD(ADDR, TYPE, NAME, SUFFIX) TYPE NAME SUFFIX;
#define FIELD(ADDR, TYPE, ARRAY_COUNT, NAME) std::array<TYPE, ARRAY_COUNT ? ARRAY_COUNT : 1> NAME;
ALL_FIELDS()
#undef FIELD
};
#define FIELD(ADDR, TYPE, NAME, SUFFIX) \
#define FIELD(ADDR, TYPE, ARRAY_COUNT, NAME) \
EXPECT_EQ(ADDR, static_cast<int>(offsetof(S, NAME))) << "Field " << #NAME;
ALL_FIELDS()
#undef FIELD
@@ -350,22 +368,26 @@ TEST_F(MslGeneratorImplTest, EmitType_Struct_Layout_Structures) {
auto* sem_s = program->TypeOf(s)->As<sem::Struct>();
ASSERT_TRUE(gen.EmitStructType(&buf, sem_s)) << gen.error();
// ALL_FIELDS() calls the macro FIELD(ADDR, TYPE, NAME, SUFFIX)
// ALL_FIELDS() calls the macro FIELD(ADDR, TYPE, ARRAY_COUNT, NAME)
// for each field of the structure s.
#define ALL_FIELDS() \
FIELD(0x0000, int, a, /*NO SUFFIX*/) \
FIELD(0x0004, int8_t, tint_pad, [508]) \
FIELD(0x0200, inner_x, b, /*NO SUFFIX*/) \
FIELD(0x0600, float, c, /*NO SUFFIX*/) \
FIELD(0x0604, inner_y, d, /*NO SUFFIX*/) \
FIELD(0x0808, float, e, /*NO SUFFIX*/) \
FIELD(0x080c, int8_t, tint_pad_1, [500])
#define ALL_FIELDS() \
FIELD(0x0000, int, 0, a) \
FIELD(0x0004, int8_t, 508, tint_pad) \
FIELD(0x0200, inner_x, 0, b) \
FIELD(0x0600, float, 0, c) \
FIELD(0x0604, inner_y, 0, d) \
FIELD(0x0808, float, 0, e) \
FIELD(0x080c, int8_t, 500, tint_pad_1)
// Check that the generated string is as expected.
#define FIELD(ADDR, TYPE, NAME, SUFFIX) " /* " #ADDR " */ " #TYPE " " #NAME #SUFFIX ";\n"
auto* expect = "struct S {\n" ALL_FIELDS() "};\n";
std::stringstream expect;
expect << "struct S {\n";
#define FIELD(ADDR, TYPE, ARRAY_COUNT, NAME) \
FormatMSLField(expect, #ADDR, #TYPE, ARRAY_COUNT, #NAME);
ALL_FIELDS()
#undef FIELD
EXPECT_EQ(buf.String(), expect);
expect << "};\n";
EXPECT_EQ(buf.String(), expect.str());
// 1.4 Metal and C++14
// The Metal programming language is a C++14-based Specification with
@@ -389,12 +411,12 @@ TEST_F(MslGeneratorImplTest, EmitType_Struct_Layout_Structures) {
CHECK_TYPE_SIZE_AND_ALIGN(inner_y, 516, 4);
struct S {
#define FIELD(ADDR, TYPE, NAME, SUFFIX) TYPE NAME SUFFIX;
#define FIELD(ADDR, TYPE, ARRAY_COUNT, NAME) std::array<TYPE, ARRAY_COUNT ? ARRAY_COUNT : 1> NAME;
ALL_FIELDS()
#undef FIELD
};
#define FIELD(ADDR, TYPE, NAME, SUFFIX) \
#define FIELD(ADDR, TYPE, ARRAY_COUNT, NAME) \
EXPECT_EQ(ADDR, static_cast<int>(offsetof(S, NAME))) << "Field " << #NAME;
ALL_FIELDS()
#undef FIELD
@@ -440,23 +462,27 @@ TEST_F(MslGeneratorImplTest, EmitType_Struct_Layout_ArrayDefaultStride) {
auto* sem_s = program->TypeOf(s)->As<sem::Struct>();
ASSERT_TRUE(gen.EmitStructType(&buf, sem_s)) << gen.error();
// ALL_FIELDS() calls the macro FIELD(ADDR, TYPE, NAME, SUFFIX)
// ALL_FIELDS() calls the macro FIELD(ADDR, TYPE, ARRAY_COUNT, NAME)
// for each field of the structure s.
#define ALL_FIELDS() \
FIELD(0x0000, int, a, /*NO SUFFIX*/) \
FIELD(0x0004, float, b, [7]) \
FIELD(0x0020, float, c, /*NO SUFFIX*/) \
FIELD(0x0024, int8_t, tint_pad, [476]) \
FIELD(0x0200, inner, d, [4]) \
FIELD(0x1200, float, e, /*NO SUFFIX*/) \
FIELD(0x1204, float, f, [1]) \
FIELD(0x1208, int8_t, tint_pad_1, [504])
#define ALL_FIELDS() \
FIELD(0x0000, int, 0, a) \
FIELD(0x0004, float, 7, b) \
FIELD(0x0020, float, 0, c) \
FIELD(0x0024, int8_t, 476, tint_pad) \
FIELD(0x0200, inner, 4, d) \
FIELD(0x1200, float, 0, e) \
FIELD(0x1204, float, 1, f) \
FIELD(0x1208, int8_t, 504, tint_pad_1)
// Check that the generated string is as expected.
#define FIELD(ADDR, TYPE, NAME, SUFFIX) " /* " #ADDR " */ " #TYPE " " #NAME #SUFFIX ";\n"
auto* expect = "struct S {\n" ALL_FIELDS() "};\n";
std::stringstream expect;
expect << "struct S {\n";
#define FIELD(ADDR, TYPE, ARRAY_COUNT, NAME) \
FormatMSLField(expect, #ADDR, #TYPE, ARRAY_COUNT, #NAME);
ALL_FIELDS()
#undef FIELD
EXPECT_EQ(buf.String(), expect);
expect << "};\n";
EXPECT_EQ(buf.String(), expect.str());
// 1.4 Metal and C++14
// The Metal programming language is a C++14-based Specification with
@@ -486,12 +512,12 @@ TEST_F(MslGeneratorImplTest, EmitType_Struct_Layout_ArrayDefaultStride) {
CHECK_TYPE_SIZE_AND_ALIGN(array_z, 4, 4);
struct S {
#define FIELD(ADDR, TYPE, NAME, SUFFIX) TYPE NAME SUFFIX;
#define FIELD(ADDR, TYPE, ARRAY_COUNT, NAME) std::array<TYPE, ARRAY_COUNT ? ARRAY_COUNT : 1> NAME;
ALL_FIELDS()
#undef FIELD
};
#define FIELD(ADDR, TYPE, NAME, SUFFIX) \
#define FIELD(ADDR, TYPE, ARRAY_COUNT, NAME) \
EXPECT_EQ(ADDR, static_cast<int>(offsetof(S, NAME))) << "Field " << #NAME;
ALL_FIELDS()
#undef FIELD
@@ -522,20 +548,24 @@ TEST_F(MslGeneratorImplTest, EmitType_Struct_Layout_ArrayVec3DefaultStride) {
auto* sem_s = program->TypeOf(s)->As<sem::Struct>();
ASSERT_TRUE(gen.EmitStructType(&buf, sem_s)) << gen.error();
// ALL_FIELDS() calls the macro FIELD(ADDR, TYPE, NAME, SUFFIX)
// ALL_FIELDS() calls the macro FIELD(ADDR, TYPE, ARRAY_COUNT, NAME)
// for each field of the structure s.
#define ALL_FIELDS() \
FIELD(0x0000, int, a, /*NO SUFFIX*/) \
FIELD(0x0004, int8_t, tint_pad, [12]) \
FIELD(0x0010, float3, b, [4]) \
FIELD(0x0050, int, c, /*NO SUFFIX*/) \
FIELD(0x0054, int8_t, tint_pad_1, [12])
#define ALL_FIELDS() \
FIELD(0x0000, int, 0, a) \
FIELD(0x0004, int8_t, 12, tint_pad) \
FIELD(0x0010, float3, 4, b) \
FIELD(0x0050, int, 0, c) \
FIELD(0x0054, int8_t, 12, tint_pad_1)
// Check that the generated string is as expected.
#define FIELD(ADDR, TYPE, NAME, SUFFIX) " /* " #ADDR " */ " #TYPE " " #NAME #SUFFIX ";\n"
auto* expect = "struct S {\n" ALL_FIELDS() "};\n";
std::stringstream expect;
expect << "struct S {\n";
#define FIELD(ADDR, TYPE, ARRAY_COUNT, NAME) \
FormatMSLField(expect, #ADDR, #TYPE, ARRAY_COUNT, #NAME);
ALL_FIELDS()
#undef FIELD
EXPECT_EQ(buf.String(), expect);
expect << "};\n";
EXPECT_EQ(buf.String(), expect.str());
}
TEST_F(MslGeneratorImplTest, AttemptTintPadSymbolCollision) {
@@ -583,44 +613,44 @@ TEST_F(MslGeneratorImplTest, AttemptTintPadSymbolCollision) {
ASSERT_TRUE(gen.EmitStructType(&buf, sem_s)) << gen.error();
EXPECT_EQ(buf.String(), R"(struct S {
/* 0x0000 */ int tint_pad_2;
/* 0x0004 */ int8_t tint_pad_10[124];
/* 0x0004 */ tint_array<int8_t, 124> tint_pad_10;
/* 0x0080 */ float tint_pad_20;
/* 0x0084 */ int8_t tint_pad_11[124];
/* 0x0084 */ tint_array<int8_t, 124> tint_pad_11;
/* 0x0100 */ float2 tint_pad_33;
/* 0x0108 */ uint tint_pad_1;
/* 0x010c */ int8_t tint_pad_12[4];
/* 0x010c */ tint_array<int8_t, 4> tint_pad_12;
/* 0x0110 */ packed_float3 tint_pad_3;
/* 0x011c */ uint tint_pad_7;
/* 0x0120 */ float4 tint_pad_25;
/* 0x0130 */ uint tint_pad_5;
/* 0x0134 */ int8_t tint_pad_13[4];
/* 0x0134 */ tint_array<int8_t, 4> tint_pad_13;
/* 0x0138 */ float2x2 tint_pad_27;
/* 0x0148 */ uint tint_pad_24;
/* 0x014c */ int8_t tint_pad_14[4];
/* 0x014c */ tint_array<int8_t, 4> tint_pad_14;
/* 0x0150 */ float2x3 tint_pad_23;
/* 0x0170 */ uint tint_pad;
/* 0x0174 */ int8_t tint_pad_15[12];
/* 0x0174 */ tint_array<int8_t, 12> tint_pad_15;
/* 0x0180 */ float2x4 tint_pad_8;
/* 0x01a0 */ uint tint_pad_26;
/* 0x01a4 */ int8_t tint_pad_16[4];
/* 0x01a4 */ tint_array<int8_t, 4> tint_pad_16;
/* 0x01a8 */ float3x2 tint_pad_29;
/* 0x01c0 */ uint tint_pad_6;
/* 0x01c4 */ int8_t tint_pad_17[12];
/* 0x01c4 */ tint_array<int8_t, 12> tint_pad_17;
/* 0x01d0 */ float3x3 tint_pad_22;
/* 0x0200 */ uint tint_pad_32;
/* 0x0204 */ int8_t tint_pad_18[12];
/* 0x0204 */ tint_array<int8_t, 12> tint_pad_18;
/* 0x0210 */ float3x4 tint_pad_34;
/* 0x0240 */ uint tint_pad_35;
/* 0x0244 */ int8_t tint_pad_19[4];
/* 0x0244 */ tint_array<int8_t, 4> tint_pad_19;
/* 0x0248 */ float4x2 tint_pad_30;
/* 0x0268 */ uint tint_pad_9;
/* 0x026c */ int8_t tint_pad_36[4];
/* 0x026c */ tint_array<int8_t, 4> tint_pad_36;
/* 0x0270 */ float4x3 tint_pad_31;
/* 0x02b0 */ uint tint_pad_28;
/* 0x02b4 */ int8_t tint_pad_37[12];
/* 0x02b4 */ tint_array<int8_t, 12> tint_pad_37;
/* 0x02c0 */ float4x4 tint_pad_4;
/* 0x0300 */ float tint_pad_21;
/* 0x0304 */ int8_t tint_pad_38[124];
/* 0x0304 */ tint_array<int8_t, 124> tint_pad_38;
};
)");
}

View File

@@ -61,7 +61,7 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Array) {
gen.increment_indent();
ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error();
EXPECT_EQ(gen.result(), " float a[5] = {0.0f};\n");
EXPECT_EQ(gen.result(), " tint_array<float, 5> a = {};\n");
}
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Struct) {