mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-16 16:37:08 +00:00
reader/spirv: Decompose arrays with strides
Transform any SPIR-V that has an array with a custom stride:
@stride(S) array<T, N>
into:
struct strided_arr {
@size(S) er : T;
};
array<strided_arr, N>
Also remove any @stride decorations that match the default array stride.
Bug: tint:1394
Bug: tint:1381
Change-Id: I8be8f3a76c5335fdb2bc5183388366091dbc7642
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/78781
Reviewed-by: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
committed by
Tint LUCI CQ
parent
de857e1c58
commit
009d129103
@@ -441,6 +441,8 @@ libtint_source_set("libtint_core_all_src") {
|
||||
"transform/combine_samplers.h",
|
||||
"transform/decompose_memory_access.cc",
|
||||
"transform/decompose_memory_access.h",
|
||||
"transform/decompose_strided_array.cc",
|
||||
"transform/decompose_strided_array.h",
|
||||
"transform/decompose_strided_matrix.cc",
|
||||
"transform/decompose_strided_matrix.h",
|
||||
"transform/external_texture_transform.cc",
|
||||
|
||||
@@ -309,6 +309,8 @@ set(TINT_LIB_SRCS
|
||||
transform/canonicalize_entry_point_io.h
|
||||
transform/decompose_memory_access.cc
|
||||
transform/decompose_memory_access.h
|
||||
transform/decompose_strided_array.cc
|
||||
transform/decompose_strided_array.h
|
||||
transform/decompose_strided_matrix.cc
|
||||
transform/decompose_strided_matrix.h
|
||||
transform/external_texture_transform.cc
|
||||
@@ -984,6 +986,7 @@ if(TINT_BUILD_TESTS)
|
||||
transform/canonicalize_entry_point_io_test.cc
|
||||
transform/combine_samplers_test.cc
|
||||
transform/decompose_memory_access_test.cc
|
||||
transform/decompose_strided_array_test.cc
|
||||
transform/decompose_strided_matrix_test.cc
|
||||
transform/external_texture_transform_test.cc
|
||||
transform/first_index_offset_test.cc
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include <utility>
|
||||
|
||||
#include "src/reader/spirv/parser_impl.h"
|
||||
#include "src/transform/decompose_strided_array.h"
|
||||
#include "src/transform/decompose_strided_matrix.h"
|
||||
#include "src/transform/manager.h"
|
||||
#include "src/transform/remove_unreachable_statements.h"
|
||||
@@ -54,6 +55,7 @@ Program Parse(const std::vector<uint32_t>& input) {
|
||||
manager.Add<transform::Unshadow>();
|
||||
manager.Add<transform::SimplifyPointers>();
|
||||
manager.Add<transform::DecomposeStridedMatrix>();
|
||||
manager.Add<transform::DecomposeStridedArray>();
|
||||
manager.Add<transform::RemoveUnreachableStatements>();
|
||||
return manager.Run(&program).program;
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ namespace tint {
|
||||
|
||||
// Forward declarations
|
||||
namespace ast {
|
||||
class Array;
|
||||
class CallExpression;
|
||||
class Expression;
|
||||
class ElseStatement;
|
||||
@@ -60,6 +61,7 @@ class Variable;
|
||||
/// rules will be used to infer the return type based on the argument type.
|
||||
struct TypeMappings {
|
||||
//! @cond Doxygen_Suppress
|
||||
Array* operator()(ast::Array*);
|
||||
Call* operator()(ast::CallExpression*);
|
||||
Expression* operator()(ast::Expression*);
|
||||
ElseStatement* operator()(ast::ElseStatement*);
|
||||
|
||||
162
src/transform/decompose_strided_array.cc
Normal file
162
src/transform/decompose_strided_array.cc
Normal file
@@ -0,0 +1,162 @@
|
||||
// Copyright 2022 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/transform/decompose_strided_array.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/program_builder.h"
|
||||
#include "src/sem/call.h"
|
||||
#include "src/sem/expression.h"
|
||||
#include "src/sem/member_accessor_expression.h"
|
||||
#include "src/sem/type_constructor.h"
|
||||
#include "src/transform/simplify_pointers.h"
|
||||
#include "src/utils/hash.h"
|
||||
#include "src/utils/map.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeStridedArray);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using DecomposedArrays = std::unordered_map<const sem::Array*, Symbol>;
|
||||
|
||||
} // namespace
|
||||
|
||||
DecomposeStridedArray::DecomposeStridedArray() = default;
|
||||
|
||||
DecomposeStridedArray::~DecomposeStridedArray() = default;
|
||||
|
||||
bool DecomposeStridedArray::ShouldRun(const Program* program,
|
||||
const DataMap&) const {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (auto* ast = node->As<ast::Array>()) {
|
||||
if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void DecomposeStridedArray::Run(CloneContext& ctx,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
const auto& sem = ctx.src->Sem();
|
||||
|
||||
static constexpr const char* kMemberName = "el";
|
||||
|
||||
// Maps an array type in the source program to the name of the struct wrapper
|
||||
// type in the target program.
|
||||
std::unordered_map<const sem::Array*, Symbol> decomposed;
|
||||
|
||||
// Find and replace all arrays with a @stride attribute with a array that has
|
||||
// the @stride removed. If the source array stride does not match the natural
|
||||
// stride for the array element type, then replace the array element type with
|
||||
// a structure, holding a single field with a @size attribute equal to the
|
||||
// array stride.
|
||||
ctx.ReplaceAll([&](const ast::Array* ast) -> const ast::Array* {
|
||||
if (auto* arr = sem.Get(ast)) {
|
||||
if (!arr->IsStrideImplicit()) {
|
||||
auto el_ty = utils::GetOrCreate(decomposed, arr, [&] {
|
||||
auto name = ctx.dst->Symbols().New("strided_arr");
|
||||
auto* member_ty = ctx.Clone(ast->type);
|
||||
auto* member = ctx.dst->Member(kMemberName, member_ty,
|
||||
{ctx.dst->MemberSize(arr->Stride())});
|
||||
ctx.dst->Structure(name, {member});
|
||||
return name;
|
||||
});
|
||||
auto* count = ctx.Clone(ast->count);
|
||||
return ctx.dst->ty.array(ctx.dst->ty.type_name(el_ty), count);
|
||||
}
|
||||
if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) {
|
||||
// Strip the @stride attribute
|
||||
auto* ty = ctx.Clone(ast->type);
|
||||
auto* count = ctx.Clone(ast->count);
|
||||
return ctx.dst->ty.array(ty, count);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
// Find all array index-accessors expressions for arrays that have had their
|
||||
// element changed to a single field structure. These expressions are adjusted
|
||||
// to insert an additional member accessor for the single structure field.
|
||||
// Example: `arr[i]` -> `arr[i].el`
|
||||
ctx.ReplaceAll(
|
||||
[&](const ast::IndexAccessorExpression* idx) -> const ast::Expression* {
|
||||
if (auto* ty = ctx.src->TypeOf(idx->object)) {
|
||||
if (auto* arr = ty->UnwrapRef()->As<sem::Array>()) {
|
||||
if (!arr->IsStrideImplicit()) {
|
||||
auto* expr = ctx.CloneWithoutTransform(idx);
|
||||
return ctx.dst->MemberAccessor(expr, kMemberName);
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
// Find all array type constructor expressions for array types that have had
|
||||
// their element changed to a single field structure. These constructors are
|
||||
// adjusted to wrap each of the arguments with an additional constructor for
|
||||
// the new element structure type.
|
||||
// Example:
|
||||
// `@stride(32) array<i32, 3>(1, 2, 3)`
|
||||
// ->
|
||||
// `array<strided_arr, 3>(strided_arr(1), strided_arr(2), strided_arr(3))`
|
||||
ctx.ReplaceAll(
|
||||
[&](const ast::CallExpression* expr) -> const ast::Expression* {
|
||||
if (!expr->args.empty()) {
|
||||
if (auto* call = sem.Get(expr)) {
|
||||
if (auto* ctor = call->Target()->As<sem::TypeConstructor>()) {
|
||||
if (auto* arr = ctor->ReturnType()->As<sem::Array>()) {
|
||||
// Begin by cloning the array constructor type or name
|
||||
// If this is an unaliased array, this may add a new entry to
|
||||
// decomposed.
|
||||
// If this is an aliased array, decomposed should already be
|
||||
// populated with any strided aliases.
|
||||
ast::CallExpression::Target target;
|
||||
if (expr->target.type) {
|
||||
target.type = ctx.Clone(expr->target.type);
|
||||
} else {
|
||||
target.name = ctx.Clone(expr->target.name);
|
||||
}
|
||||
|
||||
ast::ExpressionList args;
|
||||
if (auto it = decomposed.find(arr); it != decomposed.end()) {
|
||||
args.reserve(expr->args.size());
|
||||
for (auto* arg : expr->args) {
|
||||
args.emplace_back(
|
||||
ctx.dst->Call(it->second, ctx.Clone(arg)));
|
||||
}
|
||||
} else {
|
||||
args = ctx.Clone(expr->args);
|
||||
}
|
||||
|
||||
return target.type ? ctx.dst->Construct(target.type, args)
|
||||
: ctx.dst->Call(target.name, args);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
61
src/transform/decompose_strided_array.h
Normal file
61
src/transform/decompose_strided_array.h
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TRANSFORM_DECOMPOSE_STRIDED_ARRAY_H_
|
||||
#define SRC_TRANSFORM_DECOMPOSE_STRIDED_ARRAY_H_
|
||||
|
||||
#include "src/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// DecomposeStridedArray transforms replaces arrays with a non-default
|
||||
/// `@stride` attribute with an array of structure elements, where the
|
||||
/// structure contains a single field with an equivalent `@size` attribute.
|
||||
/// `@stride` attributes on arrays that match the default stride are also
|
||||
/// removed.
|
||||
///
|
||||
/// @note Depends on the following transforms to have been run first:
|
||||
/// * SimplifyPointers
|
||||
class DecomposeStridedArray
|
||||
: public Castable<DecomposeStridedArray, Transform> {
|
||||
public:
|
||||
/// Constructor
|
||||
DecomposeStridedArray();
|
||||
|
||||
/// Destructor
|
||||
~DecomposeStridedArray() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TRANSFORM_DECOMPOSE_STRIDED_ARRAY_H_
|
||||
698
src/transform/decompose_strided_array_test.cc
Normal file
698
src/transform/decompose_strided_array_test.cc
Normal file
@@ -0,0 +1,698 @@
|
||||
// Copyright 2022 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/transform/decompose_strided_array.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/program_builder.h"
|
||||
#include "src/transform/simplify_pointers.h"
|
||||
#include "src/transform/test_helper.h"
|
||||
#include "src/transform/unshadow.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using DecomposeStridedArrayTest = TransformTest;
|
||||
using f32 = ProgramBuilder::f32;
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, ShouldRunEmptyModule) {
|
||||
ProgramBuilder b;
|
||||
EXPECT_FALSE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, ShouldRunNonStridedArray) {
|
||||
// var<private> arr : array<f32, 4>
|
||||
|
||||
ProgramBuilder b;
|
||||
b.Global("arr", b.ty.array<f32, 4>(), ast::StorageClass::kPrivate);
|
||||
EXPECT_FALSE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, ShouldRunDefaultStridedArray) {
|
||||
// var<private> arr : @stride(4) array<f32, 4>
|
||||
|
||||
ProgramBuilder b;
|
||||
b.Global("arr", b.ty.array<f32, 4>(4), ast::StorageClass::kPrivate);
|
||||
EXPECT_TRUE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, ShouldRunExplicitStridedArray) {
|
||||
// var<private> arr : @stride(16) array<f32, 4>
|
||||
|
||||
ProgramBuilder b;
|
||||
b.Global("arr", b.ty.array<f32, 4>(16), ast::StorageClass::kPrivate);
|
||||
EXPECT_TRUE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, Empty) {
|
||||
auto* src = R"()";
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<DecomposeStridedArray>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, PrivateDefaultStridedArray) {
|
||||
// var<private> arr : @stride(4) array<f32, 4>
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let a : @stride(4) array<f32, 4> = a;
|
||||
// let b : f32 = arr[1];
|
||||
// }
|
||||
|
||||
ProgramBuilder b;
|
||||
b.Global("arr", b.ty.array<f32, 4>(4), ast::StorageClass::kPrivate);
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("a", b.ty.array<f32, 4>(4), b.Expr("arr"))),
|
||||
b.Decl(b.Const("b", b.ty.f32(), b.IndexAccessor("arr", 1))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
var<private> arr : array<f32, 4>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let a : array<f32, 4> = arr;
|
||||
let b : f32 = arr[1];
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, PrivateStridedArray) {
|
||||
// var<private> arr : @stride(32) array<f32, 4>
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let a : @stride(32) array<f32, 4> = a;
|
||||
// let b : f32 = arr[1];
|
||||
// }
|
||||
|
||||
ProgramBuilder b;
|
||||
b.Global("arr", b.ty.array<f32, 4>(32), ast::StorageClass::kPrivate);
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("a", b.ty.array<f32, 4>(32), b.Expr("arr"))),
|
||||
b.Decl(b.Const("b", b.ty.f32(), b.IndexAccessor("arr", 1))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
struct strided_arr {
|
||||
@size(32)
|
||||
el : f32;
|
||||
}
|
||||
|
||||
var<private> arr : array<strided_arr, 4>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let a : array<strided_arr, 4> = arr;
|
||||
let b : f32 = arr[1].el;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, ReadUniformStridedArray) {
|
||||
// struct S {
|
||||
// a : @stride(32) array<f32, 4>;
|
||||
// };
|
||||
// @group(0) @binding(0) var<uniform> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let a : @stride(32) array<f32, 4> = s.a;
|
||||
// let b : f32 = s.a[1];
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(32))});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
|
||||
b.GroupAndBinding(0, 0));
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("a", b.ty.array<f32, 4>(32),
|
||||
b.MemberAccessor("s", "a"))),
|
||||
b.Decl(b.Const("b", b.ty.f32(),
|
||||
b.IndexAccessor(b.MemberAccessor("s", "a"), 1))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
struct strided_arr {
|
||||
@size(32)
|
||||
el : f32;
|
||||
}
|
||||
|
||||
struct S {
|
||||
a : array<strided_arr, 4>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<uniform> s : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let a : array<strided_arr, 4> = s.a;
|
||||
let b : f32 = s.a[1].el;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, ReadUniformDefaultStridedArray) {
|
||||
// struct S {
|
||||
// a : @stride(16) array<vec4<f32>, 4>;
|
||||
// };
|
||||
// @group(0) @binding(0) var<uniform> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let a : @stride(16) array<vec4<f32>, 4> = s.a;
|
||||
// let b : f32 = s.a[1][2];
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S =
|
||||
b.Structure("S", {b.Member("a", b.ty.array(b.ty.vec4<f32>(), 4, 16))});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
|
||||
b.GroupAndBinding(0, 0));
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("a", b.ty.array(b.ty.vec4<f32>(), 4, 16),
|
||||
b.MemberAccessor("s", "a"))),
|
||||
b.Decl(b.Const(
|
||||
"b", b.ty.f32(),
|
||||
b.IndexAccessor(b.IndexAccessor(b.MemberAccessor("s", "a"), 1),
|
||||
2))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect =
|
||||
R"(
|
||||
struct S {
|
||||
a : array<vec4<f32>, 4>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<uniform> s : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let a : array<vec4<f32>, 4> = s.a;
|
||||
let b : f32 = s.a[1][2];
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, ReadStorageStridedArray) {
|
||||
// struct S {
|
||||
// a : @stride(32) array<f32, 4>;
|
||||
// };
|
||||
// @group(0) @binding(0) var<storage> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let a : @stride(32) array<f32, 4> = s.a;
|
||||
// let b : f32 = s.a[1];
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(32))});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
|
||||
b.GroupAndBinding(0, 0));
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("a", b.ty.array<f32, 4>(32),
|
||||
b.MemberAccessor("s", "a"))),
|
||||
b.Decl(b.Const("b", b.ty.f32(),
|
||||
b.IndexAccessor(b.MemberAccessor("s", "a"), 1))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
struct strided_arr {
|
||||
@size(32)
|
||||
el : f32;
|
||||
}
|
||||
|
||||
struct S {
|
||||
a : array<strided_arr, 4>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage> s : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let a : array<strided_arr, 4> = s.a;
|
||||
let b : f32 = s.a[1].el;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, ReadStorageDefaultStridedArray) {
|
||||
// struct S {
|
||||
// a : @stride(4) array<f32, 4>;
|
||||
// };
|
||||
// @group(0) @binding(0) var<storage> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let a : @stride(4) array<f32, 4> = s.a;
|
||||
// let b : f32 = s.a[1];
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(4))});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
|
||||
b.GroupAndBinding(0, 0));
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("a", b.ty.array<f32, 4>(4),
|
||||
b.MemberAccessor("s", "a"))),
|
||||
b.Decl(b.Const("b", b.ty.f32(),
|
||||
b.IndexAccessor(b.MemberAccessor("s", "a"), 1))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
a : array<f32, 4>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage> s : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let a : array<f32, 4> = s.a;
|
||||
let b : f32 = s.a[1];
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, WriteStorageStridedArray) {
|
||||
// struct S {
|
||||
// a : @stride(32) array<f32, 4>;
|
||||
// };
|
||||
// @group(0) @binding(0) var<storage, read_write> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// s.a = @stride(32) array<f32, 4>();
|
||||
// s.a = @stride(32) array<f32, 4>(1.0, 2.0, 3.0, 4.0);
|
||||
// s.a[1] = 5.0;
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(32))});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
|
||||
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
|
||||
b.Func(
|
||||
"f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Assign(b.MemberAccessor("s", "a"),
|
||||
b.Construct(b.ty.array<f32, 4>(32))),
|
||||
b.Assign(b.MemberAccessor("s", "a"),
|
||||
b.Construct(b.ty.array<f32, 4>(32), 1.0f, 2.0f, 3.0f, 4.0f)),
|
||||
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1), 5.0f),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect =
|
||||
R"(
|
||||
struct strided_arr {
|
||||
@size(32)
|
||||
el : f32;
|
||||
}
|
||||
|
||||
struct S {
|
||||
a : array<strided_arr, 4>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> s : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
s.a = array<strided_arr, 4>();
|
||||
s.a = array<strided_arr, 4>(strided_arr(1.0), strided_arr(2.0), strided_arr(3.0), strided_arr(4.0));
|
||||
s.a[1].el = 5.0;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, WriteStorageDefaultStridedArray) {
|
||||
// struct S {
|
||||
// a : @stride(4) array<f32, 4>;
|
||||
// };
|
||||
// @group(0) @binding(0) var<storage, read_write> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// s.a = @stride(4) array<f32, 4>();
|
||||
// s.a = @stride(4) array<f32, 4>(1.0, 2.0, 3.0, 4.0);
|
||||
// s.a[1] = 5.0;
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(4))});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
|
||||
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
|
||||
b.Func(
|
||||
"f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Assign(b.MemberAccessor("s", "a"),
|
||||
b.Construct(b.ty.array<f32, 4>(4))),
|
||||
b.Assign(b.MemberAccessor("s", "a"),
|
||||
b.Construct(b.ty.array<f32, 4>(4), 1.0f, 2.0f, 3.0f, 4.0f)),
|
||||
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1), 5.0f),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect =
|
||||
R"(
|
||||
struct S {
|
||||
a : array<f32, 4>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> s : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
s.a = array<f32, 4>();
|
||||
s.a = array<f32, 4>(1.0, 2.0, 3.0, 4.0);
|
||||
s.a[1] = 5.0;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, ReadWriteViaPointerLets) {
|
||||
// struct S {
|
||||
// a : @stride(32) array<f32, 4>;
|
||||
// };
|
||||
// @group(0) @binding(0) var<storage, read_write> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let a = &s.a;
|
||||
// let b = &*&*(a);
|
||||
// let c = *b;
|
||||
// let d = (*b)[1];
|
||||
// (*b) = @stride(32) array<f32, 4>(1.0, 2.0, 3.0, 4.0);
|
||||
// (*b)[1] = 5.0;
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(32))});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
|
||||
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("a", nullptr,
|
||||
b.AddressOf(b.MemberAccessor("s", "a")))),
|
||||
b.Decl(b.Const("b", nullptr,
|
||||
b.AddressOf(b.Deref(b.AddressOf(b.Deref("a")))))),
|
||||
b.Decl(b.Const("c", nullptr, b.Deref("b"))),
|
||||
b.Decl(b.Const("d", nullptr, b.IndexAccessor(b.Deref("b"), 1))),
|
||||
b.Assign(b.Deref("b"), b.Construct(b.ty.array<f32, 4>(32), 1.0f,
|
||||
2.0f, 3.0f, 4.0f)),
|
||||
b.Assign(b.IndexAccessor(b.Deref("b"), 1), 5.0f),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect =
|
||||
R"(
|
||||
struct strided_arr {
|
||||
@size(32)
|
||||
el : f32;
|
||||
}
|
||||
|
||||
struct S {
|
||||
a : array<strided_arr, 4>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> s : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let c = s.a;
|
||||
let d = s.a[1].el;
|
||||
s.a = array<strided_arr, 4>(strided_arr(1.0), strided_arr(2.0), strided_arr(3.0), strided_arr(4.0));
|
||||
s.a[1].el = 5.0;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, PrivateAliasedStridedArray) {
|
||||
// type ARR = @stride(32) array<f32, 4>;
|
||||
// struct S {
|
||||
// a : ARR;
|
||||
// };
|
||||
// @group(0) @binding(0) var<storage, read_write> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let a : ARR = s.a;
|
||||
// let b : f32 = s.a[1];
|
||||
// s.a = ARR();
|
||||
// s.a = ARR(1.0, 2.0, 3.0, 4.0);
|
||||
// s.a[1] = 5.0;
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
b.Alias("ARR", b.ty.array<f32, 4>(32));
|
||||
auto* S = b.Structure("S", {b.Member("a", b.ty.type_name("ARR"))});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
|
||||
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
|
||||
b.Func(
|
||||
"f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(
|
||||
b.Const("a", b.ty.type_name("ARR"), b.MemberAccessor("s", "a"))),
|
||||
b.Decl(b.Const("b", b.ty.f32(),
|
||||
b.IndexAccessor(b.MemberAccessor("s", "a"), 1))),
|
||||
b.Assign(b.MemberAccessor("s", "a"),
|
||||
b.Construct(b.ty.type_name("ARR"))),
|
||||
b.Assign(b.MemberAccessor("s", "a"),
|
||||
b.Construct(b.ty.type_name("ARR"), 1.0f, 2.0f, 3.0f, 4.0f)),
|
||||
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1), 5.0f),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
struct strided_arr {
|
||||
@size(32)
|
||||
el : f32;
|
||||
}
|
||||
|
||||
type ARR = array<strided_arr, 4>;
|
||||
|
||||
struct S {
|
||||
a : ARR;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> s : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let a : ARR = s.a;
|
||||
let b : f32 = s.a[1].el;
|
||||
s.a = ARR();
|
||||
s.a = ARR(strided_arr(1.0), strided_arr(2.0), strided_arr(3.0), strided_arr(4.0));
|
||||
s.a[1].el = 5.0;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, PrivateNestedStridedArray) {
|
||||
// type ARR_A = @stride(8) array<f32, 2>;
|
||||
// type ARR_B = @stride(128) array<@stride(16) array<ARR_A, 3>, 4>;
|
||||
// struct S {
|
||||
// a : ARR_B;
|
||||
// };
|
||||
// @group(0) @binding(0) var<storage, read_write> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let a : ARR_B = s.a;
|
||||
// let b : array<@stride(8) array<f32, 2>, 3> = s.a[3];
|
||||
// let c = s.a[3][2];
|
||||
// let d = s.a[3][2][1];
|
||||
// s.a = ARR_B();
|
||||
// s.a[3][2][1] = 5.0;
|
||||
// }
|
||||
|
||||
ProgramBuilder b;
|
||||
b.Alias("ARR_A", b.ty.array<f32, 2>(8));
|
||||
b.Alias("ARR_B",
|
||||
b.ty.array( //
|
||||
b.ty.array(b.ty.type_name("ARR_A"), 3, 16), //
|
||||
4, 128));
|
||||
auto* S = b.Structure("S", {b.Member("a", b.ty.type_name("ARR_B"))});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
|
||||
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("a", b.ty.type_name("ARR_B"),
|
||||
b.MemberAccessor("s", "a"))),
|
||||
b.Decl(b.Const("b", b.ty.array(b.ty.type_name("ARR_A"), 3, 16),
|
||||
b.IndexAccessor( //
|
||||
b.MemberAccessor("s", "a"), //
|
||||
3))),
|
||||
b.Decl(b.Const("c", b.ty.type_name("ARR_A"),
|
||||
b.IndexAccessor( //
|
||||
b.IndexAccessor( //
|
||||
b.MemberAccessor("s", "a"), //
|
||||
3),
|
||||
2))),
|
||||
b.Decl(b.Const("d", b.ty.f32(),
|
||||
b.IndexAccessor( //
|
||||
b.IndexAccessor( //
|
||||
b.IndexAccessor( //
|
||||
b.MemberAccessor("s", "a"), //
|
||||
3),
|
||||
2),
|
||||
1))),
|
||||
b.Assign(b.MemberAccessor("s", "a"),
|
||||
b.Construct(b.ty.type_name("ARR_B"))),
|
||||
b.Assign(b.IndexAccessor( //
|
||||
b.IndexAccessor( //
|
||||
b.IndexAccessor( //
|
||||
b.MemberAccessor("s", "a"), //
|
||||
3),
|
||||
2),
|
||||
1),
|
||||
5.0f),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect =
|
||||
R"(
|
||||
struct strided_arr {
|
||||
@size(8)
|
||||
el : f32;
|
||||
}
|
||||
|
||||
type ARR_A = array<strided_arr, 2>;
|
||||
|
||||
struct strided_arr_1 {
|
||||
@size(128)
|
||||
el : array<ARR_A, 3>;
|
||||
}
|
||||
|
||||
type ARR_B = array<strided_arr_1, 4>;
|
||||
|
||||
struct S {
|
||||
a : ARR_B;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> s : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let a : ARR_B = s.a;
|
||||
let b : array<ARR_A, 3> = s.a[3].el;
|
||||
let c : ARR_A = s.a[3].el[2];
|
||||
let d : f32 = s.a[3].el[2][1].el;
|
||||
s.a = ARR_B();
|
||||
s.a[3].el[2][1].el = 5.0;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
@@ -81,6 +81,15 @@ class TransformTestBase : public BASE {
|
||||
return manager.Run(&program, data);
|
||||
}
|
||||
|
||||
/// @param program the input program
|
||||
/// @param data the optional DataMap to pass to Transform::Run()
|
||||
/// @return true if the transform should be run for the given input.
|
||||
template <typename TRANSFORM>
|
||||
bool ShouldRun(Program&& program, const DataMap& data = {}) {
|
||||
EXPECT_TRUE(program.IsValid()) << program.Diagnostics().str();
|
||||
return TRANSFORM().ShouldRun(&program, data);
|
||||
}
|
||||
|
||||
/// @param in the input WGSL source
|
||||
/// @param data the optional DataMap to pass to Transform::Run()
|
||||
/// @return true if the transform should be run for the given input.
|
||||
@@ -88,8 +97,7 @@ class TransformTestBase : public BASE {
|
||||
bool ShouldRun(std::string in, const DataMap& data = {}) {
|
||||
auto file = std::make_unique<Source::File>("test", in);
|
||||
auto program = reader::wgsl::Parse(file.get());
|
||||
EXPECT_TRUE(program.IsValid()) << program.Diagnostics().str();
|
||||
return TRANSFORM().ShouldRun(&program, data);
|
||||
return ShouldRun<TRANSFORM>(std::move(program), data);
|
||||
}
|
||||
|
||||
/// @param output the output of the transform
|
||||
|
||||
Reference in New Issue
Block a user