mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-21 10:49:14 +00:00
tint->dawn: Shuffle source tree in preperation of merging repos
docs/ -> docs/tint/ fuzzers/ -> src/tint/fuzzers/ samples/ -> src/tint/cmd/ src/ -> src/tint/ test/ -> test/tint/ BUG=tint:1418,tint:1433 Change-Id: Id2aa79f989aef3245b80ef4aa37a27ff16cd700b Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/80482 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Ben Clayton <bclayton@google.com> Commit-Queue: Ryan Harrison <rharrison@chromium.org>
This commit is contained in:
committed by
Tint LUCI CQ
parent
38f1e9c75c
commit
dbc13af287
51
src/tint/transform/add_empty_entry_point.cc
Normal file
51
src/tint/transform/add_empty_entry_point.cc
Normal file
@@ -0,0 +1,51 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/add_empty_entry_point.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::AddEmptyEntryPoint);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
AddEmptyEntryPoint::AddEmptyEntryPoint() = default;
|
||||
|
||||
AddEmptyEntryPoint::~AddEmptyEntryPoint() = default;
|
||||
|
||||
bool AddEmptyEntryPoint::ShouldRun(const Program* program,
|
||||
const DataMap&) const {
|
||||
for (auto* func : program->AST().Functions()) {
|
||||
if (func->IsEntryPoint()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void AddEmptyEntryPoint::Run(CloneContext& ctx,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
ctx.dst->Func(ctx.dst->Symbols().New("unused_entry_point"), {},
|
||||
ctx.dst->ty.void_(), {},
|
||||
{ctx.dst->Stage(ast::PipelineStage::kCompute),
|
||||
ctx.dst->WorkgroupSize(1)});
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
52
src/tint/transform/add_empty_entry_point.h
Normal file
52
src/tint/transform/add_empty_entry_point.h
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_ADD_EMPTY_ENTRY_POINT_H_
|
||||
#define SRC_TINT_TRANSFORM_ADD_EMPTY_ENTRY_POINT_H_
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// Add an empty entry point to the module, if no other entry points exist.
|
||||
class AddEmptyEntryPoint : public Castable<AddEmptyEntryPoint, Transform> {
|
||||
public:
|
||||
/// Constructor
|
||||
AddEmptyEntryPoint();
|
||||
/// Destructor
|
||||
~AddEmptyEntryPoint() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_ADD_EMPTY_ENTRY_POINT_H_
|
||||
88
src/tint/transform/add_empty_entry_point_test.cc
Normal file
88
src/tint/transform/add_empty_entry_point_test.cc
Normal file
@@ -0,0 +1,88 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/add_empty_entry_point.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/transform/test_helper.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using AddEmptyEntryPointTest = TransformTest;
|
||||
|
||||
TEST_F(AddEmptyEntryPointTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<AddEmptyEntryPoint>(src));
|
||||
}
|
||||
|
||||
TEST_F(AddEmptyEntryPointTest, ShouldRunExistingEntryPoint) {
|
||||
auto* src = R"(
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn existing() {}
|
||||
)";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<AddEmptyEntryPoint>(src));
|
||||
}
|
||||
|
||||
TEST_F(AddEmptyEntryPointTest, EmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
auto* expect = R"(
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn unused_entry_point() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<AddEmptyEntryPoint>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(AddEmptyEntryPointTest, ExistingEntryPoint) {
|
||||
auto* src = R"(
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<AddEmptyEntryPoint>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(AddEmptyEntryPointTest, NameClash) {
|
||||
auto* src = R"(var<private> unused_entry_point : f32;)";
|
||||
|
||||
auto* expect = R"(
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn unused_entry_point_1() {
|
||||
}
|
||||
|
||||
var<private> unused_entry_point : f32;
|
||||
)";
|
||||
|
||||
auto got = Run<AddEmptyEntryPoint>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
122
src/tint/transform/add_spirv_block_attribute.cc
Normal file
122
src/tint/transform/add_spirv_block_attribute.cc
Normal file
@@ -0,0 +1,122 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/add_spirv_block_attribute.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/variable.h"
|
||||
#include "src/tint/utils/map.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::AddSpirvBlockAttribute);
|
||||
TINT_INSTANTIATE_TYPEINFO(
|
||||
tint::transform::AddSpirvBlockAttribute::SpirvBlockAttribute);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
AddSpirvBlockAttribute::AddSpirvBlockAttribute() = default;
|
||||
|
||||
AddSpirvBlockAttribute::~AddSpirvBlockAttribute() = default;
|
||||
|
||||
void AddSpirvBlockAttribute::Run(CloneContext& ctx,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
auto& sem = ctx.src->Sem();
|
||||
|
||||
// Collect the set of structs that are nested in other types.
|
||||
std::unordered_set<const sem::Struct*> nested_structs;
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
if (auto* arr = sem.Get<sem::Array>(node->As<ast::Array>())) {
|
||||
if (auto* nested_str = arr->ElemType()->As<sem::Struct>()) {
|
||||
nested_structs.insert(nested_str);
|
||||
}
|
||||
} else if (auto* str = sem.Get<sem::Struct>(node->As<ast::Struct>())) {
|
||||
for (auto* member : str->Members()) {
|
||||
if (auto* nested_str = member->Type()->As<sem::Struct>()) {
|
||||
nested_structs.insert(nested_str);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A map from a type in the source program to a block-decorated wrapper that
|
||||
// contains it in the destination program.
|
||||
std::unordered_map<const sem::Type*, const ast::Struct*> wrapper_structs;
|
||||
|
||||
// Process global variables that are buffers.
|
||||
for (auto* var : ctx.src->AST().GlobalVariables()) {
|
||||
auto* sem_var = sem.Get<sem::GlobalVariable>(var);
|
||||
if (var->declared_storage_class != ast::StorageClass::kStorage &&
|
||||
var->declared_storage_class != ast::StorageClass::kUniform) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* ty = sem.Get(var->type);
|
||||
auto* str = ty->As<sem::Struct>();
|
||||
if (!str || nested_structs.count(str)) {
|
||||
const char* kMemberName = "inner";
|
||||
|
||||
// This is a non-struct or a struct that is nested somewhere else, so we
|
||||
// need to wrap it first.
|
||||
auto* wrapper = utils::GetOrCreate(wrapper_structs, ty, [&]() {
|
||||
auto* block =
|
||||
ctx.dst->ASTNodes().Create<SpirvBlockAttribute>(ctx.dst->ID());
|
||||
auto wrapper_name = ctx.src->Symbols().NameFor(var->symbol) + "_block";
|
||||
auto* ret = ctx.dst->create<ast::Struct>(
|
||||
ctx.dst->Symbols().New(wrapper_name),
|
||||
ast::StructMemberList{
|
||||
ctx.dst->Member(kMemberName, CreateASTTypeFor(ctx, ty))},
|
||||
ast::AttributeList{block});
|
||||
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), var, ret);
|
||||
return ret;
|
||||
});
|
||||
ctx.Replace(var->type, ctx.dst->ty.Of(wrapper));
|
||||
|
||||
// Insert a member accessor to get the original type from the wrapper at
|
||||
// any usage of the original variable.
|
||||
for (auto* user : sem_var->Users()) {
|
||||
ctx.Replace(
|
||||
user->Declaration(),
|
||||
ctx.dst->MemberAccessor(ctx.Clone(var->symbol), kMemberName));
|
||||
}
|
||||
} else {
|
||||
// Add a block attribute to this struct directly.
|
||||
auto* block =
|
||||
ctx.dst->ASTNodes().Create<SpirvBlockAttribute>(ctx.dst->ID());
|
||||
ctx.InsertFront(str->Declaration()->attributes, block);
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
AddSpirvBlockAttribute::SpirvBlockAttribute::SpirvBlockAttribute(ProgramID pid)
|
||||
: Base(pid) {}
|
||||
AddSpirvBlockAttribute::SpirvBlockAttribute::~SpirvBlockAttribute() = default;
|
||||
std::string AddSpirvBlockAttribute::SpirvBlockAttribute::InternalName() const {
|
||||
return "spirv_block";
|
||||
}
|
||||
|
||||
const AddSpirvBlockAttribute::SpirvBlockAttribute*
|
||||
AddSpirvBlockAttribute::SpirvBlockAttribute::Clone(CloneContext* ctx) const {
|
||||
return ctx->dst->ASTNodes()
|
||||
.Create<AddSpirvBlockAttribute::SpirvBlockAttribute>(ctx->dst->ID());
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
76
src/tint/transform/add_spirv_block_attribute.h
Normal file
76
src/tint/transform/add_spirv_block_attribute.h
Normal file
@@ -0,0 +1,76 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_ADD_SPIRV_BLOCK_ATTRIBUTE_H_
|
||||
#define SRC_TINT_TRANSFORM_ADD_SPIRV_BLOCK_ATTRIBUTE_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "src/tint/ast/internal_attribute.h"
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// AddSpirvBlockAttribute is a transform that adds an
|
||||
/// `@internal(spirv_block)` attribute to any structure that is used as the
|
||||
/// store type of a buffer. If that structure is nested inside another structure
|
||||
/// or an array, then it is wrapped inside another structure which gets the
|
||||
/// `@internal(spirv_block)` attribute instead.
|
||||
class AddSpirvBlockAttribute
|
||||
: public Castable<AddSpirvBlockAttribute, Transform> {
|
||||
public:
|
||||
/// SpirvBlockAttribute is an InternalAttribute that is used to decorate a
|
||||
// structure that needs a SPIR-V block attribute.
|
||||
class SpirvBlockAttribute
|
||||
: public Castable<SpirvBlockAttribute, ast::InternalAttribute> {
|
||||
public:
|
||||
/// Constructor
|
||||
/// @param program_id the identifier of the program that owns this node
|
||||
explicit SpirvBlockAttribute(ProgramID program_id);
|
||||
/// Destructor
|
||||
~SpirvBlockAttribute() override;
|
||||
|
||||
/// @return a short description of the internal attribute which will be
|
||||
/// displayed as `@internal(<name>)`
|
||||
std::string InternalName() const override;
|
||||
|
||||
/// Performs a deep clone of this object using the CloneContext `ctx`.
|
||||
/// @param ctx the clone context
|
||||
/// @return the newly cloned object
|
||||
const SpirvBlockAttribute* Clone(CloneContext* ctx) const override;
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
AddSpirvBlockAttribute();
|
||||
|
||||
/// Destructor
|
||||
~AddSpirvBlockAttribute() override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_ADD_SPIRV_BLOCK_ATTRIBUTE_H_
|
||||
615
src/tint/transform/add_spirv_block_attribute_test.cc
Normal file
615
src/tint/transform/add_spirv_block_attribute_test.cc
Normal file
@@ -0,0 +1,615 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/add_spirv_block_attribute.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/transform/test_helper.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using AddSpirvBlockAttributeTest = TransformTest;
|
||||
|
||||
TEST_F(AddSpirvBlockAttributeTest, EmptyModule) {
|
||||
auto* src = "";
|
||||
auto* expect = "";
|
||||
|
||||
auto got = Run<AddSpirvBlockAttribute>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(AddSpirvBlockAttributeTest, Noop_UsedForPrivateVar) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
f : f32;
|
||||
}
|
||||
|
||||
var<private> p : S;
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
p.f = 1.0;
|
||||
}
|
||||
)";
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<AddSpirvBlockAttribute>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(AddSpirvBlockAttributeTest, Noop_UsedForShaderIO) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
@location(0)
|
||||
f : f32;
|
||||
}
|
||||
|
||||
@stage(fragment)
|
||||
fn main() -> S {
|
||||
return S();
|
||||
}
|
||||
)";
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<AddSpirvBlockAttribute>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(AddSpirvBlockAttributeTest, BasicScalar) {
|
||||
auto* src = R"(
|
||||
@group(0) @binding(0)
|
||||
var<uniform> u : f32;
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let f = u;
|
||||
}
|
||||
)";
|
||||
auto* expect = R"(
|
||||
@internal(spirv_block)
|
||||
struct u_block {
|
||||
inner : f32;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<uniform> u : u_block;
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let f = u.inner;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<AddSpirvBlockAttribute>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(AddSpirvBlockAttributeTest, BasicArray) {
|
||||
auto* src = R"(
|
||||
@group(0) @binding(0)
|
||||
var<uniform> u : array<vec4<f32>, 4u>;
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let a = u;
|
||||
}
|
||||
)";
|
||||
auto* expect = R"(
|
||||
@internal(spirv_block)
|
||||
struct u_block {
|
||||
inner : array<vec4<f32>, 4u>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<uniform> u : u_block;
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let a = u.inner;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<AddSpirvBlockAttribute>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(AddSpirvBlockAttributeTest, BasicArray_Alias) {
|
||||
auto* src = R"(
|
||||
type Numbers = array<vec4<f32>, 4u>;
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<uniform> u : Numbers;
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let a = u;
|
||||
}
|
||||
)";
|
||||
auto* expect = R"(
|
||||
type Numbers = array<vec4<f32>, 4u>;
|
||||
|
||||
@internal(spirv_block)
|
||||
struct u_block {
|
||||
inner : array<vec4<f32>, 4u>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<uniform> u : u_block;
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let a = u.inner;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<AddSpirvBlockAttribute>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(AddSpirvBlockAttributeTest, BasicStruct) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
f : f32;
|
||||
};
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<uniform> u : S;
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let f = u.f;
|
||||
}
|
||||
)";
|
||||
auto* expect = R"(
|
||||
@internal(spirv_block)
|
||||
struct S {
|
||||
f : f32;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<uniform> u : S;
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let f = u.f;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<AddSpirvBlockAttribute>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(AddSpirvBlockAttributeTest, Nested_OuterBuffer_InnerNotBuffer) {
|
||||
auto* src = R"(
|
||||
struct Inner {
|
||||
f : f32;
|
||||
};
|
||||
|
||||
struct Outer {
|
||||
i : Inner;
|
||||
};
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<uniform> u : Outer;
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let f = u.i.f;
|
||||
}
|
||||
)";
|
||||
auto* expect = R"(
|
||||
struct Inner {
|
||||
f : f32;
|
||||
}
|
||||
|
||||
@internal(spirv_block)
|
||||
struct Outer {
|
||||
i : Inner;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<uniform> u : Outer;
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let f = u.i.f;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<AddSpirvBlockAttribute>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(AddSpirvBlockAttributeTest, Nested_OuterBuffer_InnerBuffer) {
|
||||
auto* src = R"(
|
||||
struct Inner {
|
||||
f : f32;
|
||||
};
|
||||
|
||||
struct Outer {
|
||||
i : Inner;
|
||||
};
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<uniform> u0 : Outer;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<uniform> u1 : Inner;
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let f0 = u0.i.f;
|
||||
let f1 = u1.f;
|
||||
}
|
||||
)";
|
||||
auto* expect = R"(
|
||||
struct Inner {
|
||||
f : f32;
|
||||
}
|
||||
|
||||
@internal(spirv_block)
|
||||
struct Outer {
|
||||
i : Inner;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<uniform> u0 : Outer;
|
||||
|
||||
@internal(spirv_block)
|
||||
struct u1_block {
|
||||
inner : Inner;
|
||||
}
|
||||
|
||||
@group(0) @binding(1) var<uniform> u1 : u1_block;
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let f0 = u0.i.f;
|
||||
let f1 = u1.inner.f;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<AddSpirvBlockAttribute>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(AddSpirvBlockAttributeTest, Nested_OuterNotBuffer_InnerBuffer) {
|
||||
auto* src = R"(
|
||||
struct Inner {
|
||||
f : f32;
|
||||
};
|
||||
|
||||
struct Outer {
|
||||
i : Inner;
|
||||
};
|
||||
|
||||
var<private> p : Outer;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<uniform> u : Inner;
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let f0 = p.i.f;
|
||||
let f1 = u.f;
|
||||
}
|
||||
)";
|
||||
auto* expect = R"(
|
||||
struct Inner {
|
||||
f : f32;
|
||||
}
|
||||
|
||||
struct Outer {
|
||||
i : Inner;
|
||||
}
|
||||
|
||||
var<private> p : Outer;
|
||||
|
||||
@internal(spirv_block)
|
||||
struct u_block {
|
||||
inner : Inner;
|
||||
}
|
||||
|
||||
@group(0) @binding(1) var<uniform> u : u_block;
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let f0 = p.i.f;
|
||||
let f1 = u.inner.f;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<AddSpirvBlockAttribute>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(AddSpirvBlockAttributeTest, Nested_InnerUsedForMultipleBuffers) {
|
||||
auto* src = R"(
|
||||
struct Inner {
|
||||
f : f32;
|
||||
};
|
||||
|
||||
struct S {
|
||||
i : Inner;
|
||||
};
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<uniform> u0 : S;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<uniform> u1 : Inner;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> u2 : Inner;
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let f0 = u0.i.f;
|
||||
let f1 = u1.f;
|
||||
let f2 = u2.f;
|
||||
}
|
||||
)";
|
||||
auto* expect = R"(
|
||||
struct Inner {
|
||||
f : f32;
|
||||
}
|
||||
|
||||
@internal(spirv_block)
|
||||
struct S {
|
||||
i : Inner;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<uniform> u0 : S;
|
||||
|
||||
@internal(spirv_block)
|
||||
struct u1_block {
|
||||
inner : Inner;
|
||||
}
|
||||
|
||||
@group(0) @binding(1) var<uniform> u1 : u1_block;
|
||||
|
||||
@group(0) @binding(2) var<uniform> u2 : u1_block;
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let f0 = u0.i.f;
|
||||
let f1 = u1.inner.f;
|
||||
let f2 = u2.inner.f;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<AddSpirvBlockAttribute>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(AddSpirvBlockAttributeTest, StructInArray) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
f : f32;
|
||||
};
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<uniform> u : S;
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let f = u.f;
|
||||
let a = array<S, 4>();
|
||||
}
|
||||
)";
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
f : f32;
|
||||
}
|
||||
|
||||
@internal(spirv_block)
|
||||
struct u_block {
|
||||
inner : S;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<uniform> u : u_block;
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let f = u.inner.f;
|
||||
let a = array<S, 4>();
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<AddSpirvBlockAttribute>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(AddSpirvBlockAttributeTest, StructInArray_MultipleBuffers) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
f : f32;
|
||||
};
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<uniform> u0 : S;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<uniform> u1 : S;
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let f0 = u0.f;
|
||||
let f1 = u1.f;
|
||||
let a = array<S, 4>();
|
||||
}
|
||||
)";
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
f : f32;
|
||||
}
|
||||
|
||||
@internal(spirv_block)
|
||||
struct u0_block {
|
||||
inner : S;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<uniform> u0 : u0_block;
|
||||
|
||||
@group(0) @binding(1) var<uniform> u1 : u0_block;
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let f0 = u0.inner.f;
|
||||
let f1 = u1.inner.f;
|
||||
let a = array<S, 4>();
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<AddSpirvBlockAttribute>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(AddSpirvBlockAttributeTest, Aliases_Nested_OuterBuffer_InnerBuffer) {
|
||||
auto* src = R"(
|
||||
struct Inner {
|
||||
f : f32;
|
||||
};
|
||||
|
||||
type MyInner = Inner;
|
||||
|
||||
struct Outer {
|
||||
i : MyInner;
|
||||
};
|
||||
|
||||
type MyOuter = Outer;
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<uniform> u0 : MyOuter;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<uniform> u1 : MyInner;
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let f0 = u0.i.f;
|
||||
let f1 = u1.f;
|
||||
}
|
||||
)";
|
||||
auto* expect = R"(
|
||||
struct Inner {
|
||||
f : f32;
|
||||
}
|
||||
|
||||
type MyInner = Inner;
|
||||
|
||||
@internal(spirv_block)
|
||||
struct Outer {
|
||||
i : MyInner;
|
||||
}
|
||||
|
||||
type MyOuter = Outer;
|
||||
|
||||
@group(0) @binding(0) var<uniform> u0 : MyOuter;
|
||||
|
||||
@internal(spirv_block)
|
||||
struct u1_block {
|
||||
inner : Inner;
|
||||
}
|
||||
|
||||
@group(0) @binding(1) var<uniform> u1 : u1_block;
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let f0 = u0.i.f;
|
||||
let f1 = u1.inner.f;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<AddSpirvBlockAttribute>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(AddSpirvBlockAttributeTest,
|
||||
Aliases_Nested_OuterBuffer_InnerBuffer_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let f0 = u0.i.f;
|
||||
let f1 = u1.f;
|
||||
}
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<uniform> u1 : MyInner;
|
||||
|
||||
type MyInner = Inner;
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<uniform> u0 : MyOuter;
|
||||
|
||||
type MyOuter = Outer;
|
||||
|
||||
struct Outer {
|
||||
i : MyInner;
|
||||
};
|
||||
|
||||
struct Inner {
|
||||
f : f32;
|
||||
};
|
||||
)";
|
||||
auto* expect = R"(
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let f0 = u0.i.f;
|
||||
let f1 = u1.inner.f;
|
||||
}
|
||||
|
||||
@internal(spirv_block)
|
||||
struct u1_block {
|
||||
inner : Inner;
|
||||
}
|
||||
|
||||
@group(0) @binding(1) var<uniform> u1 : u1_block;
|
||||
|
||||
type MyInner = Inner;
|
||||
|
||||
@group(0) @binding(0) var<uniform> u0 : MyOuter;
|
||||
|
||||
type MyOuter = Outer;
|
||||
|
||||
@internal(spirv_block)
|
||||
struct Outer {
|
||||
i : MyInner;
|
||||
}
|
||||
|
||||
struct Inner {
|
||||
f : f32;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<AddSpirvBlockAttribute>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
233
src/tint/transform/array_length_from_uniform.cc
Normal file
233
src/tint/transform/array_length_from_uniform.cc
Normal file
@@ -0,0 +1,233 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/array_length_from_uniform.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/call.h"
|
||||
#include "src/tint/sem/function.h"
|
||||
#include "src/tint/sem/variable.h"
|
||||
#include "src/tint/transform/simplify_pointers.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform::Config);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform::Result);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
ArrayLengthFromUniform::ArrayLengthFromUniform() = default;
|
||||
ArrayLengthFromUniform::~ArrayLengthFromUniform() = default;
|
||||
|
||||
/// Iterate over all arrayLength() builtins that operate on
|
||||
/// storage buffer variables.
|
||||
/// @param ctx the CloneContext.
|
||||
/// @param functor of type void(const ast::CallExpression*, const
|
||||
/// sem::VariableUser, const sem::GlobalVariable*). It takes in an
|
||||
/// ast::CallExpression of the arrayLength call expression node, a
|
||||
/// sem::VariableUser of the used storage buffer variable, and the
|
||||
/// sem::GlobalVariable for the storage buffer.
|
||||
template <typename F>
|
||||
static void IterateArrayLengthOnStorageVar(CloneContext& ctx, F&& functor) {
|
||||
auto& sem = ctx.src->Sem();
|
||||
|
||||
// Find all calls to the arrayLength() builtin.
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
auto* call_expr = node->As<ast::CallExpression>();
|
||||
if (!call_expr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* call = sem.Get(call_expr);
|
||||
auto* builtin = call->Target()->As<sem::Builtin>();
|
||||
if (!builtin || builtin->Type() != sem::BuiltinType::kArrayLength) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get the storage buffer that contains the runtime array.
|
||||
// Since we require SimplifyPointers, we can assume that the arrayLength()
|
||||
// call has one of two forms:
|
||||
// arrayLength(&struct_var.array_member)
|
||||
// arrayLength(&array_var)
|
||||
auto* param = call_expr->args[0]->As<ast::UnaryOpExpression>();
|
||||
if (!param || param->op != ast::UnaryOp::kAddressOf) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
<< "expected form of arrayLength argument to be &array_var or "
|
||||
"&struct_var.array_member";
|
||||
break;
|
||||
}
|
||||
auto* storage_buffer_expr = param->expr;
|
||||
if (auto* accessor = param->expr->As<ast::MemberAccessorExpression>()) {
|
||||
storage_buffer_expr = accessor->structure;
|
||||
}
|
||||
auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
|
||||
if (!storage_buffer_sem) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
<< "expected form of arrayLength argument to be &array_var or "
|
||||
"&struct_var.array_member";
|
||||
break;
|
||||
}
|
||||
|
||||
// Get the index to use for the buffer size array.
|
||||
auto* var = tint::As<sem::GlobalVariable>(storage_buffer_sem->Variable());
|
||||
if (!var) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
<< "storage buffer is not a global variable";
|
||||
break;
|
||||
}
|
||||
functor(call_expr, storage_buffer_sem, var);
|
||||
}
|
||||
}
|
||||
|
||||
bool ArrayLengthFromUniform::ShouldRun(const Program* program,
|
||||
const DataMap&) const {
|
||||
for (auto* fn : program->AST().Functions()) {
|
||||
if (auto* sem_fn = program->Sem().Get(fn)) {
|
||||
for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
|
||||
if (builtin->Type() == sem::BuiltinType::kArrayLength) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ArrayLengthFromUniform::Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const {
|
||||
auto* cfg = inputs.Get<Config>();
|
||||
if (cfg == nullptr) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
diag::System::Transform,
|
||||
"missing transform data for " + std::string(TypeInfo().name));
|
||||
return;
|
||||
}
|
||||
|
||||
const char* kBufferSizeMemberName = "buffer_size";
|
||||
|
||||
// Determine the size of the buffer size array.
|
||||
uint32_t max_buffer_size_index = 0;
|
||||
|
||||
IterateArrayLengthOnStorageVar(
|
||||
ctx, [&](const ast::CallExpression*, const sem::VariableUser*,
|
||||
const sem::GlobalVariable* var) {
|
||||
auto binding = var->BindingPoint();
|
||||
auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
|
||||
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
|
||||
return;
|
||||
}
|
||||
if (idx_itr->second > max_buffer_size_index) {
|
||||
max_buffer_size_index = idx_itr->second;
|
||||
}
|
||||
});
|
||||
|
||||
// Get (or create, on first call) the uniform buffer that will receive the
|
||||
// size of each storage buffer in the module.
|
||||
const ast::Variable* buffer_size_ubo = nullptr;
|
||||
auto get_ubo = [&]() {
|
||||
if (!buffer_size_ubo) {
|
||||
// Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
|
||||
// We do this because UBOs require an element stride that is 16-byte
|
||||
// aligned.
|
||||
auto* buffer_size_struct = ctx.dst->Structure(
|
||||
ctx.dst->Sym(),
|
||||
{ctx.dst->Member(
|
||||
kBufferSizeMemberName,
|
||||
ctx.dst->ty.array(ctx.dst->ty.vec4(ctx.dst->ty.u32()),
|
||||
(max_buffer_size_index / 4) + 1))});
|
||||
buffer_size_ubo = ctx.dst->Global(
|
||||
ctx.dst->Sym(), ctx.dst->ty.Of(buffer_size_struct),
|
||||
ast::StorageClass::kUniform,
|
||||
ast::AttributeList{ctx.dst->GroupAndBinding(
|
||||
cfg->ubo_binding.group, cfg->ubo_binding.binding)});
|
||||
}
|
||||
return buffer_size_ubo;
|
||||
};
|
||||
|
||||
std::unordered_set<uint32_t> used_size_indices;
|
||||
|
||||
IterateArrayLengthOnStorageVar(
|
||||
ctx, [&](const ast::CallExpression* call_expr,
|
||||
const sem::VariableUser* storage_buffer_sem,
|
||||
const sem::GlobalVariable* var) {
|
||||
auto binding = var->BindingPoint();
|
||||
auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
|
||||
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t size_index = idx_itr->second;
|
||||
used_size_indices.insert(size_index);
|
||||
|
||||
// Load the total storage buffer size from the UBO.
|
||||
uint32_t array_index = size_index / 4;
|
||||
auto* vec_expr = ctx.dst->IndexAccessor(
|
||||
ctx.dst->MemberAccessor(get_ubo()->symbol, kBufferSizeMemberName),
|
||||
array_index);
|
||||
uint32_t vec_index = size_index % 4;
|
||||
auto* total_storage_buffer_size =
|
||||
ctx.dst->IndexAccessor(vec_expr, vec_index);
|
||||
|
||||
// Calculate actual array length
|
||||
// total_storage_buffer_size - array_offset
|
||||
// array_length = ----------------------------------------
|
||||
// array_stride
|
||||
const ast::Expression* total_size = total_storage_buffer_size;
|
||||
auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef();
|
||||
const sem::Array* array_type = nullptr;
|
||||
if (auto* str = storage_buffer_type->As<sem::Struct>()) {
|
||||
// The variable is a struct, so subtract the byte offset of the array
|
||||
// member.
|
||||
auto* array_member_sem = str->Members().back();
|
||||
array_type = array_member_sem->Type()->As<sem::Array>();
|
||||
total_size = ctx.dst->Sub(total_storage_buffer_size,
|
||||
array_member_sem->Offset());
|
||||
} else if (auto* arr = storage_buffer_type->As<sem::Array>()) {
|
||||
array_type = arr;
|
||||
} else {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
<< "expected form of arrayLength argument to be &array_var or "
|
||||
"&struct_var.array_member";
|
||||
return;
|
||||
}
|
||||
auto* array_length = ctx.dst->Div(total_size, array_type->Stride());
|
||||
|
||||
ctx.Replace(call_expr, array_length);
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
|
||||
outputs.Add<Result>(used_size_indices);
|
||||
}
|
||||
|
||||
ArrayLengthFromUniform::Config::Config(sem::BindingPoint ubo_bp)
|
||||
: ubo_binding(ubo_bp) {}
|
||||
ArrayLengthFromUniform::Config::Config(const Config&) = default;
|
||||
ArrayLengthFromUniform::Config& ArrayLengthFromUniform::Config::operator=(
|
||||
const Config&) = default;
|
||||
ArrayLengthFromUniform::Config::~Config() = default;
|
||||
|
||||
ArrayLengthFromUniform::Result::Result(
|
||||
std::unordered_set<uint32_t> used_size_indices_in)
|
||||
: used_size_indices(std::move(used_size_indices_in)) {}
|
||||
ArrayLengthFromUniform::Result::Result(const Result&) = default;
|
||||
ArrayLengthFromUniform::Result::~Result() = default;
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
125
src/tint/transform/array_length_from_uniform.h
Normal file
125
src/tint/transform/array_length_from_uniform.h
Normal file
@@ -0,0 +1,125 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_ARRAY_LENGTH_FROM_UNIFORM_H_
|
||||
#define SRC_TINT_TRANSFORM_ARRAY_LENGTH_FROM_UNIFORM_H_
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "src/tint/sem/binding_point.h"
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
|
||||
// Forward declarations
|
||||
class CloneContext;
|
||||
|
||||
namespace transform {
|
||||
|
||||
/// ArrayLengthFromUniform is a transform that implements calls to arrayLength()
|
||||
/// by calculating the length from the total size of the storage buffer, which
|
||||
/// is received via a uniform buffer.
|
||||
///
|
||||
/// The generated uniform buffer will have the form:
|
||||
/// ```
|
||||
/// struct buffer_size_struct {
|
||||
/// buffer_size : array<u32, 8>;
|
||||
/// };
|
||||
///
|
||||
/// @group(0) @binding(30)
|
||||
/// var<uniform> buffer_size_ubo : buffer_size_struct;
|
||||
/// ```
|
||||
/// The binding group and number used for this uniform buffer is provided via
|
||||
/// the `Config` transform input. The `Config` struct also defines the mapping
|
||||
/// from a storage buffer's `BindingPoint` to the array index that will be used
|
||||
/// to get the size of that buffer.
|
||||
///
|
||||
/// This transform assumes that the `SimplifyPointers`
|
||||
/// transforms have been run before it so that arguments to the arrayLength
|
||||
/// builtin always have the form `&resource.array`.
|
||||
///
|
||||
/// @note Depends on the following transforms to have been run first:
|
||||
/// * SimplifyPointers
|
||||
class ArrayLengthFromUniform
|
||||
: public Castable<ArrayLengthFromUniform, Transform> {
|
||||
public:
|
||||
/// Constructor
|
||||
ArrayLengthFromUniform();
|
||||
/// Destructor
|
||||
~ArrayLengthFromUniform() override;
|
||||
|
||||
/// Configuration options for the ArrayLengthFromUniform transform.
|
||||
struct Config : public Castable<Data, transform::Data> {
|
||||
/// Constructor
|
||||
/// @param ubo_bp the binding point to use for the generated uniform buffer.
|
||||
explicit Config(sem::BindingPoint ubo_bp);
|
||||
|
||||
/// Copy constructor
|
||||
Config(const Config&);
|
||||
|
||||
/// Copy assignment
|
||||
/// @return this Config
|
||||
Config& operator=(const Config&);
|
||||
|
||||
/// Destructor
|
||||
~Config() override;
|
||||
|
||||
/// The binding point to use for the generated uniform buffer.
|
||||
sem::BindingPoint ubo_binding;
|
||||
|
||||
/// The mapping from binding point to the index for the buffer size lookup.
|
||||
std::unordered_map<sem::BindingPoint, uint32_t> bindpoint_to_size_index;
|
||||
};
|
||||
|
||||
/// Information produced about what the transform did.
|
||||
/// If there were no calls to the arrayLength() builtin, then no Result will
|
||||
/// be emitted.
|
||||
struct Result : public Castable<Result, transform::Data> {
|
||||
/// Constructor
|
||||
/// @param used_size_indices Indices into the UBO that are statically used.
|
||||
explicit Result(std::unordered_set<uint32_t> used_size_indices);
|
||||
|
||||
/// Copy constructor
|
||||
Result(const Result&);
|
||||
|
||||
/// Destructor
|
||||
~Result() override;
|
||||
|
||||
/// Indices into the UBO that are statically used.
|
||||
const std::unordered_set<uint32_t> used_size_indices;
|
||||
};
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_ARRAY_LENGTH_FROM_UNIFORM_H_
|
||||
589
src/tint/transform/array_length_from_uniform_test.cc
Normal file
589
src/tint/transform/array_length_from_uniform_test.cc
Normal file
@@ -0,0 +1,589 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/array_length_from_uniform.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/transform/simplify_pointers.h"
|
||||
#include "src/tint/transform/test_helper.h"
|
||||
#include "src/tint/transform/unshadow.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using ArrayLengthFromUniformTest = TransformTest;
|
||||
|
||||
TEST_F(ArrayLengthFromUniformTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src));
|
||||
}
|
||||
|
||||
TEST_F(ArrayLengthFromUniformTest, ShouldRunNoArrayLength) {
|
||||
auto* src = R"(
|
||||
struct SB {
|
||||
x : i32;
|
||||
arr : array<i32>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read> sb : SB;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src));
|
||||
}
|
||||
|
||||
TEST_F(ArrayLengthFromUniformTest, ShouldRunWithArrayLength) {
|
||||
auto* src = R"(
|
||||
struct SB {
|
||||
x : i32;
|
||||
arr : array<i32>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read> sb : SB;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
var len : u32 = arrayLength(&sb.arr);
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<ArrayLengthFromUniform>(src));
|
||||
}
|
||||
|
||||
TEST_F(ArrayLengthFromUniformTest, Error_MissingTransformData) {
|
||||
auto* src = R"(
|
||||
struct SB {
|
||||
x : i32;
|
||||
arr : array<i32>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read> sb : SB;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
var len : u32 = arrayLength(&sb.arr);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect =
|
||||
"error: missing transform data for "
|
||||
"tint::transform::ArrayLengthFromUniform";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(ArrayLengthFromUniformTest, Basic) {
|
||||
auto* src = R"(
|
||||
@group(0) @binding(0) var<storage, read> sb : array<i32>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var len : u32 = arrayLength(&sb);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol {
|
||||
buffer_size : array<vec4<u32>, 1u>;
|
||||
}
|
||||
|
||||
@group(0) @binding(30) var<uniform> tint_symbol_1 : tint_symbol;
|
||||
|
||||
@group(0) @binding(0) var<storage, read> sb : array<i32>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var len : u32 = (tint_symbol_1.buffer_size[0u][0u] / 4u);
|
||||
}
|
||||
)";
|
||||
|
||||
ArrayLengthFromUniform::Config cfg({0, 30u});
|
||||
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
|
||||
|
||||
DataMap data;
|
||||
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
EXPECT_EQ(std::unordered_set<uint32_t>({0}),
|
||||
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
|
||||
}
|
||||
|
||||
TEST_F(ArrayLengthFromUniformTest, BasicInStruct) {
|
||||
auto* src = R"(
|
||||
struct SB {
|
||||
x : i32;
|
||||
arr : array<i32>;
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read> sb : SB;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var len : u32 = arrayLength(&sb.arr);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol {
|
||||
buffer_size : array<vec4<u32>, 1u>;
|
||||
}
|
||||
|
||||
@group(0) @binding(30) var<uniform> tint_symbol_1 : tint_symbol;
|
||||
|
||||
struct SB {
|
||||
x : i32;
|
||||
arr : array<i32>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> sb : SB;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var len : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
|
||||
}
|
||||
)";
|
||||
|
||||
ArrayLengthFromUniform::Config cfg({0, 30u});
|
||||
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
|
||||
|
||||
DataMap data;
|
||||
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
EXPECT_EQ(std::unordered_set<uint32_t>({0}),
|
||||
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
|
||||
}
|
||||
|
||||
TEST_F(ArrayLengthFromUniformTest, WithStride) {
|
||||
auto* src = R"(
|
||||
@group(0) @binding(0) var<storage, read> sb : @stride(64) array<i32>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var len : u32 = arrayLength(&sb);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol {
|
||||
buffer_size : array<vec4<u32>, 1u>;
|
||||
}
|
||||
|
||||
@group(0) @binding(30) var<uniform> tint_symbol_1 : tint_symbol;
|
||||
|
||||
@group(0) @binding(0) var<storage, read> sb : @stride(64) array<i32>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var len : u32 = (tint_symbol_1.buffer_size[0u][0u] / 64u);
|
||||
}
|
||||
)";
|
||||
|
||||
ArrayLengthFromUniform::Config cfg({0, 30u});
|
||||
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
|
||||
|
||||
DataMap data;
|
||||
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
EXPECT_EQ(std::unordered_set<uint32_t>({0}),
|
||||
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
|
||||
}
|
||||
|
||||
TEST_F(ArrayLengthFromUniformTest, WithStride_InStruct) {
|
||||
auto* src = R"(
|
||||
struct SB {
|
||||
x : i32;
|
||||
y : f32;
|
||||
arr : @stride(64) array<i32>;
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read> sb : SB;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var len : u32 = arrayLength(&sb.arr);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol {
|
||||
buffer_size : array<vec4<u32>, 1u>;
|
||||
}
|
||||
|
||||
@group(0) @binding(30) var<uniform> tint_symbol_1 : tint_symbol;
|
||||
|
||||
struct SB {
|
||||
x : i32;
|
||||
y : f32;
|
||||
arr : @stride(64) array<i32>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> sb : SB;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var len : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 8u) / 64u);
|
||||
}
|
||||
)";
|
||||
|
||||
ArrayLengthFromUniform::Config cfg({0, 30u});
|
||||
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
|
||||
|
||||
DataMap data;
|
||||
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
EXPECT_EQ(std::unordered_set<uint32_t>({0}),
|
||||
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
|
||||
}
|
||||
|
||||
TEST_F(ArrayLengthFromUniformTest, MultipleStorageBuffers) {
|
||||
auto* src = R"(
|
||||
struct SB1 {
|
||||
x : i32;
|
||||
arr1 : array<i32>;
|
||||
};
|
||||
struct SB2 {
|
||||
x : i32;
|
||||
arr2 : array<vec4<f32>>;
|
||||
};
|
||||
struct SB4 {
|
||||
x : i32;
|
||||
arr4 : array<vec4<f32>>;
|
||||
};
|
||||
|
||||
@group(0) @binding(2) var<storage, read> sb1 : SB1;
|
||||
@group(1) @binding(2) var<storage, read> sb2 : SB2;
|
||||
@group(2) @binding(2) var<storage, read> sb3 : array<vec4<f32>>;
|
||||
@group(3) @binding(2) var<storage, read> sb4 : SB4;
|
||||
@group(4) @binding(2) var<storage, read> sb5 : array<vec4<f32>>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var len1 : u32 = arrayLength(&(sb1.arr1));
|
||||
var len2 : u32 = arrayLength(&(sb2.arr2));
|
||||
var len3 : u32 = arrayLength(&sb3);
|
||||
var len4 : u32 = arrayLength(&(sb4.arr4));
|
||||
var len5 : u32 = arrayLength(&sb5);
|
||||
var x : u32 = (len1 + len2 + len3 + len4 + len5);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol {
|
||||
buffer_size : array<vec4<u32>, 2u>;
|
||||
}
|
||||
|
||||
@group(0) @binding(30) var<uniform> tint_symbol_1 : tint_symbol;
|
||||
|
||||
struct SB1 {
|
||||
x : i32;
|
||||
arr1 : array<i32>;
|
||||
}
|
||||
|
||||
struct SB2 {
|
||||
x : i32;
|
||||
arr2 : array<vec4<f32>>;
|
||||
}
|
||||
|
||||
struct SB4 {
|
||||
x : i32;
|
||||
arr4 : array<vec4<f32>>;
|
||||
}
|
||||
|
||||
@group(0) @binding(2) var<storage, read> sb1 : SB1;
|
||||
|
||||
@group(1) @binding(2) var<storage, read> sb2 : SB2;
|
||||
|
||||
@group(2) @binding(2) var<storage, read> sb3 : array<vec4<f32>>;
|
||||
|
||||
@group(3) @binding(2) var<storage, read> sb4 : SB4;
|
||||
|
||||
@group(4) @binding(2) var<storage, read> sb5 : array<vec4<f32>>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var len1 : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
|
||||
var len2 : u32 = ((tint_symbol_1.buffer_size[0u][1u] - 16u) / 16u);
|
||||
var len3 : u32 = (tint_symbol_1.buffer_size[0u][2u] / 16u);
|
||||
var len4 : u32 = ((tint_symbol_1.buffer_size[0u][3u] - 16u) / 16u);
|
||||
var len5 : u32 = (tint_symbol_1.buffer_size[1u][0u] / 16u);
|
||||
var x : u32 = ((((len1 + len2) + len3) + len4) + len5);
|
||||
}
|
||||
)";
|
||||
|
||||
ArrayLengthFromUniform::Config cfg({0, 30u});
|
||||
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2u}, 0);
|
||||
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{1u, 2u}, 1);
|
||||
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{2u, 2u}, 2);
|
||||
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{3u, 2u}, 3);
|
||||
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{4u, 2u}, 4);
|
||||
|
||||
DataMap data;
|
||||
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
EXPECT_EQ(std::unordered_set<uint32_t>({0, 1, 2, 3, 4}),
|
||||
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
|
||||
}
|
||||
|
||||
TEST_F(ArrayLengthFromUniformTest, MultipleUnusedStorageBuffers) {
|
||||
auto* src = R"(
|
||||
struct SB1 {
|
||||
x : i32;
|
||||
arr1 : array<i32>;
|
||||
};
|
||||
struct SB2 {
|
||||
x : i32;
|
||||
arr2 : array<vec4<f32>>;
|
||||
};
|
||||
struct SB4 {
|
||||
x : i32;
|
||||
arr4 : array<vec4<f32>>;
|
||||
};
|
||||
|
||||
@group(0) @binding(2) var<storage, read> sb1 : SB1;
|
||||
@group(1) @binding(2) var<storage, read> sb2 : SB2;
|
||||
@group(2) @binding(2) var<storage, read> sb3 : array<vec4<f32>>;
|
||||
@group(3) @binding(2) var<storage, read> sb4 : SB4;
|
||||
@group(4) @binding(2) var<storage, read> sb5 : array<vec4<f32>>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var len1 : u32 = arrayLength(&(sb1.arr1));
|
||||
var len3 : u32 = arrayLength(&sb3);
|
||||
var x : u32 = (len1 + len3);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol {
|
||||
buffer_size : array<vec4<u32>, 1u>;
|
||||
}
|
||||
|
||||
@group(0) @binding(30) var<uniform> tint_symbol_1 : tint_symbol;
|
||||
|
||||
struct SB1 {
|
||||
x : i32;
|
||||
arr1 : array<i32>;
|
||||
}
|
||||
|
||||
struct SB2 {
|
||||
x : i32;
|
||||
arr2 : array<vec4<f32>>;
|
||||
}
|
||||
|
||||
struct SB4 {
|
||||
x : i32;
|
||||
arr4 : array<vec4<f32>>;
|
||||
}
|
||||
|
||||
@group(0) @binding(2) var<storage, read> sb1 : SB1;
|
||||
|
||||
@group(1) @binding(2) var<storage, read> sb2 : SB2;
|
||||
|
||||
@group(2) @binding(2) var<storage, read> sb3 : array<vec4<f32>>;
|
||||
|
||||
@group(3) @binding(2) var<storage, read> sb4 : SB4;
|
||||
|
||||
@group(4) @binding(2) var<storage, read> sb5 : array<vec4<f32>>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var len1 : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
|
||||
var len3 : u32 = (tint_symbol_1.buffer_size[0u][2u] / 16u);
|
||||
var x : u32 = (len1 + len3);
|
||||
}
|
||||
)";
|
||||
|
||||
ArrayLengthFromUniform::Config cfg({0, 30u});
|
||||
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2u}, 0);
|
||||
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{1u, 2u}, 1);
|
||||
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{2u, 2u}, 2);
|
||||
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{3u, 2u}, 3);
|
||||
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{4u, 2u}, 4);
|
||||
|
||||
DataMap data;
|
||||
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
EXPECT_EQ(std::unordered_set<uint32_t>({0, 2}),
|
||||
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
|
||||
}
|
||||
|
||||
TEST_F(ArrayLengthFromUniformTest, NoArrayLengthCalls) {
|
||||
auto* src = R"(
|
||||
struct SB {
|
||||
x : i32;
|
||||
arr : array<i32>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> sb : SB;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
_ = &(sb.arr);
|
||||
}
|
||||
)";
|
||||
|
||||
ArrayLengthFromUniform::Config cfg({0, 30u});
|
||||
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
|
||||
|
||||
DataMap data;
|
||||
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
|
||||
|
||||
EXPECT_EQ(src, str(got));
|
||||
EXPECT_EQ(got.data.Get<ArrayLengthFromUniform::Result>(), nullptr);
|
||||
}
|
||||
|
||||
TEST_F(ArrayLengthFromUniformTest, MissingBindingPointToIndexMapping) {
|
||||
auto* src = R"(
|
||||
struct SB1 {
|
||||
x : i32;
|
||||
arr1 : array<i32>;
|
||||
};
|
||||
|
||||
struct SB2 {
|
||||
x : i32;
|
||||
arr2 : array<vec4<f32>>;
|
||||
};
|
||||
|
||||
@group(0) @binding(2) var<storage, read> sb1 : SB1;
|
||||
|
||||
@group(1) @binding(2) var<storage, read> sb2 : SB2;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var len1 : u32 = arrayLength(&(sb1.arr1));
|
||||
var len2 : u32 = arrayLength(&(sb2.arr2));
|
||||
var x : u32 = (len1 + len2);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol {
|
||||
buffer_size : array<vec4<u32>, 1u>;
|
||||
}
|
||||
|
||||
@group(0) @binding(30) var<uniform> tint_symbol_1 : tint_symbol;
|
||||
|
||||
struct SB1 {
|
||||
x : i32;
|
||||
arr1 : array<i32>;
|
||||
}
|
||||
|
||||
struct SB2 {
|
||||
x : i32;
|
||||
arr2 : array<vec4<f32>>;
|
||||
}
|
||||
|
||||
@group(0) @binding(2) var<storage, read> sb1 : SB1;
|
||||
|
||||
@group(1) @binding(2) var<storage, read> sb2 : SB2;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var len1 : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
|
||||
var len2 : u32 = arrayLength(&(sb2.arr2));
|
||||
var x : u32 = (len1 + len2);
|
||||
}
|
||||
)";
|
||||
|
||||
ArrayLengthFromUniform::Config cfg({0, 30u});
|
||||
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2}, 0);
|
||||
|
||||
DataMap data;
|
||||
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
EXPECT_EQ(std::unordered_set<uint32_t>({0}),
|
||||
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
|
||||
}
|
||||
|
||||
TEST_F(ArrayLengthFromUniformTest, OutOfOrder) {
|
||||
auto* src = R"(
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var len : u32 = arrayLength(&sb.arr);
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> sb : SB;
|
||||
|
||||
struct SB {
|
||||
x : i32;
|
||||
arr : array<i32>;
|
||||
};
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol {
|
||||
buffer_size : array<vec4<u32>, 1u>;
|
||||
}
|
||||
|
||||
@group(0) @binding(30) var<uniform> tint_symbol_1 : tint_symbol;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var len : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> sb : SB;
|
||||
|
||||
struct SB {
|
||||
x : i32;
|
||||
arr : array<i32>;
|
||||
}
|
||||
)";
|
||||
|
||||
ArrayLengthFromUniform::Config cfg({0, 30u});
|
||||
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
|
||||
|
||||
DataMap data;
|
||||
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
EXPECT_EQ(std::unordered_set<uint32_t>({0}),
|
||||
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
163
src/tint/transform/binding_remapper.cc
Normal file
163
src/tint/transform/binding_remapper.cc
Normal file
@@ -0,0 +1,163 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/binding_remapper.h"
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/ast/disable_validation_attribute.h"
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/function.h"
|
||||
#include "src/tint/sem/variable.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::BindingRemapper);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::BindingRemapper::Remappings);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
BindingRemapper::Remappings::Remappings(BindingPoints bp,
|
||||
AccessControls ac,
|
||||
bool may_collide)
|
||||
: binding_points(std::move(bp)),
|
||||
access_controls(std::move(ac)),
|
||||
allow_collisions(may_collide) {}
|
||||
|
||||
BindingRemapper::Remappings::Remappings(const Remappings&) = default;
|
||||
BindingRemapper::Remappings::~Remappings() = default;
|
||||
|
||||
BindingRemapper::BindingRemapper() = default;
|
||||
BindingRemapper::~BindingRemapper() = default;
|
||||
|
||||
bool BindingRemapper::ShouldRun(const Program*, const DataMap& inputs) const {
|
||||
if (auto* remappings = inputs.Get<Remappings>()) {
|
||||
return !remappings->binding_points.empty() ||
|
||||
!remappings->access_controls.empty();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void BindingRemapper::Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap&) const {
|
||||
auto* remappings = inputs.Get<Remappings>();
|
||||
if (!remappings) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
diag::System::Transform,
|
||||
"missing transform data for " + std::string(TypeInfo().name));
|
||||
return;
|
||||
}
|
||||
|
||||
// A set of post-remapped binding points that need to be decorated with a
|
||||
// DisableValidationAttribute to disable binding-point-collision validation
|
||||
std::unordered_set<sem::BindingPoint> add_collision_attr;
|
||||
|
||||
if (remappings->allow_collisions) {
|
||||
// Scan for binding point collisions generated by this transform.
|
||||
// Populate all collisions in the `add_collision_attr` set.
|
||||
for (auto* func_ast : ctx.src->AST().Functions()) {
|
||||
if (!func_ast->IsEntryPoint()) {
|
||||
continue;
|
||||
}
|
||||
auto* func = ctx.src->Sem().Get(func_ast);
|
||||
std::unordered_map<sem::BindingPoint, int> binding_point_counts;
|
||||
for (auto* var : func->TransitivelyReferencedGlobals()) {
|
||||
if (auto binding_point = var->Declaration()->BindingPoint()) {
|
||||
BindingPoint from{binding_point.group->value,
|
||||
binding_point.binding->value};
|
||||
auto bp_it = remappings->binding_points.find(from);
|
||||
if (bp_it != remappings->binding_points.end()) {
|
||||
// Remapped
|
||||
BindingPoint to = bp_it->second;
|
||||
if (binding_point_counts[to]++) {
|
||||
add_collision_attr.emplace(to);
|
||||
}
|
||||
} else {
|
||||
// No remapping
|
||||
if (binding_point_counts[from]++) {
|
||||
add_collision_attr.emplace(from);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* var : ctx.src->AST().GlobalVariables()) {
|
||||
if (auto binding_point = var->BindingPoint()) {
|
||||
// The original binding point
|
||||
BindingPoint from{binding_point.group->value,
|
||||
binding_point.binding->value};
|
||||
|
||||
// The binding point after remapping
|
||||
BindingPoint bp = from;
|
||||
|
||||
// Replace any group or binding attributes.
|
||||
// Note: This has to be performed *before* remapping access controls, as
|
||||
// `ctx.Clone(var->attributes)` depend on these replacements.
|
||||
auto bp_it = remappings->binding_points.find(from);
|
||||
if (bp_it != remappings->binding_points.end()) {
|
||||
BindingPoint to = bp_it->second;
|
||||
auto* new_group = ctx.dst->create<ast::GroupAttribute>(to.group);
|
||||
auto* new_binding = ctx.dst->create<ast::BindingAttribute>(to.binding);
|
||||
|
||||
ctx.Replace(binding_point.group, new_group);
|
||||
ctx.Replace(binding_point.binding, new_binding);
|
||||
bp = to;
|
||||
}
|
||||
|
||||
// Replace any access controls.
|
||||
auto ac_it = remappings->access_controls.find(from);
|
||||
if (ac_it != remappings->access_controls.end()) {
|
||||
ast::Access ac = ac_it->second;
|
||||
if (ac > ast::Access::kLastValid) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
diag::System::Transform,
|
||||
"invalid access mode (" +
|
||||
std::to_string(static_cast<uint32_t>(ac)) + ")");
|
||||
return;
|
||||
}
|
||||
auto* sem = ctx.src->Sem().Get(var);
|
||||
if (sem->StorageClass() != ast::StorageClass::kStorage) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
diag::System::Transform,
|
||||
"cannot apply access control to variable with storage class " +
|
||||
std::string(ast::ToString(sem->StorageClass())));
|
||||
return;
|
||||
}
|
||||
auto* ty = sem->Type()->UnwrapRef();
|
||||
const ast::Type* inner_ty = CreateASTTypeFor(ctx, ty);
|
||||
auto* new_var = ctx.dst->create<ast::Variable>(
|
||||
ctx.Clone(var->source), ctx.Clone(var->symbol),
|
||||
var->declared_storage_class, ac, inner_ty, false, false,
|
||||
ctx.Clone(var->constructor), ctx.Clone(var->attributes));
|
||||
ctx.Replace(var, new_var);
|
||||
}
|
||||
|
||||
// Add `DisableValidationAttribute`s if required
|
||||
if (add_collision_attr.count(bp)) {
|
||||
auto* attribute =
|
||||
ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision);
|
||||
ctx.InsertBefore(var->attributes, *var->attributes.begin(), attribute);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
92
src/tint/transform/binding_remapper.h
Normal file
92
src/tint/transform/binding_remapper.h
Normal file
@@ -0,0 +1,92 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_BINDING_REMAPPER_H_
|
||||
#define SRC_TINT_TRANSFORM_BINDING_REMAPPER_H_
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "src/tint/ast/access.h"
|
||||
#include "src/tint/sem/binding_point.h"
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// BindingPoint is an alias to sem::BindingPoint
|
||||
using BindingPoint = sem::BindingPoint;
|
||||
|
||||
/// BindingRemapper is a transform used to remap resource binding points and
|
||||
/// access controls.
|
||||
class BindingRemapper : public Castable<BindingRemapper, Transform> {
|
||||
public:
|
||||
/// BindingPoints is a map of old binding point to new binding point
|
||||
using BindingPoints = std::unordered_map<BindingPoint, BindingPoint>;
|
||||
|
||||
/// AccessControls is a map of old binding point to new access control
|
||||
using AccessControls = std::unordered_map<BindingPoint, ast::Access>;
|
||||
|
||||
/// Remappings is consumed by the BindingRemapper transform.
|
||||
/// Data holds information about shader usage and constant buffer offsets.
|
||||
struct Remappings : public Castable<Data, transform::Data> {
|
||||
/// Constructor
|
||||
/// @param bp a map of new binding points
|
||||
/// @param ac a map of new access controls
|
||||
/// @param may_collide If true, then validation will be disabled for
|
||||
/// binding point collisions generated by this transform
|
||||
Remappings(BindingPoints bp, AccessControls ac, bool may_collide = true);
|
||||
|
||||
/// Copy constructor
|
||||
Remappings(const Remappings&);
|
||||
|
||||
/// Destructor
|
||||
~Remappings() override;
|
||||
|
||||
/// A map of old binding point to new binding point
|
||||
const BindingPoints binding_points;
|
||||
|
||||
/// A map of old binding point to new access controls
|
||||
const AccessControls access_controls;
|
||||
|
||||
/// If true, then validation will be disabled for binding point collisions
|
||||
/// generated by this transform
|
||||
const bool allow_collisions;
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
BindingRemapper();
|
||||
~BindingRemapper() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_BINDING_REMAPPER_H_
|
||||
423
src/tint/transform/binding_remapper_test.cc
Normal file
423
src/tint/transform/binding_remapper_test.cc
Normal file
@@ -0,0 +1,423 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/binding_remapper.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/transform/test_helper.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using BindingRemapperTest = TransformTest;
|
||||
|
||||
TEST_F(BindingRemapperTest, ShouldRunNoRemappings) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<BindingRemapper>(src));
|
||||
}
|
||||
|
||||
TEST_F(BindingRemapperTest, ShouldRunEmptyRemappings) {
|
||||
auto* src = R"()";
|
||||
|
||||
DataMap data;
|
||||
data.Add<BindingRemapper::Remappings>(BindingRemapper::BindingPoints{},
|
||||
BindingRemapper::AccessControls{});
|
||||
|
||||
EXPECT_FALSE(ShouldRun<BindingRemapper>(src, data));
|
||||
}
|
||||
|
||||
TEST_F(BindingRemapperTest, ShouldRunBindingPointRemappings) {
|
||||
auto* src = R"()";
|
||||
|
||||
DataMap data;
|
||||
data.Add<BindingRemapper::Remappings>(
|
||||
BindingRemapper::BindingPoints{
|
||||
{{2, 1}, {1, 2}},
|
||||
},
|
||||
BindingRemapper::AccessControls{});
|
||||
|
||||
EXPECT_TRUE(ShouldRun<BindingRemapper>(src, data));
|
||||
}
|
||||
|
||||
TEST_F(BindingRemapperTest, ShouldRunAccessControlRemappings) {
|
||||
auto* src = R"()";
|
||||
|
||||
DataMap data;
|
||||
data.Add<BindingRemapper::Remappings>(BindingRemapper::BindingPoints{},
|
||||
BindingRemapper::AccessControls{
|
||||
{{2, 1}, ast::Access::kWrite},
|
||||
});
|
||||
|
||||
EXPECT_TRUE(ShouldRun<BindingRemapper>(src, data));
|
||||
}
|
||||
|
||||
TEST_F(BindingRemapperTest, NoRemappings) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
a : f32;
|
||||
}
|
||||
|
||||
@group(2) @binding(1) var<storage, read> a : S;
|
||||
|
||||
@group(3) @binding(2) var<storage, read> b : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
DataMap data;
|
||||
data.Add<BindingRemapper::Remappings>(BindingRemapper::BindingPoints{},
|
||||
BindingRemapper::AccessControls{});
|
||||
auto got = Run<BindingRemapper>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(BindingRemapperTest, RemapBindingPoints) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
a : f32;
|
||||
};
|
||||
|
||||
@group(2) @binding(1) var<storage, read> a : S;
|
||||
|
||||
@group(3) @binding(2) var<storage, read> b : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
a : f32;
|
||||
}
|
||||
|
||||
@group(1) @binding(2) var<storage, read> a : S;
|
||||
|
||||
@group(3) @binding(2) var<storage, read> b : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
data.Add<BindingRemapper::Remappings>(
|
||||
BindingRemapper::BindingPoints{
|
||||
{{2, 1}, {1, 2}}, // Remap
|
||||
{{4, 5}, {6, 7}}, // Not found
|
||||
// Keep @group(3) @binding(2) as is
|
||||
},
|
||||
BindingRemapper::AccessControls{});
|
||||
auto got = Run<BindingRemapper>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(BindingRemapperTest, RemapAccessControls) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
a : f32;
|
||||
};
|
||||
|
||||
@group(2) @binding(1) var<storage, read> a : S;
|
||||
|
||||
@group(3) @binding(2) var<storage, write> b : S;
|
||||
|
||||
@group(4) @binding(3) var<storage, read> c : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
a : f32;
|
||||
}
|
||||
|
||||
@group(2) @binding(1) var<storage, write> a : S;
|
||||
|
||||
@group(3) @binding(2) var<storage, write> b : S;
|
||||
|
||||
@group(4) @binding(3) var<storage, read> c : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
data.Add<BindingRemapper::Remappings>(
|
||||
BindingRemapper::BindingPoints{},
|
||||
BindingRemapper::AccessControls{
|
||||
{{2, 1}, ast::Access::kWrite}, // Modify access control
|
||||
// Keep @group(3) @binding(2) as is
|
||||
{{4, 3}, ast::Access::kRead}, // Add access control
|
||||
});
|
||||
auto got = Run<BindingRemapper>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
// TODO(crbug.com/676): Possibly enable if the spec allows for access
|
||||
// attributes in type aliases. If not, just remove.
|
||||
TEST_F(BindingRemapperTest, DISABLED_RemapAccessControlsWithAliases) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
a : f32;
|
||||
};
|
||||
|
||||
type, read ReadOnlyS = S;
|
||||
|
||||
type, write WriteOnlyS = S;
|
||||
|
||||
type A = S;
|
||||
|
||||
@group(2) @binding(1) var<storage> a : ReadOnlyS;
|
||||
|
||||
@group(3) @binding(2) var<storage> b : WriteOnlyS;
|
||||
|
||||
@group(4) @binding(3) var<storage> c : A;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
a : f32;
|
||||
};
|
||||
|
||||
type, read ReadOnlyS = S;
|
||||
|
||||
type, write WriteOnlyS = S;
|
||||
|
||||
type A = S;
|
||||
|
||||
@group(2) @binding(1) var<storage, write> a : S;
|
||||
|
||||
@group(3) @binding(2) var<storage> b : WriteOnlyS;
|
||||
|
||||
@group(4) @binding(3) var<storage, write> c : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
data.Add<BindingRemapper::Remappings>(
|
||||
BindingRemapper::BindingPoints{},
|
||||
BindingRemapper::AccessControls{
|
||||
{{2, 1}, ast::Access::kWrite}, // Modify access control
|
||||
// Keep @group(3) @binding(2) as is
|
||||
{{4, 3}, ast::Access::kRead}, // Add access control
|
||||
});
|
||||
auto got = Run<BindingRemapper>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(BindingRemapperTest, RemapAll) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
a : f32;
|
||||
};
|
||||
|
||||
@group(2) @binding(1) var<storage, read> a : S;
|
||||
|
||||
@group(3) @binding(2) var<storage, read> b : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
a : f32;
|
||||
}
|
||||
|
||||
@group(4) @binding(5) var<storage, write> a : S;
|
||||
|
||||
@group(6) @binding(7) var<storage, write> b : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
data.Add<BindingRemapper::Remappings>(
|
||||
BindingRemapper::BindingPoints{
|
||||
{{2, 1}, {4, 5}},
|
||||
{{3, 2}, {6, 7}},
|
||||
},
|
||||
BindingRemapper::AccessControls{
|
||||
{{2, 1}, ast::Access::kWrite},
|
||||
{{3, 2}, ast::Access::kWrite},
|
||||
});
|
||||
auto got = Run<BindingRemapper>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(BindingRemapperTest, BindingCollisionsSameEntryPoint) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
i : i32;
|
||||
};
|
||||
|
||||
@group(2) @binding(1) var<storage, read> a : S;
|
||||
|
||||
@group(3) @binding(2) var<storage, read> b : S;
|
||||
|
||||
@group(4) @binding(3) var<storage, read> c : S;
|
||||
|
||||
@group(5) @binding(4) var<storage, read> d : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let x : i32 = (((a.i + b.i) + c.i) + d.i);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
i : i32;
|
||||
}
|
||||
|
||||
@internal(disable_validation__binding_point_collision) @group(1) @binding(1) var<storage, read> a : S;
|
||||
|
||||
@internal(disable_validation__binding_point_collision) @group(1) @binding(1) var<storage, read> b : S;
|
||||
|
||||
@internal(disable_validation__binding_point_collision) @group(5) @binding(4) var<storage, read> c : S;
|
||||
|
||||
@internal(disable_validation__binding_point_collision) @group(5) @binding(4) var<storage, read> d : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let x : i32 = (((a.i + b.i) + c.i) + d.i);
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
data.Add<BindingRemapper::Remappings>(
|
||||
BindingRemapper::BindingPoints{
|
||||
{{2, 1}, {1, 1}},
|
||||
{{3, 2}, {1, 1}},
|
||||
{{4, 3}, {5, 4}},
|
||||
},
|
||||
BindingRemapper::AccessControls{}, true);
|
||||
auto got = Run<BindingRemapper>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(BindingRemapperTest, BindingCollisionsDifferentEntryPoints) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
i : i32;
|
||||
};
|
||||
|
||||
@group(2) @binding(1) var<storage, read> a : S;
|
||||
|
||||
@group(3) @binding(2) var<storage, read> b : S;
|
||||
|
||||
@group(4) @binding(3) var<storage, read> c : S;
|
||||
|
||||
@group(5) @binding(4) var<storage, read> d : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f1() {
|
||||
let x : i32 = (a.i + c.i);
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f2() {
|
||||
let x : i32 = (b.i + d.i);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
i : i32;
|
||||
}
|
||||
|
||||
@group(1) @binding(1) var<storage, read> a : S;
|
||||
|
||||
@group(1) @binding(1) var<storage, read> b : S;
|
||||
|
||||
@group(5) @binding(4) var<storage, read> c : S;
|
||||
|
||||
@group(5) @binding(4) var<storage, read> d : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f1() {
|
||||
let x : i32 = (a.i + c.i);
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f2() {
|
||||
let x : i32 = (b.i + d.i);
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
data.Add<BindingRemapper::Remappings>(
|
||||
BindingRemapper::BindingPoints{
|
||||
{{2, 1}, {1, 1}},
|
||||
{{3, 2}, {1, 1}},
|
||||
{{4, 3}, {5, 4}},
|
||||
},
|
||||
BindingRemapper::AccessControls{}, true);
|
||||
auto got = Run<BindingRemapper>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(BindingRemapperTest, NoData) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
a : f32;
|
||||
}
|
||||
|
||||
@group(2) @binding(1) var<storage, read> a : S;
|
||||
|
||||
@group(3) @binding(2) var<storage, read> b : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<BindingRemapper>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
243
src/tint/transform/calculate_array_length.cc
Normal file
243
src/tint/transform/calculate_array_length.cc
Normal file
@@ -0,0 +1,243 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/calculate_array_length.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/ast/call_statement.h"
|
||||
#include "src/tint/ast/disable_validation_attribute.h"
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/block_statement.h"
|
||||
#include "src/tint/sem/call.h"
|
||||
#include "src/tint/sem/function.h"
|
||||
#include "src/tint/sem/statement.h"
|
||||
#include "src/tint/sem/struct.h"
|
||||
#include "src/tint/sem/variable.h"
|
||||
#include "src/tint/transform/simplify_pointers.h"
|
||||
#include "src/tint/utils/hash.h"
|
||||
#include "src/tint/utils/map.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::CalculateArrayLength);
|
||||
TINT_INSTANTIATE_TYPEINFO(
|
||||
tint::transform::CalculateArrayLength::BufferSizeIntrinsic);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
namespace {
|
||||
|
||||
/// ArrayUsage describes a runtime array usage.
|
||||
/// It is used as a key by the array_length_by_usage map.
|
||||
struct ArrayUsage {
|
||||
ast::BlockStatement const* const block;
|
||||
sem::Variable const* const buffer;
|
||||
bool operator==(const ArrayUsage& rhs) const {
|
||||
return block == rhs.block && buffer == rhs.buffer;
|
||||
}
|
||||
struct Hasher {
|
||||
inline std::size_t operator()(const ArrayUsage& u) const {
|
||||
return utils::Hash(u.block, u.buffer);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
CalculateArrayLength::BufferSizeIntrinsic::BufferSizeIntrinsic(ProgramID pid)
|
||||
: Base(pid) {}
|
||||
CalculateArrayLength::BufferSizeIntrinsic::~BufferSizeIntrinsic() = default;
|
||||
std::string CalculateArrayLength::BufferSizeIntrinsic::InternalName() const {
|
||||
return "intrinsic_buffer_size";
|
||||
}
|
||||
|
||||
const CalculateArrayLength::BufferSizeIntrinsic*
|
||||
CalculateArrayLength::BufferSizeIntrinsic::Clone(CloneContext* ctx) const {
|
||||
return ctx->dst->ASTNodes().Create<CalculateArrayLength::BufferSizeIntrinsic>(
|
||||
ctx->dst->ID());
|
||||
}
|
||||
|
||||
CalculateArrayLength::CalculateArrayLength() = default;
|
||||
CalculateArrayLength::~CalculateArrayLength() = default;
|
||||
|
||||
bool CalculateArrayLength::ShouldRun(const Program* program,
|
||||
const DataMap&) const {
|
||||
for (auto* fn : program->AST().Functions()) {
|
||||
if (auto* sem_fn = program->Sem().Get(fn)) {
|
||||
for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
|
||||
if (builtin->Type() == sem::BuiltinType::kArrayLength) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void CalculateArrayLength::Run(CloneContext& ctx,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
auto& sem = ctx.src->Sem();
|
||||
|
||||
// get_buffer_size_intrinsic() emits the function decorated with
|
||||
// BufferSizeIntrinsic that is transformed by the HLSL writer into a call to
|
||||
// [RW]ByteAddressBuffer.GetDimensions().
|
||||
std::unordered_map<const sem::Type*, Symbol> buffer_size_intrinsics;
|
||||
auto get_buffer_size_intrinsic = [&](const sem::Type* buffer_type) {
|
||||
return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] {
|
||||
auto name = ctx.dst->Sym();
|
||||
auto* type = CreateASTTypeFor(ctx, buffer_type);
|
||||
auto* disable_validation = ctx.dst->Disable(
|
||||
ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
|
||||
ctx.dst->AST().AddFunction(ctx.dst->create<ast::Function>(
|
||||
name,
|
||||
ast::VariableList{
|
||||
// Note: The buffer parameter requires the kStorage StorageClass
|
||||
// in order for HLSL to emit this as a ByteAddressBuffer.
|
||||
ctx.dst->create<ast::Variable>(
|
||||
ctx.dst->Sym("buffer"), ast::StorageClass::kStorage,
|
||||
ast::Access::kUndefined, type, true, false, nullptr,
|
||||
ast::AttributeList{disable_validation}),
|
||||
ctx.dst->Param("result",
|
||||
ctx.dst->ty.pointer(ctx.dst->ty.u32(),
|
||||
ast::StorageClass::kFunction)),
|
||||
},
|
||||
ctx.dst->ty.void_(), nullptr,
|
||||
ast::AttributeList{
|
||||
ctx.dst->ASTNodes().Create<BufferSizeIntrinsic>(ctx.dst->ID()),
|
||||
},
|
||||
ast::AttributeList{}));
|
||||
|
||||
return name;
|
||||
});
|
||||
};
|
||||
|
||||
std::unordered_map<ArrayUsage, Symbol, ArrayUsage::Hasher>
|
||||
array_length_by_usage;
|
||||
|
||||
// Find all the arrayLength() calls...
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
if (auto* call_expr = node->As<ast::CallExpression>()) {
|
||||
auto* call = sem.Get(call_expr);
|
||||
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
|
||||
if (builtin->Type() == sem::BuiltinType::kArrayLength) {
|
||||
// We're dealing with an arrayLength() call
|
||||
|
||||
// A runtime-sized array can only appear as the store type of a
|
||||
// variable, or the last element of a structure (which cannot itself
|
||||
// be nested). Given that we require SimplifyPointers, we can assume
|
||||
// that the arrayLength() call has one of two forms:
|
||||
// arrayLength(&struct_var.array_member)
|
||||
// arrayLength(&array_var)
|
||||
auto* arg = call_expr->args[0];
|
||||
auto* address_of = arg->As<ast::UnaryOpExpression>();
|
||||
if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
<< "arrayLength() expected address-of, got "
|
||||
<< arg->TypeInfo().name;
|
||||
}
|
||||
auto* storage_buffer_expr = address_of->expr;
|
||||
if (auto* accessor =
|
||||
storage_buffer_expr->As<ast::MemberAccessorExpression>()) {
|
||||
storage_buffer_expr = accessor->structure;
|
||||
}
|
||||
auto* storage_buffer_sem =
|
||||
sem.Get<sem::VariableUser>(storage_buffer_expr);
|
||||
if (!storage_buffer_sem) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
<< "expected form of arrayLength argument to be &array_var or "
|
||||
"&struct_var.array_member";
|
||||
break;
|
||||
}
|
||||
auto* storage_buffer_var = storage_buffer_sem->Variable();
|
||||
auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef();
|
||||
|
||||
// Generate BufferSizeIntrinsic for this storage type if we haven't
|
||||
// already
|
||||
auto buffer_size = get_buffer_size_intrinsic(storage_buffer_type);
|
||||
|
||||
// Find the current statement block
|
||||
auto* block = call->Stmt()->Block()->Declaration();
|
||||
|
||||
auto array_length = utils::GetOrCreate(
|
||||
array_length_by_usage, {block, storage_buffer_var}, [&] {
|
||||
// First time this array length is used for this block.
|
||||
// Let's calculate it.
|
||||
|
||||
// Construct the variable that'll hold the result of
|
||||
// RWByteAddressBuffer.GetDimensions()
|
||||
auto* buffer_size_result = ctx.dst->Decl(
|
||||
ctx.dst->Var(ctx.dst->Sym(), ctx.dst->ty.u32(),
|
||||
ast::StorageClass::kNone, ctx.dst->Expr(0u)));
|
||||
|
||||
// Call storage_buffer.GetDimensions(&buffer_size_result)
|
||||
auto* call_get_dims = ctx.dst->CallStmt(ctx.dst->Call(
|
||||
// BufferSizeIntrinsic(X, ARGS...) is
|
||||
// translated to:
|
||||
// X.GetDimensions(ARGS..) by the writer
|
||||
buffer_size, ctx.Clone(storage_buffer_expr),
|
||||
ctx.dst->AddressOf(
|
||||
ctx.dst->Expr(buffer_size_result->variable->symbol))));
|
||||
|
||||
// Calculate actual array length
|
||||
// total_storage_buffer_size - array_offset
|
||||
// array_length = ----------------------------------------
|
||||
// array_stride
|
||||
auto name = ctx.dst->Sym();
|
||||
const ast::Expression* total_size =
|
||||
ctx.dst->Expr(buffer_size_result->variable);
|
||||
const sem::Array* array_type = nullptr;
|
||||
if (auto* str = storage_buffer_type->As<sem::Struct>()) {
|
||||
// The variable is a struct, so subtract the byte offset of
|
||||
// the array member.
|
||||
auto* array_member_sem = str->Members().back();
|
||||
array_type = array_member_sem->Type()->As<sem::Array>();
|
||||
total_size =
|
||||
ctx.dst->Sub(total_size, array_member_sem->Offset());
|
||||
} else if (auto* arr = storage_buffer_type->As<sem::Array>()) {
|
||||
array_type = arr;
|
||||
} else {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
<< "expected form of arrayLength argument to be "
|
||||
"&array_var or &struct_var.array_member";
|
||||
return name;
|
||||
}
|
||||
uint32_t array_stride = array_type->Size();
|
||||
auto* array_length_var = ctx.dst->Decl(
|
||||
ctx.dst->Const(name, ctx.dst->ty.u32(),
|
||||
ctx.dst->Div(total_size, array_stride)));
|
||||
|
||||
// Insert the array length calculations at the top of the block
|
||||
ctx.InsertBefore(block->statements, block->statements[0],
|
||||
buffer_size_result);
|
||||
ctx.InsertBefore(block->statements, block->statements[0],
|
||||
call_get_dims);
|
||||
ctx.InsertBefore(block->statements, block->statements[0],
|
||||
array_length_var);
|
||||
return name;
|
||||
});
|
||||
|
||||
// Replace the call to arrayLength() with the array length variable
|
||||
ctx.Replace(call_expr, ctx.dst->Expr(array_length));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
83
src/tint/transform/calculate_array_length.h
Normal file
83
src/tint/transform/calculate_array_length.h
Normal file
@@ -0,0 +1,83 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_CALCULATE_ARRAY_LENGTH_H_
|
||||
#define SRC_TINT_TRANSFORM_CALCULATE_ARRAY_LENGTH_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "src/tint/ast/internal_attribute.h"
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
|
||||
// Forward declarations
|
||||
class CloneContext;
|
||||
|
||||
namespace transform {
|
||||
|
||||
/// CalculateArrayLength is a transform used to replace calls to arrayLength()
|
||||
/// with a value calculated from the size of the storage buffer.
|
||||
///
|
||||
/// @note Depends on the following transforms to have been run first:
|
||||
/// * SimplifyPointers
|
||||
class CalculateArrayLength : public Castable<CalculateArrayLength, Transform> {
|
||||
public:
|
||||
/// BufferSizeIntrinsic is an InternalAttribute that's applied to intrinsic
|
||||
/// functions used to obtain the runtime size of a storage buffer.
|
||||
class BufferSizeIntrinsic
|
||||
: public Castable<BufferSizeIntrinsic, ast::InternalAttribute> {
|
||||
public:
|
||||
/// Constructor
|
||||
/// @param program_id the identifier of the program that owns this node
|
||||
explicit BufferSizeIntrinsic(ProgramID program_id);
|
||||
/// Destructor
|
||||
~BufferSizeIntrinsic() override;
|
||||
|
||||
/// @return "buffer_size"
|
||||
std::string InternalName() const override;
|
||||
|
||||
/// Performs a deep clone of this object using the CloneContext `ctx`.
|
||||
/// @param ctx the clone context
|
||||
/// @return the newly cloned object
|
||||
const BufferSizeIntrinsic* Clone(CloneContext* ctx) const override;
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
CalculateArrayLength();
|
||||
/// Destructor
|
||||
~CalculateArrayLength() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_CALCULATE_ARRAY_LENGTH_H_
|
||||
625
src/tint/transform/calculate_array_length_test.cc
Normal file
625
src/tint/transform/calculate_array_length_test.cc
Normal file
@@ -0,0 +1,625 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/calculate_array_length.h"
|
||||
|
||||
#include "src/tint/transform/simplify_pointers.h"
|
||||
#include "src/tint/transform/test_helper.h"
|
||||
#include "src/tint/transform/unshadow.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using CalculateArrayLengthTest = TransformTest;
|
||||
|
||||
TEST_F(CalculateArrayLengthTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<CalculateArrayLength>(src));
|
||||
}
|
||||
|
||||
TEST_F(CalculateArrayLengthTest, ShouldRunNoArrayLength) {
|
||||
auto* src = R"(
|
||||
struct SB {
|
||||
x : i32;
|
||||
arr : array<i32>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read> sb : SB;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<CalculateArrayLength>(src));
|
||||
}
|
||||
|
||||
TEST_F(CalculateArrayLengthTest, ShouldRunWithArrayLength) {
|
||||
auto* src = R"(
|
||||
struct SB {
|
||||
x : i32;
|
||||
arr : array<i32>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read> sb : SB;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
var len : u32 = arrayLength(&sb.arr);
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<CalculateArrayLength>(src));
|
||||
}
|
||||
|
||||
TEST_F(CalculateArrayLengthTest, BasicArray) {
|
||||
auto* src = R"(
|
||||
@group(0) @binding(0) var<storage, read> sb : array<i32>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var len : u32 = arrayLength(&sb);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@internal(intrinsic_buffer_size)
|
||||
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<i32>, result : ptr<function, u32>)
|
||||
|
||||
@group(0) @binding(0) var<storage, read> sb : array<i32>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var tint_symbol_1 : u32 = 0u;
|
||||
tint_symbol(sb, &(tint_symbol_1));
|
||||
let tint_symbol_2 : u32 = (tint_symbol_1 / 4u);
|
||||
var len : u32 = tint_symbol_2;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(CalculateArrayLengthTest, BasicInStruct) {
|
||||
auto* src = R"(
|
||||
struct SB {
|
||||
x : i32;
|
||||
arr : array<i32>;
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read> sb : SB;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var len : u32 = arrayLength(&sb.arr);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@internal(intrinsic_buffer_size)
|
||||
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, result : ptr<function, u32>)
|
||||
|
||||
struct SB {
|
||||
x : i32;
|
||||
arr : array<i32>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> sb : SB;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var tint_symbol_1 : u32 = 0u;
|
||||
tint_symbol(sb, &(tint_symbol_1));
|
||||
let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
|
||||
var len : u32 = tint_symbol_2;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(CalculateArrayLengthTest, ArrayOfStruct) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
f : f32;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> arr : array<S>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
let len = arrayLength(&arr);
|
||||
}
|
||||
)";
|
||||
auto* expect = R"(
|
||||
@internal(intrinsic_buffer_size)
|
||||
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<S>, result : ptr<function, u32>)
|
||||
|
||||
struct S {
|
||||
f : f32;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> arr : array<S>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var tint_symbol_1 : u32 = 0u;
|
||||
tint_symbol(arr, &(tint_symbol_1));
|
||||
let tint_symbol_2 : u32 = (tint_symbol_1 / 4u);
|
||||
let len = tint_symbol_2;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(CalculateArrayLengthTest, ArrayOfArrayOfStruct) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
f : f32;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> arr : array<array<S, 4>>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
let len = arrayLength(&arr);
|
||||
}
|
||||
)";
|
||||
auto* expect = R"(
|
||||
@internal(intrinsic_buffer_size)
|
||||
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<array<S, 4u>>, result : ptr<function, u32>)
|
||||
|
||||
struct S {
|
||||
f : f32;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> arr : array<array<S, 4>>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var tint_symbol_1 : u32 = 0u;
|
||||
tint_symbol(arr, &(tint_symbol_1));
|
||||
let tint_symbol_2 : u32 = (tint_symbol_1 / 16u);
|
||||
let len = tint_symbol_2;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(CalculateArrayLengthTest, InSameBlock) {
|
||||
auto* src = R"(
|
||||
@group(0) @binding(0) var<storage, read> sb : array<i32>;;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var a : u32 = arrayLength(&sb);
|
||||
var b : u32 = arrayLength(&sb);
|
||||
var c : u32 = arrayLength(&sb);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@internal(intrinsic_buffer_size)
|
||||
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<i32>, result : ptr<function, u32>)
|
||||
|
||||
@group(0) @binding(0) var<storage, read> sb : array<i32>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var tint_symbol_1 : u32 = 0u;
|
||||
tint_symbol(sb, &(tint_symbol_1));
|
||||
let tint_symbol_2 : u32 = (tint_symbol_1 / 4u);
|
||||
var a : u32 = tint_symbol_2;
|
||||
var b : u32 = tint_symbol_2;
|
||||
var c : u32 = tint_symbol_2;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(CalculateArrayLengthTest, InSameBlock_Struct) {
|
||||
auto* src = R"(
|
||||
struct SB {
|
||||
x : i32;
|
||||
arr : array<i32>;
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read> sb : SB;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var a : u32 = arrayLength(&sb.arr);
|
||||
var b : u32 = arrayLength(&sb.arr);
|
||||
var c : u32 = arrayLength(&sb.arr);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@internal(intrinsic_buffer_size)
|
||||
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, result : ptr<function, u32>)
|
||||
|
||||
struct SB {
|
||||
x : i32;
|
||||
arr : array<i32>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> sb : SB;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var tint_symbol_1 : u32 = 0u;
|
||||
tint_symbol(sb, &(tint_symbol_1));
|
||||
let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
|
||||
var a : u32 = tint_symbol_2;
|
||||
var b : u32 = tint_symbol_2;
|
||||
var c : u32 = tint_symbol_2;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(CalculateArrayLengthTest, WithStride) {
|
||||
auto* src = R"(
|
||||
@group(0) @binding(0) var<storage, read> sb : @stride(64) array<i32>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var len : u32 = arrayLength(&sb);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@internal(intrinsic_buffer_size)
|
||||
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : @stride(64) array<i32>, result : ptr<function, u32>)
|
||||
|
||||
@group(0) @binding(0) var<storage, read> sb : @stride(64) array<i32>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var tint_symbol_1 : u32 = 0u;
|
||||
tint_symbol(sb, &(tint_symbol_1));
|
||||
let tint_symbol_2 : u32 = (tint_symbol_1 / 64u);
|
||||
var len : u32 = tint_symbol_2;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(CalculateArrayLengthTest, WithStride_InStruct) {
|
||||
auto* src = R"(
|
||||
struct SB {
|
||||
x : i32;
|
||||
y : f32;
|
||||
arr : @stride(64) array<i32>;
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read> sb : SB;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var len : u32 = arrayLength(&sb.arr);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@internal(intrinsic_buffer_size)
|
||||
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, result : ptr<function, u32>)
|
||||
|
||||
struct SB {
|
||||
x : i32;
|
||||
y : f32;
|
||||
arr : @stride(64) array<i32>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> sb : SB;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var tint_symbol_1 : u32 = 0u;
|
||||
tint_symbol(sb, &(tint_symbol_1));
|
||||
let tint_symbol_2 : u32 = ((tint_symbol_1 - 8u) / 64u);
|
||||
var len : u32 = tint_symbol_2;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(CalculateArrayLengthTest, Nested) {
|
||||
auto* src = R"(
|
||||
struct SB {
|
||||
x : i32;
|
||||
arr : array<i32>;
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read> sb : SB;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
if (true) {
|
||||
var len : u32 = arrayLength(&sb.arr);
|
||||
} else {
|
||||
if (true) {
|
||||
var len : u32 = arrayLength(&sb.arr);
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@internal(intrinsic_buffer_size)
|
||||
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, result : ptr<function, u32>)
|
||||
|
||||
struct SB {
|
||||
x : i32;
|
||||
arr : array<i32>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> sb : SB;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
if (true) {
|
||||
var tint_symbol_1 : u32 = 0u;
|
||||
tint_symbol(sb, &(tint_symbol_1));
|
||||
let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
|
||||
var len : u32 = tint_symbol_2;
|
||||
} else {
|
||||
if (true) {
|
||||
var tint_symbol_3 : u32 = 0u;
|
||||
tint_symbol(sb, &(tint_symbol_3));
|
||||
let tint_symbol_4 : u32 = ((tint_symbol_3 - 4u) / 4u);
|
||||
var len : u32 = tint_symbol_4;
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(CalculateArrayLengthTest, MultipleStorageBuffers) {
|
||||
auto* src = R"(
|
||||
struct SB1 {
|
||||
x : i32;
|
||||
arr1 : array<i32>;
|
||||
};
|
||||
|
||||
struct SB2 {
|
||||
x : i32;
|
||||
arr2 : array<vec4<f32>>;
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read> sb1 : SB1;
|
||||
|
||||
@group(0) @binding(1) var<storage, read> sb2 : SB2;
|
||||
|
||||
@group(0) @binding(2) var<storage, read> sb3 : array<i32>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var len1 : u32 = arrayLength(&(sb1.arr1));
|
||||
var len2 : u32 = arrayLength(&(sb2.arr2));
|
||||
var len3 : u32 = arrayLength(&sb3);
|
||||
var x : u32 = (len1 + len2 + len3);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@internal(intrinsic_buffer_size)
|
||||
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB1, result : ptr<function, u32>)
|
||||
|
||||
@internal(intrinsic_buffer_size)
|
||||
fn tint_symbol_3(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB2, result : ptr<function, u32>)
|
||||
|
||||
@internal(intrinsic_buffer_size)
|
||||
fn tint_symbol_6(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<i32>, result : ptr<function, u32>)
|
||||
|
||||
struct SB1 {
|
||||
x : i32;
|
||||
arr1 : array<i32>;
|
||||
}
|
||||
|
||||
struct SB2 {
|
||||
x : i32;
|
||||
arr2 : array<vec4<f32>>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> sb1 : SB1;
|
||||
|
||||
@group(0) @binding(1) var<storage, read> sb2 : SB2;
|
||||
|
||||
@group(0) @binding(2) var<storage, read> sb3 : array<i32>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var tint_symbol_1 : u32 = 0u;
|
||||
tint_symbol(sb1, &(tint_symbol_1));
|
||||
let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
|
||||
var tint_symbol_4 : u32 = 0u;
|
||||
tint_symbol_3(sb2, &(tint_symbol_4));
|
||||
let tint_symbol_5 : u32 = ((tint_symbol_4 - 16u) / 16u);
|
||||
var tint_symbol_7 : u32 = 0u;
|
||||
tint_symbol_6(sb3, &(tint_symbol_7));
|
||||
let tint_symbol_8 : u32 = (tint_symbol_7 / 4u);
|
||||
var len1 : u32 = tint_symbol_2;
|
||||
var len2 : u32 = tint_symbol_5;
|
||||
var len3 : u32 = tint_symbol_8;
|
||||
var x : u32 = ((len1 + len2) + len3);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(CalculateArrayLengthTest, Shadowing) {
|
||||
auto* src = R"(
|
||||
struct SB {
|
||||
x : i32;
|
||||
arr : array<i32>;
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read> a : SB;
|
||||
@group(0) @binding(1) var<storage, read> b : SB;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
let x = &a;
|
||||
var a : u32 = arrayLength(&a.arr);
|
||||
{
|
||||
var b : u32 = arrayLength(&((*x).arr));
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect =
|
||||
R"(
|
||||
@internal(intrinsic_buffer_size)
|
||||
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, result : ptr<function, u32>)
|
||||
|
||||
struct SB {
|
||||
x : i32;
|
||||
arr : array<i32>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> a : SB;
|
||||
|
||||
@group(0) @binding(1) var<storage, read> b : SB;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var tint_symbol_1 : u32 = 0u;
|
||||
tint_symbol(a, &(tint_symbol_1));
|
||||
let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
|
||||
var a_1 : u32 = tint_symbol_2;
|
||||
{
|
||||
var tint_symbol_3 : u32 = 0u;
|
||||
tint_symbol(a, &(tint_symbol_3));
|
||||
let tint_symbol_4 : u32 = ((tint_symbol_3 - 4u) / 4u);
|
||||
var b_1 : u32 = tint_symbol_4;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(CalculateArrayLengthTest, OutOfOrder) {
|
||||
auto* src = R"(
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var len1 : u32 = arrayLength(&(sb1.arr1));
|
||||
var len2 : u32 = arrayLength(&(sb2.arr2));
|
||||
var len3 : u32 = arrayLength(&sb3);
|
||||
var x : u32 = (len1 + len2 + len3);
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> sb1 : SB1;
|
||||
|
||||
struct SB1 {
|
||||
x : i32;
|
||||
arr1 : array<i32>;
|
||||
};
|
||||
|
||||
@group(0) @binding(1) var<storage, read> sb2 : SB2;
|
||||
|
||||
struct SB2 {
|
||||
x : i32;
|
||||
arr2 : array<vec4<f32>>;
|
||||
};
|
||||
|
||||
@group(0) @binding(2) var<storage, read> sb3 : array<i32>;
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@internal(intrinsic_buffer_size)
|
||||
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB1, result : ptr<function, u32>)
|
||||
|
||||
@internal(intrinsic_buffer_size)
|
||||
fn tint_symbol_3(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB2, result : ptr<function, u32>)
|
||||
|
||||
@internal(intrinsic_buffer_size)
|
||||
fn tint_symbol_6(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<i32>, result : ptr<function, u32>)
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var tint_symbol_1 : u32 = 0u;
|
||||
tint_symbol(sb1, &(tint_symbol_1));
|
||||
let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
|
||||
var tint_symbol_4 : u32 = 0u;
|
||||
tint_symbol_3(sb2, &(tint_symbol_4));
|
||||
let tint_symbol_5 : u32 = ((tint_symbol_4 - 16u) / 16u);
|
||||
var tint_symbol_7 : u32 = 0u;
|
||||
tint_symbol_6(sb3, &(tint_symbol_7));
|
||||
let tint_symbol_8 : u32 = (tint_symbol_7 / 4u);
|
||||
var len1 : u32 = tint_symbol_2;
|
||||
var len2 : u32 = tint_symbol_5;
|
||||
var len3 : u32 = tint_symbol_8;
|
||||
var x : u32 = ((len1 + len2) + len3);
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> sb1 : SB1;
|
||||
|
||||
struct SB1 {
|
||||
x : i32;
|
||||
arr1 : array<i32>;
|
||||
}
|
||||
|
||||
@group(0) @binding(1) var<storage, read> sb2 : SB2;
|
||||
|
||||
struct SB2 {
|
||||
x : i32;
|
||||
arr2 : array<vec4<f32>>;
|
||||
}
|
||||
|
||||
@group(0) @binding(2) var<storage, read> sb3 : array<i32>;
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
779
src/tint/transform/canonicalize_entry_point_io.cc
Normal file
779
src/tint/transform/canonicalize_entry_point_io.cc
Normal file
@@ -0,0 +1,779 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/canonicalize_entry_point_io.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/tint/ast/disable_validation_attribute.h"
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/function.h"
|
||||
#include "src/tint/transform/unshadow.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::CanonicalizeEntryPointIO);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::CanonicalizeEntryPointIO::Config);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
CanonicalizeEntryPointIO::CanonicalizeEntryPointIO() = default;
|
||||
CanonicalizeEntryPointIO::~CanonicalizeEntryPointIO() = default;
|
||||
|
||||
namespace {
|
||||
|
||||
// Comparison function used to reorder struct members such that all members with
|
||||
// location attributes appear first (ordered by location slot), followed by
|
||||
// those with builtin attributes.
|
||||
bool StructMemberComparator(const ast::StructMember* a,
|
||||
const ast::StructMember* b) {
|
||||
auto* a_loc = ast::GetAttribute<ast::LocationAttribute>(a->attributes);
|
||||
auto* b_loc = ast::GetAttribute<ast::LocationAttribute>(b->attributes);
|
||||
auto* a_blt = ast::GetAttribute<ast::BuiltinAttribute>(a->attributes);
|
||||
auto* b_blt = ast::GetAttribute<ast::BuiltinAttribute>(b->attributes);
|
||||
if (a_loc) {
|
||||
if (!b_loc) {
|
||||
// `a` has location attribute and `b` does not: `a` goes first.
|
||||
return true;
|
||||
}
|
||||
// Both have location attributes: smallest goes first.
|
||||
return a_loc->value < b_loc->value;
|
||||
} else {
|
||||
if (b_loc) {
|
||||
// `b` has location attribute and `a` does not: `b` goes first.
|
||||
return false;
|
||||
}
|
||||
// Both are builtins: order doesn't matter, just use enum value.
|
||||
return a_blt->builtin < b_blt->builtin;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns true if `attr` is a shader IO attribute.
|
||||
bool IsShaderIOAttribute(const ast::Attribute* attr) {
|
||||
return attr->IsAnyOf<ast::BuiltinAttribute, ast::InterpolateAttribute,
|
||||
ast::InvariantAttribute, ast::LocationAttribute>();
|
||||
}
|
||||
|
||||
// Returns true if `attrs` contains a `sample_mask` builtin.
|
||||
bool HasSampleMask(const ast::AttributeList& attrs) {
|
||||
auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(attrs);
|
||||
return builtin && builtin->builtin == ast::Builtin::kSampleMask;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/// State holds the current transform state for a single entry point.
|
||||
struct CanonicalizeEntryPointIO::State {
|
||||
/// OutputValue represents a shader result that the wrapper function produces.
|
||||
struct OutputValue {
|
||||
/// The name of the output value.
|
||||
std::string name;
|
||||
/// The type of the output value.
|
||||
const ast::Type* type;
|
||||
/// The shader IO attributes.
|
||||
ast::AttributeList attributes;
|
||||
/// The value itself.
|
||||
const ast::Expression* value;
|
||||
};
|
||||
|
||||
/// The clone context.
|
||||
CloneContext& ctx;
|
||||
/// The transform config.
|
||||
CanonicalizeEntryPointIO::Config const cfg;
|
||||
/// The entry point function (AST).
|
||||
const ast::Function* func_ast;
|
||||
/// The entry point function (SEM).
|
||||
const sem::Function* func_sem;
|
||||
|
||||
/// The new entry point wrapper function's parameters.
|
||||
ast::VariableList wrapper_ep_parameters;
|
||||
/// The members of the wrapper function's struct parameter.
|
||||
ast::StructMemberList wrapper_struct_param_members;
|
||||
/// The name of the wrapper function's struct parameter.
|
||||
Symbol wrapper_struct_param_name;
|
||||
/// The parameters that will be passed to the original function.
|
||||
ast::ExpressionList inner_call_parameters;
|
||||
/// The members of the wrapper function's struct return type.
|
||||
ast::StructMemberList wrapper_struct_output_members;
|
||||
/// The wrapper function output values.
|
||||
std::vector<OutputValue> wrapper_output_values;
|
||||
/// The body of the wrapper function.
|
||||
ast::StatementList wrapper_body;
|
||||
/// Input names used by the entrypoint
|
||||
std::unordered_set<std::string> input_names;
|
||||
|
||||
/// Constructor
|
||||
/// @param context the clone context
|
||||
/// @param config the transform config
|
||||
/// @param function the entry point function
|
||||
State(CloneContext& context,
|
||||
const CanonicalizeEntryPointIO::Config& config,
|
||||
const ast::Function* function)
|
||||
: ctx(context),
|
||||
cfg(config),
|
||||
func_ast(function),
|
||||
func_sem(ctx.src->Sem().Get(function)) {}
|
||||
|
||||
/// Clones the shader IO attributes from `src`.
|
||||
/// @param src the attributes to clone
|
||||
/// @param do_interpolate whether to clone InterpolateAttribute
|
||||
/// @return the cloned attributes
|
||||
ast::AttributeList CloneShaderIOAttributes(const ast::AttributeList& src,
|
||||
bool do_interpolate) {
|
||||
ast::AttributeList new_attributes;
|
||||
for (auto* attr : src) {
|
||||
if (IsShaderIOAttribute(attr) &&
|
||||
(do_interpolate || !attr->Is<ast::InterpolateAttribute>())) {
|
||||
new_attributes.push_back(ctx.Clone(attr));
|
||||
}
|
||||
}
|
||||
return new_attributes;
|
||||
}
|
||||
|
||||
/// Create or return a symbol for the wrapper function's struct parameter.
|
||||
/// @returns the symbol for the struct parameter
|
||||
Symbol InputStructSymbol() {
|
||||
if (!wrapper_struct_param_name.IsValid()) {
|
||||
wrapper_struct_param_name = ctx.dst->Sym();
|
||||
}
|
||||
return wrapper_struct_param_name;
|
||||
}
|
||||
|
||||
/// Add a shader input to the entry point.
|
||||
/// @param name the name of the shader input
|
||||
/// @param type the type of the shader input
|
||||
/// @param attributes the attributes to apply to the shader input
|
||||
/// @returns an expression which evaluates to the value of the shader input
|
||||
const ast::Expression* AddInput(std::string name,
|
||||
const sem::Type* type,
|
||||
ast::AttributeList attributes) {
|
||||
auto* ast_type = CreateASTTypeFor(ctx, type);
|
||||
if (cfg.shader_style == ShaderStyle::kSpirv ||
|
||||
cfg.shader_style == ShaderStyle::kGlsl) {
|
||||
// Vulkan requires that integer user-defined fragment inputs are
|
||||
// always decorated with `Flat`.
|
||||
// TODO(crbug.com/tint/1224): Remove this once a flat interpolation
|
||||
// attribute is required for integers.
|
||||
if (type->is_integer_scalar_or_vector() &&
|
||||
ast::HasAttribute<ast::LocationAttribute>(attributes) &&
|
||||
!ast::HasAttribute<ast::InterpolateAttribute>(attributes) &&
|
||||
func_ast->PipelineStage() == ast::PipelineStage::kFragment) {
|
||||
attributes.push_back(ctx.dst->Interpolate(
|
||||
ast::InterpolationType::kFlat, ast::InterpolationSampling::kNone));
|
||||
}
|
||||
|
||||
// Disable validation for use of the `input` storage class.
|
||||
attributes.push_back(
|
||||
ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
|
||||
|
||||
// In GLSL, if it's a builtin, override the name with the
|
||||
// corresponding gl_ builtin name
|
||||
auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(attributes);
|
||||
if (cfg.shader_style == ShaderStyle::kGlsl && builtin) {
|
||||
name = GLSLBuiltinToString(builtin->builtin, func_ast->PipelineStage(),
|
||||
ast::StorageClass::kInput);
|
||||
}
|
||||
auto symbol = ctx.dst->Symbols().New(name);
|
||||
|
||||
// Create the global variable and use its value for the shader input.
|
||||
const ast::Expression* value = ctx.dst->Expr(symbol);
|
||||
|
||||
if (builtin) {
|
||||
if (cfg.shader_style == ShaderStyle::kGlsl) {
|
||||
value = FromGLSLBuiltin(builtin->builtin, value, ast_type);
|
||||
} else if (builtin->builtin == ast::Builtin::kSampleMask) {
|
||||
// Vulkan requires the type of a SampleMask builtin to be an array.
|
||||
// Declare it as array<u32, 1> and then load the first element.
|
||||
ast_type = ctx.dst->ty.array(ast_type, 1);
|
||||
value = ctx.dst->IndexAccessor(value, 0);
|
||||
}
|
||||
}
|
||||
ctx.dst->Global(symbol, ast_type, ast::StorageClass::kInput,
|
||||
std::move(attributes));
|
||||
return value;
|
||||
} else if (cfg.shader_style == ShaderStyle::kMsl &&
|
||||
ast::HasAttribute<ast::BuiltinAttribute>(attributes)) {
|
||||
// If this input is a builtin and we are targeting MSL, then add it to the
|
||||
// parameter list and pass it directly to the inner function.
|
||||
Symbol symbol = input_names.emplace(name).second
|
||||
? ctx.dst->Symbols().Register(name)
|
||||
: ctx.dst->Symbols().New(name);
|
||||
wrapper_ep_parameters.push_back(
|
||||
ctx.dst->Param(symbol, ast_type, std::move(attributes)));
|
||||
return ctx.dst->Expr(symbol);
|
||||
} else {
|
||||
// Otherwise, move it to the new structure member list.
|
||||
Symbol symbol = input_names.emplace(name).second
|
||||
? ctx.dst->Symbols().Register(name)
|
||||
: ctx.dst->Symbols().New(name);
|
||||
wrapper_struct_param_members.push_back(
|
||||
ctx.dst->Member(symbol, ast_type, std::move(attributes)));
|
||||
return ctx.dst->MemberAccessor(InputStructSymbol(), symbol);
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a shader output to the entry point.
|
||||
/// @param name the name of the shader output
|
||||
/// @param type the type of the shader output
|
||||
/// @param attributes the attributes to apply to the shader output
|
||||
/// @param value the value of the shader output
|
||||
void AddOutput(std::string name,
|
||||
const sem::Type* type,
|
||||
ast::AttributeList attributes,
|
||||
const ast::Expression* value) {
|
||||
// Vulkan requires that integer user-defined vertex outputs are
|
||||
// always decorated with `Flat`.
|
||||
// TODO(crbug.com/tint/1224): Remove this once a flat interpolation
|
||||
// attribute is required for integers.
|
||||
if (cfg.shader_style == ShaderStyle::kSpirv &&
|
||||
type->is_integer_scalar_or_vector() &&
|
||||
ast::HasAttribute<ast::LocationAttribute>(attributes) &&
|
||||
!ast::HasAttribute<ast::InterpolateAttribute>(attributes) &&
|
||||
func_ast->PipelineStage() == ast::PipelineStage::kVertex) {
|
||||
attributes.push_back(ctx.dst->Interpolate(
|
||||
ast::InterpolationType::kFlat, ast::InterpolationSampling::kNone));
|
||||
}
|
||||
|
||||
// In GLSL, if it's a builtin, override the name with the
|
||||
// corresponding gl_ builtin name
|
||||
if (cfg.shader_style == ShaderStyle::kGlsl) {
|
||||
if (auto* b = ast::GetAttribute<ast::BuiltinAttribute>(attributes)) {
|
||||
name = GLSLBuiltinToString(b->builtin, func_ast->PipelineStage(),
|
||||
ast::StorageClass::kOutput);
|
||||
value = ToGLSLBuiltin(b->builtin, value, type);
|
||||
}
|
||||
}
|
||||
|
||||
OutputValue output;
|
||||
output.name = name;
|
||||
output.type = CreateASTTypeFor(ctx, type);
|
||||
output.attributes = std::move(attributes);
|
||||
output.value = value;
|
||||
wrapper_output_values.push_back(output);
|
||||
}
|
||||
|
||||
/// Process a non-struct parameter.
|
||||
/// This creates a new object for the shader input, moving the shader IO
|
||||
/// attributes to it. It also adds an expression to the list of parameters
|
||||
/// that will be passed to the original function.
|
||||
/// @param param the original function parameter
|
||||
void ProcessNonStructParameter(const sem::Parameter* param) {
|
||||
// Remove the shader IO attributes from the inner function parameter, and
|
||||
// attach them to the new object instead.
|
||||
ast::AttributeList attributes;
|
||||
for (auto* attr : param->Declaration()->attributes) {
|
||||
if (IsShaderIOAttribute(attr)) {
|
||||
ctx.Remove(param->Declaration()->attributes, attr);
|
||||
attributes.push_back(ctx.Clone(attr));
|
||||
}
|
||||
}
|
||||
|
||||
auto name = ctx.src->Symbols().NameFor(param->Declaration()->symbol);
|
||||
auto* input_expr = AddInput(name, param->Type(), std::move(attributes));
|
||||
inner_call_parameters.push_back(input_expr);
|
||||
}
|
||||
|
||||
/// Process a struct parameter.
|
||||
/// This creates new objects for each struct member, moving the shader IO
|
||||
/// attributes to them. It also creates the structure that will be passed to
|
||||
/// the original function.
|
||||
/// @param param the original function parameter
|
||||
void ProcessStructParameter(const sem::Parameter* param) {
|
||||
auto* str = param->Type()->As<sem::Struct>();
|
||||
|
||||
// Recreate struct members in the outer entry point and build an initializer
|
||||
// list to pass them through to the inner function.
|
||||
ast::ExpressionList inner_struct_values;
|
||||
for (auto* member : str->Members()) {
|
||||
if (member->Type()->Is<sem::Struct>()) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics()) << "nested IO struct";
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* member_ast = member->Declaration();
|
||||
auto name = ctx.src->Symbols().NameFor(member_ast->symbol);
|
||||
|
||||
// In GLSL, do not add interpolation attributes on vertex input
|
||||
bool do_interpolate = true;
|
||||
if (cfg.shader_style == ShaderStyle::kGlsl &&
|
||||
func_ast->PipelineStage() == ast::PipelineStage::kVertex) {
|
||||
do_interpolate = false;
|
||||
}
|
||||
auto attributes =
|
||||
CloneShaderIOAttributes(member_ast->attributes, do_interpolate);
|
||||
auto* input_expr = AddInput(name, member->Type(), std::move(attributes));
|
||||
inner_struct_values.push_back(input_expr);
|
||||
}
|
||||
|
||||
// Construct the original structure using the new shader input objects.
|
||||
inner_call_parameters.push_back(ctx.dst->Construct(
|
||||
ctx.Clone(param->Declaration()->type), inner_struct_values));
|
||||
}
|
||||
|
||||
/// Process the entry point return type.
|
||||
/// This generates a list of output values that are returned by the original
|
||||
/// function.
|
||||
/// @param inner_ret_type the original function return type
|
||||
/// @param original_result the result object produced by the original function
|
||||
void ProcessReturnType(const sem::Type* inner_ret_type,
|
||||
Symbol original_result) {
|
||||
bool do_interpolate = true;
|
||||
// In GLSL, do not add interpolation attributes on fragment output
|
||||
if (cfg.shader_style == ShaderStyle::kGlsl &&
|
||||
func_ast->PipelineStage() == ast::PipelineStage::kFragment) {
|
||||
do_interpolate = false;
|
||||
}
|
||||
if (auto* str = inner_ret_type->As<sem::Struct>()) {
|
||||
for (auto* member : str->Members()) {
|
||||
if (member->Type()->Is<sem::Struct>()) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics()) << "nested IO struct";
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* member_ast = member->Declaration();
|
||||
auto name = ctx.src->Symbols().NameFor(member_ast->symbol);
|
||||
auto attributes =
|
||||
CloneShaderIOAttributes(member_ast->attributes, do_interpolate);
|
||||
|
||||
// Extract the original structure member.
|
||||
AddOutput(name, member->Type(), std::move(attributes),
|
||||
ctx.dst->MemberAccessor(original_result, name));
|
||||
}
|
||||
} else if (!inner_ret_type->Is<sem::Void>()) {
|
||||
auto attributes = CloneShaderIOAttributes(
|
||||
func_ast->return_type_attributes, do_interpolate);
|
||||
|
||||
// Propagate the non-struct return value as is.
|
||||
AddOutput("value", func_sem->ReturnType(), std::move(attributes),
|
||||
ctx.dst->Expr(original_result));
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a fixed sample mask to the wrapper function output.
|
||||
/// If there is already a sample mask, bitwise-and it with the fixed mask.
|
||||
/// Otherwise, create a new output value from the fixed mask.
|
||||
void AddFixedSampleMask() {
|
||||
// Check the existing output values for a sample mask builtin.
|
||||
for (auto& outval : wrapper_output_values) {
|
||||
if (HasSampleMask(outval.attributes)) {
|
||||
// Combine the authored sample mask with the fixed mask.
|
||||
outval.value = ctx.dst->And(outval.value, cfg.fixed_sample_mask);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// No existing sample mask builtin was found, so create a new output value
|
||||
// using the fixed sample mask.
|
||||
AddOutput("fixed_sample_mask", ctx.dst->create<sem::U32>(),
|
||||
{ctx.dst->Builtin(ast::Builtin::kSampleMask)},
|
||||
ctx.dst->Expr(cfg.fixed_sample_mask));
|
||||
}
|
||||
|
||||
/// Add a point size builtin to the wrapper function output.
|
||||
void AddVertexPointSize() {
|
||||
// Create a new output value and assign it a literal 1.0 value.
|
||||
AddOutput("vertex_point_size", ctx.dst->create<sem::F32>(),
|
||||
{ctx.dst->Builtin(ast::Builtin::kPointSize)}, ctx.dst->Expr(1.f));
|
||||
}
|
||||
|
||||
/// Create an expression for gl_Position.[component]
|
||||
/// @param component the component of gl_Position to access
|
||||
/// @returns the new expression
|
||||
const ast::Expression* GLPosition(const char* component) {
|
||||
Symbol pos = ctx.dst->Symbols().Register("gl_Position");
|
||||
Symbol c = ctx.dst->Symbols().Register(component);
|
||||
return ctx.dst->MemberAccessor(ctx.dst->Expr(pos), ctx.dst->Expr(c));
|
||||
}
|
||||
|
||||
/// Create the wrapper function's struct parameter and type objects.
|
||||
void CreateInputStruct() {
|
||||
// Sort the struct members to satisfy HLSL interfacing matching rules.
|
||||
std::sort(wrapper_struct_param_members.begin(),
|
||||
wrapper_struct_param_members.end(), StructMemberComparator);
|
||||
|
||||
// Create the new struct type.
|
||||
auto struct_name = ctx.dst->Sym();
|
||||
auto* in_struct = ctx.dst->create<ast::Struct>(
|
||||
struct_name, wrapper_struct_param_members, ast::AttributeList{});
|
||||
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, in_struct);
|
||||
|
||||
// Create a new function parameter using this struct type.
|
||||
auto* param =
|
||||
ctx.dst->Param(InputStructSymbol(), ctx.dst->ty.type_name(struct_name));
|
||||
wrapper_ep_parameters.push_back(param);
|
||||
}
|
||||
|
||||
/// Create and return the wrapper function's struct result object.
|
||||
/// @returns the struct type
|
||||
ast::Struct* CreateOutputStruct() {
|
||||
ast::StatementList assignments;
|
||||
|
||||
auto wrapper_result = ctx.dst->Symbols().New("wrapper_result");
|
||||
|
||||
// Create the struct members and their corresponding assignment statements.
|
||||
std::unordered_set<std::string> member_names;
|
||||
for (auto& outval : wrapper_output_values) {
|
||||
// Use the original output name, unless that is already taken.
|
||||
Symbol name;
|
||||
if (member_names.count(outval.name)) {
|
||||
name = ctx.dst->Symbols().New(outval.name);
|
||||
} else {
|
||||
name = ctx.dst->Symbols().Register(outval.name);
|
||||
}
|
||||
member_names.insert(ctx.dst->Symbols().NameFor(name));
|
||||
|
||||
wrapper_struct_output_members.push_back(
|
||||
ctx.dst->Member(name, outval.type, std::move(outval.attributes)));
|
||||
assignments.push_back(ctx.dst->Assign(
|
||||
ctx.dst->MemberAccessor(wrapper_result, name), outval.value));
|
||||
}
|
||||
|
||||
// Sort the struct members to satisfy HLSL interfacing matching rules.
|
||||
std::sort(wrapper_struct_output_members.begin(),
|
||||
wrapper_struct_output_members.end(), StructMemberComparator);
|
||||
|
||||
// Create the new struct type.
|
||||
auto* out_struct = ctx.dst->create<ast::Struct>(
|
||||
ctx.dst->Sym(), wrapper_struct_output_members, ast::AttributeList{});
|
||||
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, out_struct);
|
||||
|
||||
// Create the output struct object, assign its members, and return it.
|
||||
auto* result_object =
|
||||
ctx.dst->Var(wrapper_result, ctx.dst->ty.type_name(out_struct->name));
|
||||
wrapper_body.push_back(ctx.dst->Decl(result_object));
|
||||
wrapper_body.insert(wrapper_body.end(), assignments.begin(),
|
||||
assignments.end());
|
||||
wrapper_body.push_back(ctx.dst->Return(wrapper_result));
|
||||
|
||||
return out_struct;
|
||||
}
|
||||
|
||||
/// Create and assign the wrapper function's output variables.
|
||||
void CreateGlobalOutputVariables() {
|
||||
for (auto& outval : wrapper_output_values) {
|
||||
// Disable validation for use of the `output` storage class.
|
||||
ast::AttributeList attributes = std::move(outval.attributes);
|
||||
attributes.push_back(
|
||||
ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
|
||||
|
||||
// Create the global variable and assign it the output value.
|
||||
auto name = ctx.dst->Symbols().New(outval.name);
|
||||
auto* type = outval.type;
|
||||
const ast::Expression* lhs = ctx.dst->Expr(name);
|
||||
if (HasSampleMask(attributes)) {
|
||||
// Vulkan requires the type of a SampleMask builtin to be an array.
|
||||
// Declare it as array<u32, 1> and then store to the first element.
|
||||
type = ctx.dst->ty.array(type, 1);
|
||||
lhs = ctx.dst->IndexAccessor(lhs, 0);
|
||||
}
|
||||
ctx.dst->Global(name, type, ast::StorageClass::kOutput,
|
||||
std::move(attributes));
|
||||
wrapper_body.push_back(ctx.dst->Assign(lhs, outval.value));
|
||||
}
|
||||
}
|
||||
|
||||
// Recreate the original function without entry point attributes and call it.
|
||||
/// @returns the inner function call expression
|
||||
const ast::CallExpression* CallInnerFunction() {
|
||||
Symbol inner_name;
|
||||
if (cfg.shader_style == ShaderStyle::kGlsl) {
|
||||
// In GLSL, clone the original entry point name, as the wrapper will be
|
||||
// called "main".
|
||||
inner_name = ctx.Clone(func_ast->symbol);
|
||||
} else {
|
||||
// Add a suffix to the function name, as the wrapper function will take
|
||||
// the original entry point name.
|
||||
auto ep_name = ctx.src->Symbols().NameFor(func_ast->symbol);
|
||||
inner_name = ctx.dst->Symbols().New(ep_name + "_inner");
|
||||
}
|
||||
|
||||
// Clone everything, dropping the function and return type attributes.
|
||||
// The parameter attributes will have already been stripped during
|
||||
// processing.
|
||||
auto* inner_function = ctx.dst->create<ast::Function>(
|
||||
inner_name, ctx.Clone(func_ast->params),
|
||||
ctx.Clone(func_ast->return_type), ctx.Clone(func_ast->body),
|
||||
ast::AttributeList{}, ast::AttributeList{});
|
||||
ctx.Replace(func_ast, inner_function);
|
||||
|
||||
// Call the function.
|
||||
return ctx.dst->Call(inner_function->symbol, inner_call_parameters);
|
||||
}
|
||||
|
||||
/// Process the entry point function.
|
||||
void Process() {
|
||||
bool needs_fixed_sample_mask = false;
|
||||
bool needs_vertex_point_size = false;
|
||||
if (func_ast->PipelineStage() == ast::PipelineStage::kFragment &&
|
||||
cfg.fixed_sample_mask != 0xFFFFFFFF) {
|
||||
needs_fixed_sample_mask = true;
|
||||
}
|
||||
if (func_ast->PipelineStage() == ast::PipelineStage::kVertex &&
|
||||
cfg.emit_vertex_point_size) {
|
||||
needs_vertex_point_size = true;
|
||||
}
|
||||
|
||||
// Exit early if there is no shader IO to handle.
|
||||
if (func_sem->Parameters().size() == 0 &&
|
||||
func_sem->ReturnType()->Is<sem::Void>() && !needs_fixed_sample_mask &&
|
||||
!needs_vertex_point_size && cfg.shader_style != ShaderStyle::kGlsl) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Process the entry point parameters, collecting those that need to be
|
||||
// aggregated into a single structure.
|
||||
if (!func_sem->Parameters().empty()) {
|
||||
for (auto* param : func_sem->Parameters()) {
|
||||
if (param->Type()->Is<sem::Struct>()) {
|
||||
ProcessStructParameter(param);
|
||||
} else {
|
||||
ProcessNonStructParameter(param);
|
||||
}
|
||||
}
|
||||
|
||||
// Create a structure parameter for the outer entry point if necessary.
|
||||
if (!wrapper_struct_param_members.empty()) {
|
||||
CreateInputStruct();
|
||||
}
|
||||
}
|
||||
|
||||
// Recreate the original function and call it.
|
||||
auto* call_inner = CallInnerFunction();
|
||||
|
||||
// Process the return type, and start building the wrapper function body.
|
||||
std::function<const ast::Type*()> wrapper_ret_type = [&] {
|
||||
return ctx.dst->ty.void_();
|
||||
};
|
||||
if (func_sem->ReturnType()->Is<sem::Void>()) {
|
||||
// The function call is just a statement with no result.
|
||||
wrapper_body.push_back(ctx.dst->CallStmt(call_inner));
|
||||
} else {
|
||||
// Capture the result of calling the original function.
|
||||
auto* inner_result = ctx.dst->Const(
|
||||
ctx.dst->Symbols().New("inner_result"), nullptr, call_inner);
|
||||
wrapper_body.push_back(ctx.dst->Decl(inner_result));
|
||||
|
||||
// Process the original return type to determine the outputs that the
|
||||
// outer function needs to produce.
|
||||
ProcessReturnType(func_sem->ReturnType(), inner_result->symbol);
|
||||
}
|
||||
|
||||
// Add a fixed sample mask, if necessary.
|
||||
if (needs_fixed_sample_mask) {
|
||||
AddFixedSampleMask();
|
||||
}
|
||||
|
||||
// Add the pointsize builtin, if necessary.
|
||||
if (needs_vertex_point_size) {
|
||||
AddVertexPointSize();
|
||||
}
|
||||
|
||||
// Produce the entry point outputs, if necessary.
|
||||
if (!wrapper_output_values.empty()) {
|
||||
if (cfg.shader_style == ShaderStyle::kSpirv ||
|
||||
cfg.shader_style == ShaderStyle::kGlsl) {
|
||||
CreateGlobalOutputVariables();
|
||||
} else {
|
||||
auto* output_struct = CreateOutputStruct();
|
||||
wrapper_ret_type = [&, output_struct] {
|
||||
return ctx.dst->ty.type_name(output_struct->name);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (cfg.shader_style == ShaderStyle::kGlsl &&
|
||||
func_ast->PipelineStage() == ast::PipelineStage::kVertex) {
|
||||
auto* pos_y = GLPosition("y");
|
||||
auto* negate_pos_y = ctx.dst->create<ast::UnaryOpExpression>(
|
||||
ast::UnaryOp::kNegation, GLPosition("y"));
|
||||
wrapper_body.push_back(ctx.dst->Assign(pos_y, negate_pos_y));
|
||||
|
||||
auto* two_z = ctx.dst->Mul(ctx.dst->Expr(2.0f), GLPosition("z"));
|
||||
auto* fixed_z = ctx.dst->Sub(two_z, GLPosition("w"));
|
||||
wrapper_body.push_back(ctx.dst->Assign(GLPosition("z"), fixed_z));
|
||||
}
|
||||
|
||||
// Create the wrapper entry point function.
|
||||
// For GLSL, use "main", otherwise take the name of the original
|
||||
// entry point function.
|
||||
Symbol name;
|
||||
if (cfg.shader_style == ShaderStyle::kGlsl) {
|
||||
name = ctx.dst->Symbols().New("main");
|
||||
} else {
|
||||
name = ctx.Clone(func_ast->symbol);
|
||||
}
|
||||
|
||||
auto* wrapper_func = ctx.dst->create<ast::Function>(
|
||||
name, wrapper_ep_parameters, wrapper_ret_type(),
|
||||
ctx.dst->Block(wrapper_body), ctx.Clone(func_ast->attributes),
|
||||
ast::AttributeList{});
|
||||
ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), func_ast,
|
||||
wrapper_func);
|
||||
}
|
||||
|
||||
/// Retrieve the gl_ string corresponding to a builtin.
|
||||
/// @param builtin the builtin
|
||||
/// @param stage the current pipeline stage
|
||||
/// @param storage_class the storage class (input or output)
|
||||
/// @returns the gl_ string corresponding to that builtin
|
||||
const char* GLSLBuiltinToString(ast::Builtin builtin,
|
||||
ast::PipelineStage stage,
|
||||
ast::StorageClass storage_class) {
|
||||
switch (builtin) {
|
||||
case ast::Builtin::kPosition:
|
||||
switch (stage) {
|
||||
case ast::PipelineStage::kVertex:
|
||||
return "gl_Position";
|
||||
case ast::PipelineStage::kFragment:
|
||||
return "gl_FragCoord";
|
||||
default:
|
||||
return "";
|
||||
}
|
||||
case ast::Builtin::kVertexIndex:
|
||||
return "gl_VertexID";
|
||||
case ast::Builtin::kInstanceIndex:
|
||||
return "gl_InstanceID";
|
||||
case ast::Builtin::kFrontFacing:
|
||||
return "gl_FrontFacing";
|
||||
case ast::Builtin::kFragDepth:
|
||||
return "gl_FragDepth";
|
||||
case ast::Builtin::kLocalInvocationId:
|
||||
return "gl_LocalInvocationID";
|
||||
case ast::Builtin::kLocalInvocationIndex:
|
||||
return "gl_LocalInvocationIndex";
|
||||
case ast::Builtin::kGlobalInvocationId:
|
||||
return "gl_GlobalInvocationID";
|
||||
case ast::Builtin::kNumWorkgroups:
|
||||
return "gl_NumWorkGroups";
|
||||
case ast::Builtin::kWorkgroupId:
|
||||
return "gl_WorkGroupID";
|
||||
case ast::Builtin::kSampleIndex:
|
||||
return "gl_SampleID";
|
||||
case ast::Builtin::kSampleMask:
|
||||
if (storage_class == ast::StorageClass::kInput) {
|
||||
return "gl_SampleMaskIn";
|
||||
} else {
|
||||
return "gl_SampleMask";
|
||||
}
|
||||
default:
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a given GLSL builtin value to the corresponding WGSL value.
|
||||
/// @param builtin the builtin variable
|
||||
/// @param value the value to convert
|
||||
/// @param ast_type (inout) the incoming WGSL and outgoing GLSL types
|
||||
/// @returns an expression representing the GLSL builtin converted to what
|
||||
/// WGSL expects
|
||||
const ast::Expression* FromGLSLBuiltin(ast::Builtin builtin,
|
||||
const ast::Expression* value,
|
||||
const ast::Type*& ast_type) {
|
||||
switch (builtin) {
|
||||
case ast::Builtin::kVertexIndex:
|
||||
case ast::Builtin::kInstanceIndex:
|
||||
case ast::Builtin::kSampleIndex:
|
||||
// GLSL uses i32 for these, so bitcast to u32.
|
||||
value = ctx.dst->Bitcast(ast_type, value);
|
||||
ast_type = ctx.dst->ty.i32();
|
||||
break;
|
||||
case ast::Builtin::kSampleMask:
|
||||
// gl_SampleMask is an array of i32. Retrieve the first element and
|
||||
// bitcast it to u32.
|
||||
value = ctx.dst->IndexAccessor(value, 0);
|
||||
value = ctx.dst->Bitcast(ast_type, value);
|
||||
ast_type = ctx.dst->ty.array(ctx.dst->ty.i32(), 1);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
/// Convert a given WGSL value to the type expected when assigning to a
|
||||
/// GLSL builtin.
|
||||
/// @param builtin the builtin variable
|
||||
/// @param value the value to convert
|
||||
/// @param type (out) the type to which the value was converted
|
||||
/// @returns the converted value which can be assigned to the GLSL builtin
|
||||
const ast::Expression* ToGLSLBuiltin(ast::Builtin builtin,
|
||||
const ast::Expression* value,
|
||||
const sem::Type*& type) {
|
||||
switch (builtin) {
|
||||
case ast::Builtin::kVertexIndex:
|
||||
case ast::Builtin::kInstanceIndex:
|
||||
case ast::Builtin::kSampleIndex:
|
||||
case ast::Builtin::kSampleMask:
|
||||
type = ctx.dst->create<sem::I32>();
|
||||
value = ctx.dst->Bitcast(CreateASTTypeFor(ctx, type), value);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
};
|
||||
|
||||
void CanonicalizeEntryPointIO::Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap&) const {
|
||||
auto* cfg = inputs.Get<Config>();
|
||||
if (cfg == nullptr) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
diag::System::Transform,
|
||||
"missing transform data for " + std::string(TypeInfo().name));
|
||||
return;
|
||||
}
|
||||
|
||||
// Remove entry point IO attributes from struct declarations.
|
||||
// New structures will be created for each entry point, as necessary.
|
||||
for (auto* ty : ctx.src->AST().TypeDecls()) {
|
||||
if (auto* struct_ty = ty->As<ast::Struct>()) {
|
||||
for (auto* member : struct_ty->members) {
|
||||
for (auto* attr : member->attributes) {
|
||||
if (IsShaderIOAttribute(attr)) {
|
||||
ctx.Remove(member->attributes, attr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* func_ast : ctx.src->AST().Functions()) {
|
||||
if (!func_ast->IsEntryPoint()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
State state(ctx, *cfg, func_ast);
|
||||
state.Process();
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
CanonicalizeEntryPointIO::Config::Config(ShaderStyle style,
|
||||
uint32_t sample_mask,
|
||||
bool emit_point_size)
|
||||
: shader_style(style),
|
||||
fixed_sample_mask(sample_mask),
|
||||
emit_vertex_point_size(emit_point_size) {}
|
||||
|
||||
CanonicalizeEntryPointIO::Config::Config(const Config&) = default;
|
||||
CanonicalizeEntryPointIO::Config::~Config() = default;
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
149
src/tint/transform/canonicalize_entry_point_io.h
Normal file
149
src/tint/transform/canonicalize_entry_point_io.h
Normal file
@@ -0,0 +1,149 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_CANONICALIZE_ENTRY_POINT_IO_H_
|
||||
#define SRC_TINT_TRANSFORM_CANONICALIZE_ENTRY_POINT_IO_H_
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// CanonicalizeEntryPointIO is a transform used to rewrite shader entry point
|
||||
/// interfaces into a form that the generators can handle. Each entry point
|
||||
/// function is stripped of all shader IO attributes and wrapped in a function
|
||||
/// that provides the shader interface.
|
||||
/// The transform config determines whether to use global variables, structures,
|
||||
/// or parameters for the shader inputs and outputs, and optionally adds
|
||||
/// additional builtins to the shader interface.
|
||||
///
|
||||
/// Before:
|
||||
/// ```
|
||||
/// struct Locations{
|
||||
/// @location(1) loc1 : f32;
|
||||
/// @location(2) loc2 : vec4<u32>;
|
||||
/// };
|
||||
///
|
||||
/// @stage(fragment)
|
||||
/// fn frag_main(@builtin(position) coord : vec4<f32>,
|
||||
/// locations : Locations) -> @location(0) f32 {
|
||||
/// if (coord.w > 1.0) {
|
||||
/// return 0.0;
|
||||
/// }
|
||||
/// var col : f32 = (coord.x * locations.loc1);
|
||||
/// return col;
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// After (using structures for all parameters):
|
||||
/// ```
|
||||
/// struct Locations{
|
||||
/// loc1 : f32;
|
||||
/// loc2 : vec4<u32>;
|
||||
/// };
|
||||
///
|
||||
/// struct frag_main_in {
|
||||
/// @builtin(position) coord : vec4<f32>;
|
||||
/// @location(1) loc1 : f32;
|
||||
/// @location(2) loc2 : vec4<u32>
|
||||
/// };
|
||||
///
|
||||
/// struct frag_main_out {
|
||||
/// @location(0) loc0 : f32;
|
||||
/// };
|
||||
///
|
||||
/// fn frag_main_inner(coord : vec4<f32>,
|
||||
/// locations : Locations) -> f32 {
|
||||
/// if (coord.w > 1.0) {
|
||||
/// return 0.0;
|
||||
/// }
|
||||
/// var col : f32 = (coord.x * locations.loc1);
|
||||
/// return col;
|
||||
/// }
|
||||
///
|
||||
/// @stage(fragment)
|
||||
/// fn frag_main(in : frag_main_in) -> frag_main_out {
|
||||
/// let inner_retval = frag_main_inner(in.coord, Locations(in.loc1, in.loc2));
|
||||
/// var wrapper_result : frag_main_out;
|
||||
/// wrapper_result.loc0 = inner_retval;
|
||||
/// return wrapper_result;
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// @note Depends on the following transforms to have been run first:
|
||||
/// * Unshadow
|
||||
class CanonicalizeEntryPointIO
|
||||
: public Castable<CanonicalizeEntryPointIO, Transform> {
|
||||
public:
|
||||
/// ShaderStyle is an enumerator of different ways to emit shader IO.
|
||||
enum class ShaderStyle {
|
||||
/// Target SPIR-V (using global variables).
|
||||
kSpirv,
|
||||
/// Target GLSL (using global variables).
|
||||
kGlsl,
|
||||
/// Target MSL (using non-struct function parameters for builtins).
|
||||
kMsl,
|
||||
/// Target HLSL (using structures for all IO).
|
||||
kHlsl,
|
||||
};
|
||||
|
||||
/// Configuration options for the transform.
|
||||
struct Config : public Castable<Config, Data> {
|
||||
/// Constructor
|
||||
/// @param style the approach to use for emitting shader IO.
|
||||
/// @param sample_mask an optional sample mask to combine with shader masks
|
||||
/// @param emit_vertex_point_size `true` to generate a pointsize builtin
|
||||
explicit Config(ShaderStyle style,
|
||||
uint32_t sample_mask = 0xFFFFFFFF,
|
||||
bool emit_vertex_point_size = false);
|
||||
|
||||
/// Copy constructor
|
||||
Config(const Config&);
|
||||
|
||||
/// Destructor
|
||||
~Config() override;
|
||||
|
||||
/// The approach to use for emitting shader IO.
|
||||
const ShaderStyle shader_style;
|
||||
|
||||
/// A fixed sample mask to combine into masks produced by fragment shaders.
|
||||
const uint32_t fixed_sample_mask;
|
||||
|
||||
/// Set to `true` to generate a pointsize builtin and have it set to 1.0
|
||||
/// from all vertex shaders in the module.
|
||||
const bool emit_vertex_point_size;
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
CanonicalizeEntryPointIO();
|
||||
~CanonicalizeEntryPointIO() override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
struct State;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_CANONICALIZE_ENTRY_POINT_IO_H_
|
||||
4041
src/tint/transform/canonicalize_entry_point_io_test.cc
Normal file
4041
src/tint/transform/canonicalize_entry_point_io_test.cc
Normal file
File diff suppressed because it is too large
Load Diff
355
src/tint/transform/combine_samplers.cc
Normal file
355
src/tint/transform/combine_samplers.cc
Normal file
@@ -0,0 +1,355 @@
|
||||
// Copyright 2022 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/combine_samplers.h"
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/tint/sem/function.h"
|
||||
#include "src/tint/sem/statement.h"
|
||||
|
||||
#include "src/tint/utils/map.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::CombineSamplers);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::CombineSamplers::BindingInfo);
|
||||
|
||||
namespace {
|
||||
|
||||
bool IsGlobal(const tint::sem::VariablePair& pair) {
|
||||
return pair.first->Is<tint::sem::GlobalVariable>() &&
|
||||
(!pair.second || pair.second->Is<tint::sem::GlobalVariable>());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
CombineSamplers::BindingInfo::BindingInfo(const BindingMap& map,
|
||||
const sem::BindingPoint& placeholder)
|
||||
: binding_map(map), placeholder_binding_point(placeholder) {}
|
||||
CombineSamplers::BindingInfo::BindingInfo(const BindingInfo& other) = default;
|
||||
CombineSamplers::BindingInfo::~BindingInfo() = default;
|
||||
|
||||
/// The PIMPL state for the CombineSamplers transform
|
||||
struct CombineSamplers::State {
|
||||
/// The clone context
|
||||
CloneContext& ctx;
|
||||
|
||||
/// The binding info
|
||||
const BindingInfo* binding_info;
|
||||
|
||||
/// Map from a texture/sampler pair to the corresponding combined sampler
|
||||
/// variable
|
||||
using CombinedTextureSamplerMap =
|
||||
std::unordered_map<sem::VariablePair, const ast::Variable*>;
|
||||
|
||||
/// Use sem::BindingPoint without scope.
|
||||
using BindingPoint = sem::BindingPoint;
|
||||
|
||||
/// A map of all global texture/sampler variable pairs to the global
|
||||
/// combined sampler variable that will replace it.
|
||||
CombinedTextureSamplerMap global_combined_texture_samplers_;
|
||||
|
||||
/// A map of all texture/sampler variable pairs that contain a function
|
||||
/// parameter to the combined sampler function paramter that will replace it.
|
||||
std::unordered_map<const sem::Function*, CombinedTextureSamplerMap>
|
||||
function_combined_texture_samplers_;
|
||||
|
||||
/// Placeholder global samplers used when a function contains texture-only
|
||||
/// references (one comparison sampler, one regular). These are also used as
|
||||
/// temporary sampler parameters to the texture builtins to satisfy the WGSL
|
||||
/// resolver, but are then ignored and removed by the GLSL writer.
|
||||
const ast::Variable* placeholder_samplers_[2] = {};
|
||||
|
||||
/// Group and binding attributes used by all combined sampler globals.
|
||||
/// Group 0 and binding 0 are used, with collisions disabled.
|
||||
/// @returns the newly-created attribute list
|
||||
ast::AttributeList Attributes() const {
|
||||
auto attributes = ctx.dst->GroupAndBinding(0, 0);
|
||||
attributes.push_back(
|
||||
ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision));
|
||||
return attributes;
|
||||
}
|
||||
|
||||
/// Constructor
|
||||
/// @param context the clone context
|
||||
/// @param info the binding map information
|
||||
State(CloneContext& context, const BindingInfo* info)
|
||||
: ctx(context), binding_info(info) {}
|
||||
|
||||
/// Creates a combined sampler global variables.
|
||||
/// (Note this is actually a Texture node at the AST level, but it will be
|
||||
/// written as the corresponding sampler (eg., sampler2D) on GLSL output.)
|
||||
/// @param texture_var the texture (global) variable
|
||||
/// @param sampler_var the sampler (global) variable
|
||||
/// @param name the default name to use (may be overridden by map lookup)
|
||||
/// @returns the newly-created global variable
|
||||
const ast::Variable* CreateCombinedGlobal(const sem::Variable* texture_var,
|
||||
const sem::Variable* sampler_var,
|
||||
std::string name) {
|
||||
SamplerTexturePair bp_pair;
|
||||
bp_pair.texture_binding_point =
|
||||
texture_var->As<sem::GlobalVariable>()->BindingPoint();
|
||||
bp_pair.sampler_binding_point =
|
||||
sampler_var ? sampler_var->As<sem::GlobalVariable>()->BindingPoint()
|
||||
: binding_info->placeholder_binding_point;
|
||||
auto it = binding_info->binding_map.find(bp_pair);
|
||||
if (it != binding_info->binding_map.end()) {
|
||||
name = it->second;
|
||||
}
|
||||
const ast::Type* type = CreateCombinedASTTypeFor(texture_var, sampler_var);
|
||||
Symbol symbol = ctx.dst->Symbols().New(name);
|
||||
return ctx.dst->Global(symbol, type, Attributes());
|
||||
}
|
||||
|
||||
/// Creates placeholder global sampler variables.
|
||||
/// @param kind the sampler kind to create for
|
||||
/// @returns the newly-created global variable
|
||||
const ast::Variable* CreatePlaceholder(ast::SamplerKind kind) {
|
||||
const ast::Type* type = ctx.dst->ty.sampler(kind);
|
||||
const char* name = kind == ast::SamplerKind::kComparisonSampler
|
||||
? "placeholder_comparison_sampler"
|
||||
: "placeholder_sampler";
|
||||
Symbol symbol = ctx.dst->Symbols().New(name);
|
||||
return ctx.dst->Global(symbol, type, Attributes());
|
||||
}
|
||||
|
||||
/// Creates ast::Type for a given texture and sampler variable pair.
|
||||
/// Depth textures with no samplers are turned into the corresponding
|
||||
/// f32 texture (e.g., texture_depth_2d -> texture_2d<f32>).
|
||||
/// @param texture the texture variable of interest
|
||||
/// @param sampler the texture variable of interest
|
||||
/// @returns the newly-created type
|
||||
const ast::Type* CreateCombinedASTTypeFor(const sem::Variable* texture,
|
||||
const sem::Variable* sampler) {
|
||||
const sem::Type* texture_type = texture->Type()->UnwrapRef();
|
||||
const sem::DepthTexture* depth = texture_type->As<sem::DepthTexture>();
|
||||
if (depth && !sampler) {
|
||||
return ctx.dst->create<ast::SampledTexture>(depth->dim(),
|
||||
ctx.dst->create<ast::F32>());
|
||||
} else {
|
||||
return CreateASTTypeFor(ctx, texture_type);
|
||||
}
|
||||
}
|
||||
|
||||
/// Performs the transformation
|
||||
void Run() {
|
||||
auto& sem = ctx.src->Sem();
|
||||
|
||||
// Remove all texture and sampler global variables. These will be replaced
|
||||
// by combined samplers.
|
||||
for (auto* var : ctx.src->AST().GlobalVariables()) {
|
||||
auto* type = sem.Get(var->type);
|
||||
if (type && type->IsAnyOf<sem::Texture, sem::Sampler>() &&
|
||||
!type->Is<sem::StorageTexture>()) {
|
||||
ctx.Remove(ctx.src->AST().GlobalDeclarations(), var);
|
||||
} else if (auto binding_point = var->BindingPoint()) {
|
||||
if (binding_point.group->value == 0 &&
|
||||
binding_point.binding->value == 0) {
|
||||
auto* attribute =
|
||||
ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision);
|
||||
ctx.InsertFront(var->attributes, attribute);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Rewrite all function signatures to use combined samplers, and remove
|
||||
// separate textures & samplers. Create new combined globals where found.
|
||||
ctx.ReplaceAll([&](const ast::Function* src) -> const ast::Function* {
|
||||
if (auto* func = sem.Get(src)) {
|
||||
auto pairs = func->TextureSamplerPairs();
|
||||
if (pairs.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
ast::VariableList params;
|
||||
for (auto pair : func->TextureSamplerPairs()) {
|
||||
const sem::Variable* texture_var = pair.first;
|
||||
const sem::Variable* sampler_var = pair.second;
|
||||
std::string name =
|
||||
ctx.src->Symbols().NameFor(texture_var->Declaration()->symbol);
|
||||
if (sampler_var) {
|
||||
name += "_" + ctx.src->Symbols().NameFor(
|
||||
sampler_var->Declaration()->symbol);
|
||||
}
|
||||
if (IsGlobal(pair)) {
|
||||
// Both texture and sampler are global; add a new global variable
|
||||
// to represent the combined sampler (if not already created).
|
||||
utils::GetOrCreate(global_combined_texture_samplers_, pair, [&] {
|
||||
return CreateCombinedGlobal(texture_var, sampler_var, name);
|
||||
});
|
||||
} else {
|
||||
// Either texture or sampler (or both) is a function parameter;
|
||||
// add a new function parameter to represent the combined sampler.
|
||||
const ast::Type* type =
|
||||
CreateCombinedASTTypeFor(texture_var, sampler_var);
|
||||
const ast::Variable* var =
|
||||
ctx.dst->Param(ctx.dst->Symbols().New(name), type);
|
||||
params.push_back(var);
|
||||
function_combined_texture_samplers_[func][pair] = var;
|
||||
}
|
||||
}
|
||||
// Filter out separate textures and samplers from the original
|
||||
// function signature.
|
||||
for (auto* var : src->params) {
|
||||
if (!sem.Get(var->type)->IsAnyOf<sem::Texture, sem::Sampler>()) {
|
||||
params.push_back(ctx.Clone(var));
|
||||
}
|
||||
}
|
||||
// Create a new function signature that differs only in the parameter
|
||||
// list.
|
||||
auto symbol = ctx.Clone(src->symbol);
|
||||
auto* return_type = ctx.Clone(src->return_type);
|
||||
auto* body = ctx.Clone(src->body);
|
||||
auto attributes = ctx.Clone(src->attributes);
|
||||
auto return_type_attributes = ctx.Clone(src->return_type_attributes);
|
||||
return ctx.dst->create<ast::Function>(
|
||||
symbol, params, return_type, body, std::move(attributes),
|
||||
std::move(return_type_attributes));
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
// Replace all function call expressions containing texture or
|
||||
// sampler parameters to use the current function's combined samplers or
|
||||
// the combined global samplers, as appropriate.
|
||||
ctx.ReplaceAll([&](const ast::CallExpression* expr)
|
||||
-> const ast::Expression* {
|
||||
if (auto* call = sem.Get(expr)) {
|
||||
ast::ExpressionList args;
|
||||
// Replace all texture builtin calls.
|
||||
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
|
||||
const auto& signature = builtin->Signature();
|
||||
int sampler_index = signature.IndexOf(sem::ParameterUsage::kSampler);
|
||||
int texture_index = signature.IndexOf(sem::ParameterUsage::kTexture);
|
||||
if (texture_index == -1) {
|
||||
return nullptr;
|
||||
}
|
||||
const sem::Expression* texture = call->Arguments()[texture_index];
|
||||
// We don't want to combine storage textures with anything, since
|
||||
// they never have associated samplers in GLSL.
|
||||
if (texture->Type()->UnwrapRef()->Is<sem::StorageTexture>()) {
|
||||
return nullptr;
|
||||
}
|
||||
const sem::Expression* sampler =
|
||||
sampler_index != -1 ? call->Arguments()[sampler_index] : nullptr;
|
||||
auto* texture_var = texture->As<sem::VariableUser>()->Variable();
|
||||
auto* sampler_var =
|
||||
sampler ? sampler->As<sem::VariableUser>()->Variable() : nullptr;
|
||||
sem::VariablePair new_pair(texture_var, sampler_var);
|
||||
for (auto* arg : expr->args) {
|
||||
auto* type = ctx.src->TypeOf(arg)->UnwrapRef();
|
||||
if (type->Is<sem::Texture>()) {
|
||||
const ast::Variable* var =
|
||||
IsGlobal(new_pair)
|
||||
? global_combined_texture_samplers_[new_pair]
|
||||
: function_combined_texture_samplers_
|
||||
[call->Stmt()->Function()][new_pair];
|
||||
args.push_back(ctx.dst->Expr(var->symbol));
|
||||
} else if (auto* sampler_type = type->As<sem::Sampler>()) {
|
||||
ast::SamplerKind kind = sampler_type->kind();
|
||||
int index = (kind == ast::SamplerKind::kSampler) ? 0 : 1;
|
||||
const ast::Variable*& p = placeholder_samplers_[index];
|
||||
if (!p) {
|
||||
p = CreatePlaceholder(kind);
|
||||
}
|
||||
args.push_back(ctx.dst->Expr(p->symbol));
|
||||
} else {
|
||||
args.push_back(ctx.Clone(arg));
|
||||
}
|
||||
}
|
||||
const ast::Expression* value =
|
||||
ctx.dst->Call(ctx.Clone(expr->target.name), args);
|
||||
if (builtin->Type() == sem::BuiltinType::kTextureLoad &&
|
||||
texture_var->Type()->UnwrapRef()->Is<sem::DepthTexture>() &&
|
||||
!call->Stmt()->Declaration()->Is<ast::CallStatement>()) {
|
||||
value = ctx.dst->MemberAccessor(value, "x");
|
||||
}
|
||||
return value;
|
||||
}
|
||||
// Replace all function calls.
|
||||
if (auto* callee = call->Target()->As<sem::Function>()) {
|
||||
for (auto pair : callee->TextureSamplerPairs()) {
|
||||
// Global pairs used by the callee do not require a function
|
||||
// parameter at the call site.
|
||||
if (IsGlobal(pair)) {
|
||||
continue;
|
||||
}
|
||||
const sem::Variable* texture_var = pair.first;
|
||||
const sem::Variable* sampler_var = pair.second;
|
||||
if (auto* param = texture_var->As<sem::Parameter>()) {
|
||||
const sem::Expression* texture =
|
||||
call->Arguments()[param->Index()];
|
||||
texture_var = texture->As<sem::VariableUser>()->Variable();
|
||||
}
|
||||
if (sampler_var) {
|
||||
if (auto* param = sampler_var->As<sem::Parameter>()) {
|
||||
const sem::Expression* sampler =
|
||||
call->Arguments()[param->Index()];
|
||||
sampler_var = sampler->As<sem::VariableUser>()->Variable();
|
||||
}
|
||||
}
|
||||
sem::VariablePair new_pair(texture_var, sampler_var);
|
||||
// If both texture and sampler are (now) global, pass that
|
||||
// global variable to the callee. Otherwise use the caller's
|
||||
// function parameter for this pair.
|
||||
const ast::Variable* var =
|
||||
IsGlobal(new_pair) ? global_combined_texture_samplers_[new_pair]
|
||||
: function_combined_texture_samplers_
|
||||
[call->Stmt()->Function()][new_pair];
|
||||
auto* arg = ctx.dst->Expr(var->symbol);
|
||||
args.push_back(arg);
|
||||
}
|
||||
// Append all of the remaining non-texture and non-sampler
|
||||
// parameters.
|
||||
for (auto* arg : expr->args) {
|
||||
if (!ctx.src->TypeOf(arg)
|
||||
->UnwrapRef()
|
||||
->IsAnyOf<sem::Texture, sem::Sampler>()) {
|
||||
args.push_back(ctx.Clone(arg));
|
||||
}
|
||||
}
|
||||
return ctx.dst->Call(ctx.Clone(expr->target.name), args);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
}
|
||||
};
|
||||
|
||||
CombineSamplers::CombineSamplers() = default;
|
||||
|
||||
CombineSamplers::~CombineSamplers() = default;
|
||||
|
||||
void CombineSamplers::Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap&) const {
|
||||
auto* binding_info = inputs.Get<BindingInfo>();
|
||||
if (!binding_info) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
diag::System::Transform,
|
||||
"missing transform data for " + std::string(TypeInfo().name));
|
||||
return;
|
||||
}
|
||||
|
||||
State(ctx, binding_info).Run();
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
110
src/tint/transform/combine_samplers.h
Normal file
110
src/tint/transform/combine_samplers.h
Normal file
@@ -0,0 +1,110 @@
|
||||
// Copyright 2022 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_COMBINE_SAMPLERS_H_
|
||||
#define SRC_TINT_TRANSFORM_COMBINE_SAMPLERS_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "src/tint/sem/sampler_texture_pair.h"
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// This transform converts all separate texture/sampler refences in a
|
||||
/// program into combined texture/samplers. This is required for GLSL,
|
||||
/// which does not support separate texture/samplers.
|
||||
///
|
||||
/// It utilizes the texture/sampler information collected by the
|
||||
/// Resolver and stored on each sem::Function. For each function, all
|
||||
/// separate texture/sampler parameters in the function signature are
|
||||
/// removed. For each unique pair, if both texture and sampler are
|
||||
/// global variables, the function passes the corresponding combined
|
||||
/// global stored in global_combined_texture_samplers_ at the call
|
||||
/// site. Otherwise, either the texture or sampler must be a function
|
||||
/// parameter. In this case, a new parameter is added to the function
|
||||
/// signature. All separate texture/sampler parameters are removed.
|
||||
///
|
||||
/// All texture builtin callsites are modified to pass the combined
|
||||
/// texture/sampler as the first argument, and separate texture/sampler
|
||||
/// arguments are removed.
|
||||
///
|
||||
/// Note that the sampler may be null, indicating that only a texture
|
||||
/// reference was required (e.g., textureLoad). In this case, a
|
||||
/// placeholder global sampler is used at the AST level. This will be
|
||||
/// combined with the original texture to give a combined global, and
|
||||
/// the placeholder removed (ignored) by the GLSL writer.
|
||||
///
|
||||
/// Note that the combined samplers are actually represented by a
|
||||
/// Texture node at the AST level, since this contains all the
|
||||
/// information needed to represent a combined sampler in GLSL
|
||||
/// (dimensionality, component type, etc). The GLSL writer outputs such
|
||||
/// (Tint) Textures as (GLSL) Samplers.
|
||||
class CombineSamplers : public Castable<CombineSamplers, Transform> {
|
||||
public:
|
||||
/// A pair of binding points.
|
||||
using SamplerTexturePair = sem::SamplerTexturePair;
|
||||
|
||||
/// A map from a sampler/texture pair to a named global.
|
||||
using BindingMap = std::unordered_map<SamplerTexturePair, std::string>;
|
||||
|
||||
/// The client-provided mapping from separate texture and sampler binding
|
||||
/// points to combined sampler binding point.
|
||||
struct BindingInfo : public Castable<Data, transform::Data> {
|
||||
/// Constructor
|
||||
/// @param map the map of all (texture, sampler) -> (combined) pairs
|
||||
/// @param placeholder the binding point to use for placeholder samplers.
|
||||
BindingInfo(const BindingMap& map, const sem::BindingPoint& placeholder);
|
||||
|
||||
/// Copy constructor
|
||||
/// @param other the other BindingInfo to copy
|
||||
BindingInfo(const BindingInfo& other);
|
||||
|
||||
/// Destructor
|
||||
~BindingInfo() override;
|
||||
|
||||
/// A map of bindings from (texture, sampler) -> combined sampler.
|
||||
BindingMap binding_map;
|
||||
|
||||
/// The binding point to use for placeholder samplers.
|
||||
sem::BindingPoint placeholder_binding_point;
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
CombineSamplers();
|
||||
|
||||
/// Destructor
|
||||
~CombineSamplers() override;
|
||||
|
||||
protected:
|
||||
/// The PIMPL state for this transform
|
||||
struct State;
|
||||
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_COMBINE_SAMPLERS_H_
|
||||
1012
src/tint/transform/combine_samplers_test.cc
Normal file
1012
src/tint/transform/combine_samplers_test.cc
Normal file
File diff suppressed because it is too large
Load Diff
998
src/tint/transform/decompose_memory_access.cc
Normal file
998
src/tint/transform/decompose_memory_access.cc
Normal file
@@ -0,0 +1,998 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/decompose_memory_access.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/tint/ast/assignment_statement.h"
|
||||
#include "src/tint/ast/call_statement.h"
|
||||
#include "src/tint/ast/disable_validation_attribute.h"
|
||||
#include "src/tint/ast/type_name.h"
|
||||
#include "src/tint/ast/unary_op.h"
|
||||
#include "src/tint/block_allocator.h"
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/array.h"
|
||||
#include "src/tint/sem/atomic_type.h"
|
||||
#include "src/tint/sem/call.h"
|
||||
#include "src/tint/sem/member_accessor_expression.h"
|
||||
#include "src/tint/sem/reference_type.h"
|
||||
#include "src/tint/sem/statement.h"
|
||||
#include "src/tint/sem/struct.h"
|
||||
#include "src/tint/sem/variable.h"
|
||||
#include "src/tint/utils/hash.h"
|
||||
#include "src/tint/utils/map.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeMemoryAccess);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeMemoryAccess::Intrinsic);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
namespace {
|
||||
|
||||
/// Offset is a simple ast::Expression builder interface, used to build byte
|
||||
/// offsets for storage and uniform buffer accesses.
|
||||
struct Offset : Castable<Offset> {
|
||||
/// @returns builds and returns the ast::Expression in `ctx.dst`
|
||||
virtual const ast::Expression* Build(CloneContext& ctx) const = 0;
|
||||
};
|
||||
|
||||
/// OffsetExpr is an implementation of Offset that clones and casts the given
|
||||
/// expression to `u32`.
|
||||
struct OffsetExpr : Offset {
|
||||
const ast::Expression* const expr = nullptr;
|
||||
|
||||
explicit OffsetExpr(const ast::Expression* e) : expr(e) {}
|
||||
|
||||
const ast::Expression* Build(CloneContext& ctx) const override {
|
||||
auto* type = ctx.src->Sem().Get(expr)->Type()->UnwrapRef();
|
||||
auto* res = ctx.Clone(expr);
|
||||
if (!type->Is<sem::U32>()) {
|
||||
res = ctx.dst->Construct<ProgramBuilder::u32>(res);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
/// OffsetLiteral is an implementation of Offset that constructs a u32 literal
|
||||
/// value.
|
||||
struct OffsetLiteral : Castable<OffsetLiteral, Offset> {
|
||||
uint32_t const literal = 0;
|
||||
|
||||
explicit OffsetLiteral(uint32_t lit) : literal(lit) {}
|
||||
|
||||
const ast::Expression* Build(CloneContext& ctx) const override {
|
||||
return ctx.dst->Expr(literal);
|
||||
}
|
||||
};
|
||||
|
||||
/// OffsetBinOp is an implementation of Offset that constructs a binary-op of
|
||||
/// two Offsets.
|
||||
struct OffsetBinOp : Offset {
|
||||
ast::BinaryOp op;
|
||||
Offset const* lhs = nullptr;
|
||||
Offset const* rhs = nullptr;
|
||||
|
||||
const ast::Expression* Build(CloneContext& ctx) const override {
|
||||
return ctx.dst->create<ast::BinaryExpression>(op, lhs->Build(ctx),
|
||||
rhs->Build(ctx));
|
||||
}
|
||||
};
|
||||
|
||||
/// LoadStoreKey is the unordered map key to a load or store intrinsic.
|
||||
struct LoadStoreKey {
|
||||
ast::StorageClass const storage_class; // buffer storage class
|
||||
sem::Type const* buf_ty = nullptr; // buffer type
|
||||
sem::Type const* el_ty = nullptr; // element type
|
||||
bool operator==(const LoadStoreKey& rhs) const {
|
||||
return storage_class == rhs.storage_class && buf_ty == rhs.buf_ty &&
|
||||
el_ty == rhs.el_ty;
|
||||
}
|
||||
struct Hasher {
|
||||
inline std::size_t operator()(const LoadStoreKey& u) const {
|
||||
return utils::Hash(u.storage_class, u.buf_ty, u.el_ty);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
/// AtomicKey is the unordered map key to an atomic intrinsic.
|
||||
struct AtomicKey {
|
||||
sem::Type const* buf_ty = nullptr; // buffer type
|
||||
sem::Type const* el_ty = nullptr; // element type
|
||||
sem::BuiltinType const op; // atomic op
|
||||
bool operator==(const AtomicKey& rhs) const {
|
||||
return buf_ty == rhs.buf_ty && el_ty == rhs.el_ty && op == rhs.op;
|
||||
}
|
||||
struct Hasher {
|
||||
inline std::size_t operator()(const AtomicKey& u) const {
|
||||
return utils::Hash(u.buf_ty, u.el_ty, u.op);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
bool IntrinsicDataTypeFor(const sem::Type* ty,
|
||||
DecomposeMemoryAccess::Intrinsic::DataType& out) {
|
||||
if (ty->Is<sem::I32>()) {
|
||||
out = DecomposeMemoryAccess::Intrinsic::DataType::kI32;
|
||||
return true;
|
||||
}
|
||||
if (ty->Is<sem::U32>()) {
|
||||
out = DecomposeMemoryAccess::Intrinsic::DataType::kU32;
|
||||
return true;
|
||||
}
|
||||
if (ty->Is<sem::F32>()) {
|
||||
out = DecomposeMemoryAccess::Intrinsic::DataType::kF32;
|
||||
return true;
|
||||
}
|
||||
if (auto* vec = ty->As<sem::Vector>()) {
|
||||
switch (vec->Width()) {
|
||||
case 2:
|
||||
if (vec->type()->Is<sem::I32>()) {
|
||||
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2I32;
|
||||
return true;
|
||||
}
|
||||
if (vec->type()->Is<sem::U32>()) {
|
||||
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2U32;
|
||||
return true;
|
||||
}
|
||||
if (vec->type()->Is<sem::F32>()) {
|
||||
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2F32;
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case 3:
|
||||
if (vec->type()->Is<sem::I32>()) {
|
||||
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3I32;
|
||||
return true;
|
||||
}
|
||||
if (vec->type()->Is<sem::U32>()) {
|
||||
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3U32;
|
||||
return true;
|
||||
}
|
||||
if (vec->type()->Is<sem::F32>()) {
|
||||
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3F32;
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case 4:
|
||||
if (vec->type()->Is<sem::I32>()) {
|
||||
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4I32;
|
||||
return true;
|
||||
}
|
||||
if (vec->type()->Is<sem::U32>()) {
|
||||
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4U32;
|
||||
return true;
|
||||
}
|
||||
if (vec->type()->Is<sem::F32>()) {
|
||||
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4F32;
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// @returns a DecomposeMemoryAccess::Intrinsic attribute that can be applied
|
||||
/// to a stub function to load the type `ty`.
|
||||
DecomposeMemoryAccess::Intrinsic* IntrinsicLoadFor(
|
||||
ProgramBuilder* builder,
|
||||
ast::StorageClass storage_class,
|
||||
const sem::Type* ty) {
|
||||
DecomposeMemoryAccess::Intrinsic::DataType type;
|
||||
if (!IntrinsicDataTypeFor(ty, type)) {
|
||||
return nullptr;
|
||||
}
|
||||
return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
|
||||
builder->ID(), DecomposeMemoryAccess::Intrinsic::Op::kLoad, storage_class,
|
||||
type);
|
||||
}
|
||||
|
||||
/// @returns a DecomposeMemoryAccess::Intrinsic attribute that can be applied
|
||||
/// to a stub function to store the type `ty`.
|
||||
DecomposeMemoryAccess::Intrinsic* IntrinsicStoreFor(
|
||||
ProgramBuilder* builder,
|
||||
ast::StorageClass storage_class,
|
||||
const sem::Type* ty) {
|
||||
DecomposeMemoryAccess::Intrinsic::DataType type;
|
||||
if (!IntrinsicDataTypeFor(ty, type)) {
|
||||
return nullptr;
|
||||
}
|
||||
return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
|
||||
builder->ID(), DecomposeMemoryAccess::Intrinsic::Op::kStore,
|
||||
storage_class, type);
|
||||
}
|
||||
|
||||
/// @returns a DecomposeMemoryAccess::Intrinsic attribute that can be applied
|
||||
/// to a stub function for the atomic op and the type `ty`.
|
||||
DecomposeMemoryAccess::Intrinsic* IntrinsicAtomicFor(ProgramBuilder* builder,
|
||||
sem::BuiltinType ity,
|
||||
const sem::Type* ty) {
|
||||
auto op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicLoad;
|
||||
switch (ity) {
|
||||
case sem::BuiltinType::kAtomicLoad:
|
||||
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicLoad;
|
||||
break;
|
||||
case sem::BuiltinType::kAtomicStore:
|
||||
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicStore;
|
||||
break;
|
||||
case sem::BuiltinType::kAtomicAdd:
|
||||
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicAdd;
|
||||
break;
|
||||
case sem::BuiltinType::kAtomicSub:
|
||||
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicSub;
|
||||
break;
|
||||
case sem::BuiltinType::kAtomicMax:
|
||||
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicMax;
|
||||
break;
|
||||
case sem::BuiltinType::kAtomicMin:
|
||||
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicMin;
|
||||
break;
|
||||
case sem::BuiltinType::kAtomicAnd:
|
||||
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicAnd;
|
||||
break;
|
||||
case sem::BuiltinType::kAtomicOr:
|
||||
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicOr;
|
||||
break;
|
||||
case sem::BuiltinType::kAtomicXor:
|
||||
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicXor;
|
||||
break;
|
||||
case sem::BuiltinType::kAtomicExchange:
|
||||
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicExchange;
|
||||
break;
|
||||
case sem::BuiltinType::kAtomicCompareExchangeWeak:
|
||||
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicCompareExchangeWeak;
|
||||
break;
|
||||
default:
|
||||
TINT_ICE(Transform, builder->Diagnostics())
|
||||
<< "invalid IntrinsicType for DecomposeMemoryAccess::Intrinsic: "
|
||||
<< ty->type_name();
|
||||
break;
|
||||
}
|
||||
|
||||
DecomposeMemoryAccess::Intrinsic::DataType type;
|
||||
if (!IntrinsicDataTypeFor(ty, type)) {
|
||||
return nullptr;
|
||||
}
|
||||
return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
|
||||
builder->ID(), op, ast::StorageClass::kStorage, type);
|
||||
}
|
||||
|
||||
/// BufferAccess describes a single storage or uniform buffer access
|
||||
struct BufferAccess {
|
||||
sem::Expression const* var = nullptr; // Storage buffer variable
|
||||
Offset const* offset = nullptr; // The byte offset on var
|
||||
sem::Type const* type = nullptr; // The type of the access
|
||||
operator bool() const { return var; } // Returns true if valid
|
||||
};
|
||||
|
||||
/// Store describes a single storage or uniform buffer write
|
||||
struct Store {
|
||||
const ast::AssignmentStatement* assignment; // The AST assignment statement
|
||||
BufferAccess target; // The target for the write
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
/// State holds the current transform state
|
||||
struct DecomposeMemoryAccess::State {
|
||||
/// The clone context
|
||||
CloneContext& ctx;
|
||||
/// Alias to `*ctx.dst`
|
||||
ProgramBuilder& b;
|
||||
/// Map of AST expression to storage or uniform buffer access
|
||||
/// This map has entries added when encountered, and removed when outer
|
||||
/// expressions chain the access.
|
||||
/// Subset of #expression_order, as expressions are not removed from
|
||||
/// #expression_order.
|
||||
std::unordered_map<const ast::Expression*, BufferAccess> accesses;
|
||||
/// The visited order of AST expressions (superset of #accesses)
|
||||
std::vector<const ast::Expression*> expression_order;
|
||||
/// [buffer-type, element-type] -> load function name
|
||||
std::unordered_map<LoadStoreKey, Symbol, LoadStoreKey::Hasher> load_funcs;
|
||||
/// [buffer-type, element-type] -> store function name
|
||||
std::unordered_map<LoadStoreKey, Symbol, LoadStoreKey::Hasher> store_funcs;
|
||||
/// [buffer-type, element-type, atomic-op] -> load function name
|
||||
std::unordered_map<AtomicKey, Symbol, AtomicKey::Hasher> atomic_funcs;
|
||||
/// List of storage or uniform buffer writes
|
||||
std::vector<Store> stores;
|
||||
/// Allocations for offsets
|
||||
BlockAllocator<Offset> offsets_;
|
||||
|
||||
/// Constructor
|
||||
/// @param context the CloneContext
|
||||
explicit State(CloneContext& context) : ctx(context), b(*ctx.dst) {}
|
||||
|
||||
/// @param offset the offset value to wrap in an Offset
|
||||
/// @returns an Offset for the given literal value
|
||||
const Offset* ToOffset(uint32_t offset) {
|
||||
return offsets_.Create<OffsetLiteral>(offset);
|
||||
}
|
||||
|
||||
/// @param expr the expression to convert to an Offset
|
||||
/// @returns an Offset for the given ast::Expression
|
||||
const Offset* ToOffset(const ast::Expression* expr) {
|
||||
if (auto* u32 = expr->As<ast::UintLiteralExpression>()) {
|
||||
return offsets_.Create<OffsetLiteral>(u32->value);
|
||||
} else if (auto* i32 = expr->As<ast::SintLiteralExpression>()) {
|
||||
if (i32->value > 0) {
|
||||
return offsets_.Create<OffsetLiteral>(i32->value);
|
||||
}
|
||||
}
|
||||
return offsets_.Create<OffsetExpr>(expr);
|
||||
}
|
||||
|
||||
/// @param offset the Offset that is returned
|
||||
/// @returns the given offset (pass-through)
|
||||
const Offset* ToOffset(const Offset* offset) { return offset; }
|
||||
|
||||
/// @param lhs_ the left-hand side of the add expression
|
||||
/// @param rhs_ the right-hand side of the add expression
|
||||
/// @return an Offset that is a sum of lhs and rhs, performing basic constant
|
||||
/// folding if possible
|
||||
template <typename LHS, typename RHS>
|
||||
const Offset* Add(LHS&& lhs_, RHS&& rhs_) {
|
||||
auto* lhs = ToOffset(std::forward<LHS>(lhs_));
|
||||
auto* rhs = ToOffset(std::forward<RHS>(rhs_));
|
||||
auto* lhs_lit = tint::As<OffsetLiteral>(lhs);
|
||||
auto* rhs_lit = tint::As<OffsetLiteral>(rhs);
|
||||
if (lhs_lit && lhs_lit->literal == 0) {
|
||||
return rhs;
|
||||
}
|
||||
if (rhs_lit && rhs_lit->literal == 0) {
|
||||
return lhs;
|
||||
}
|
||||
if (lhs_lit && rhs_lit) {
|
||||
if (static_cast<uint64_t>(lhs_lit->literal) +
|
||||
static_cast<uint64_t>(rhs_lit->literal) <=
|
||||
0xffffffff) {
|
||||
return offsets_.Create<OffsetLiteral>(lhs_lit->literal +
|
||||
rhs_lit->literal);
|
||||
}
|
||||
}
|
||||
auto* out = offsets_.Create<OffsetBinOp>();
|
||||
out->op = ast::BinaryOp::kAdd;
|
||||
out->lhs = lhs;
|
||||
out->rhs = rhs;
|
||||
return out;
|
||||
}
|
||||
|
||||
/// @param lhs_ the left-hand side of the multiply expression
|
||||
/// @param rhs_ the right-hand side of the multiply expression
|
||||
/// @return an Offset that is the multiplication of lhs and rhs, performing
|
||||
/// basic constant folding if possible
|
||||
template <typename LHS, typename RHS>
|
||||
const Offset* Mul(LHS&& lhs_, RHS&& rhs_) {
|
||||
auto* lhs = ToOffset(std::forward<LHS>(lhs_));
|
||||
auto* rhs = ToOffset(std::forward<RHS>(rhs_));
|
||||
auto* lhs_lit = tint::As<OffsetLiteral>(lhs);
|
||||
auto* rhs_lit = tint::As<OffsetLiteral>(rhs);
|
||||
if (lhs_lit && lhs_lit->literal == 0) {
|
||||
return offsets_.Create<OffsetLiteral>(0);
|
||||
}
|
||||
if (rhs_lit && rhs_lit->literal == 0) {
|
||||
return offsets_.Create<OffsetLiteral>(0);
|
||||
}
|
||||
if (lhs_lit && lhs_lit->literal == 1) {
|
||||
return rhs;
|
||||
}
|
||||
if (rhs_lit && rhs_lit->literal == 1) {
|
||||
return lhs;
|
||||
}
|
||||
if (lhs_lit && rhs_lit) {
|
||||
return offsets_.Create<OffsetLiteral>(lhs_lit->literal *
|
||||
rhs_lit->literal);
|
||||
}
|
||||
auto* out = offsets_.Create<OffsetBinOp>();
|
||||
out->op = ast::BinaryOp::kMultiply;
|
||||
out->lhs = lhs;
|
||||
out->rhs = rhs;
|
||||
return out;
|
||||
}
|
||||
|
||||
/// AddAccess() adds the `expr -> access` map item to #accesses, and `expr`
|
||||
/// to #expression_order.
|
||||
/// @param expr the expression that performs the access
|
||||
/// @param access the access
|
||||
void AddAccess(const ast::Expression* expr, const BufferAccess& access) {
|
||||
TINT_ASSERT(Transform, access.type);
|
||||
accesses.emplace(expr, access);
|
||||
expression_order.emplace_back(expr);
|
||||
}
|
||||
|
||||
/// TakeAccess() removes the `node` item from #accesses (if it exists),
|
||||
/// returning the BufferAccess. If #accesses does not hold an item for
|
||||
/// `node`, an invalid BufferAccess is returned.
|
||||
/// @param node the expression that performed an access
|
||||
/// @return the BufferAccess for the given expression
|
||||
BufferAccess TakeAccess(const ast::Expression* node) {
|
||||
auto lhs_it = accesses.find(node);
|
||||
if (lhs_it == accesses.end()) {
|
||||
return {};
|
||||
}
|
||||
auto access = lhs_it->second;
|
||||
accesses.erase(node);
|
||||
return access;
|
||||
}
|
||||
|
||||
/// LoadFunc() returns a symbol to an intrinsic function that loads an element
|
||||
/// of type `el_ty` from a storage or uniform buffer of type `buf_ty`.
|
||||
/// The emitted function has the signature:
|
||||
/// `fn load(buf : buf_ty, offset : u32) -> el_ty`
|
||||
/// @param buf_ty the storage or uniform buffer type
|
||||
/// @param el_ty the storage or uniform buffer element type
|
||||
/// @param var_user the variable user
|
||||
/// @return the name of the function that performs the load
|
||||
Symbol LoadFunc(const sem::Type* buf_ty,
|
||||
const sem::Type* el_ty,
|
||||
const sem::VariableUser* var_user) {
|
||||
auto storage_class = var_user->Variable()->StorageClass();
|
||||
return utils::GetOrCreate(
|
||||
load_funcs, LoadStoreKey{storage_class, buf_ty, el_ty}, [&] {
|
||||
auto* buf_ast_ty = CreateASTTypeFor(ctx, buf_ty);
|
||||
auto* disable_validation = b.Disable(
|
||||
ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
|
||||
|
||||
ast::VariableList params = {
|
||||
// Note: The buffer parameter requires the StorageClass in
|
||||
// order for HLSL to emit this as a ByteAddressBuffer or cbuffer
|
||||
// array.
|
||||
b.create<ast::Variable>(b.Sym("buffer"), storage_class,
|
||||
var_user->Variable()->Access(),
|
||||
buf_ast_ty, true, false, nullptr,
|
||||
ast::AttributeList{disable_validation}),
|
||||
b.Param("offset", b.ty.u32()),
|
||||
};
|
||||
|
||||
auto name = b.Sym();
|
||||
|
||||
if (auto* intrinsic =
|
||||
IntrinsicLoadFor(ctx.dst, storage_class, el_ty)) {
|
||||
auto* el_ast_ty = CreateASTTypeFor(ctx, el_ty);
|
||||
auto* func = b.create<ast::Function>(
|
||||
name, params, el_ast_ty, nullptr,
|
||||
ast::AttributeList{
|
||||
intrinsic,
|
||||
b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
|
||||
},
|
||||
ast::AttributeList{});
|
||||
b.AST().AddFunction(func);
|
||||
} else if (auto* arr_ty = el_ty->As<sem::Array>()) {
|
||||
// fn load_func(buf : buf_ty, offset : u32) -> array<T, N> {
|
||||
// var arr : array<T, N>;
|
||||
// for (var i = 0u; i < array_count; i = i + 1) {
|
||||
// arr[i] = el_load_func(buf, offset + i * array_stride)
|
||||
// }
|
||||
// return arr;
|
||||
// }
|
||||
auto load =
|
||||
LoadFunc(buf_ty, arr_ty->ElemType()->UnwrapRef(), var_user);
|
||||
auto* arr =
|
||||
b.Var(b.Symbols().New("arr"), CreateASTTypeFor(ctx, arr_ty));
|
||||
auto* i = b.Var(b.Symbols().New("i"), nullptr, b.Expr(0u));
|
||||
auto* for_init = b.Decl(i);
|
||||
auto* for_cond = b.create<ast::BinaryExpression>(
|
||||
ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(arr_ty->Count()));
|
||||
auto* for_cont = b.Assign(i, b.Add(i, 1u));
|
||||
auto* arr_el = b.IndexAccessor(arr, i);
|
||||
auto* el_offset =
|
||||
b.Add(b.Expr("offset"), b.Mul(i, arr_ty->Stride()));
|
||||
auto* el_val = b.Call(load, "buffer", el_offset);
|
||||
auto* for_loop = b.For(for_init, for_cond, for_cont,
|
||||
b.Block(b.Assign(arr_el, el_val)));
|
||||
|
||||
b.Func(name, params, CreateASTTypeFor(ctx, arr_ty),
|
||||
{
|
||||
b.Decl(arr),
|
||||
for_loop,
|
||||
b.Return(arr),
|
||||
});
|
||||
} else {
|
||||
ast::ExpressionList values;
|
||||
if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
|
||||
auto* vec_ty = mat_ty->ColumnType();
|
||||
Symbol load = LoadFunc(buf_ty, vec_ty, var_user);
|
||||
for (uint32_t i = 0; i < mat_ty->columns(); i++) {
|
||||
auto* offset = b.Add("offset", i * mat_ty->ColumnStride());
|
||||
values.emplace_back(b.Call(load, "buffer", offset));
|
||||
}
|
||||
} else if (auto* str = el_ty->As<sem::Struct>()) {
|
||||
for (auto* member : str->Members()) {
|
||||
auto* offset = b.Add("offset", member->Offset());
|
||||
Symbol load =
|
||||
LoadFunc(buf_ty, member->Type()->UnwrapRef(), var_user);
|
||||
values.emplace_back(b.Call(load, "buffer", offset));
|
||||
}
|
||||
}
|
||||
b.Func(
|
||||
name, params, CreateASTTypeFor(ctx, el_ty),
|
||||
{
|
||||
b.Return(b.Construct(CreateASTTypeFor(ctx, el_ty), values)),
|
||||
});
|
||||
}
|
||||
return name;
|
||||
});
|
||||
}
|
||||
|
||||
/// StoreFunc() returns a symbol to an intrinsic function that stores an
|
||||
/// element of type `el_ty` to a storage buffer of type `buf_ty`.
|
||||
/// The function has the signature:
|
||||
/// `fn store(buf : buf_ty, offset : u32, value : el_ty)`
|
||||
/// @param buf_ty the storage buffer type
|
||||
/// @param el_ty the storage buffer element type
|
||||
/// @param var_user the variable user
|
||||
/// @return the name of the function that performs the store
|
||||
Symbol StoreFunc(const sem::Type* buf_ty,
|
||||
const sem::Type* el_ty,
|
||||
const sem::VariableUser* var_user) {
|
||||
auto storage_class = var_user->Variable()->StorageClass();
|
||||
return utils::GetOrCreate(
|
||||
store_funcs, LoadStoreKey{storage_class, buf_ty, el_ty}, [&] {
|
||||
auto* buf_ast_ty = CreateASTTypeFor(ctx, buf_ty);
|
||||
auto* el_ast_ty = CreateASTTypeFor(ctx, el_ty);
|
||||
auto* disable_validation = b.Disable(
|
||||
ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
|
||||
ast::VariableList params{
|
||||
// Note: The buffer parameter requires the StorageClass in
|
||||
// order for HLSL to emit this as a ByteAddressBuffer.
|
||||
|
||||
b.create<ast::Variable>(b.Sym("buffer"), storage_class,
|
||||
var_user->Variable()->Access(),
|
||||
buf_ast_ty, true, false, nullptr,
|
||||
ast::AttributeList{disable_validation}),
|
||||
b.Param("offset", b.ty.u32()),
|
||||
b.Param("value", el_ast_ty),
|
||||
};
|
||||
|
||||
auto name = b.Sym();
|
||||
|
||||
if (auto* intrinsic =
|
||||
IntrinsicStoreFor(ctx.dst, storage_class, el_ty)) {
|
||||
auto* func = b.create<ast::Function>(
|
||||
name, params, b.ty.void_(), nullptr,
|
||||
ast::AttributeList{
|
||||
intrinsic,
|
||||
b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
|
||||
},
|
||||
ast::AttributeList{});
|
||||
b.AST().AddFunction(func);
|
||||
} else {
|
||||
ast::StatementList body;
|
||||
if (auto* arr_ty = el_ty->As<sem::Array>()) {
|
||||
// fn store_func(buf : buf_ty, offset : u32, value : el_ty) {
|
||||
// var array = value; // No dynamic indexing on constant arrays
|
||||
// for (var i = 0u; i < array_count; i = i + 1) {
|
||||
// arr[i] = el_store_func(buf, offset + i * array_stride,
|
||||
// value[i])
|
||||
// }
|
||||
// return arr;
|
||||
// }
|
||||
auto* array =
|
||||
b.Var(b.Symbols().New("array"), nullptr, b.Expr("value"));
|
||||
auto store =
|
||||
StoreFunc(buf_ty, arr_ty->ElemType()->UnwrapRef(), var_user);
|
||||
auto* i = b.Var(b.Symbols().New("i"), nullptr, b.Expr(0u));
|
||||
auto* for_init = b.Decl(i);
|
||||
auto* for_cond = b.create<ast::BinaryExpression>(
|
||||
ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(arr_ty->Count()));
|
||||
auto* for_cont = b.Assign(i, b.Add(i, 1u));
|
||||
auto* arr_el = b.IndexAccessor(array, i);
|
||||
auto* el_offset =
|
||||
b.Add(b.Expr("offset"), b.Mul(i, arr_ty->Stride()));
|
||||
auto* store_stmt =
|
||||
b.CallStmt(b.Call(store, "buffer", el_offset, arr_el));
|
||||
auto* for_loop =
|
||||
b.For(for_init, for_cond, for_cont, b.Block(store_stmt));
|
||||
|
||||
body = {b.Decl(array), for_loop};
|
||||
} else if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
|
||||
auto* vec_ty = mat_ty->ColumnType();
|
||||
Symbol store = StoreFunc(buf_ty, vec_ty, var_user);
|
||||
for (uint32_t i = 0; i < mat_ty->columns(); i++) {
|
||||
auto* offset = b.Add("offset", i * mat_ty->ColumnStride());
|
||||
auto* access = b.IndexAccessor("value", i);
|
||||
auto* call = b.Call(store, "buffer", offset, access);
|
||||
body.emplace_back(b.CallStmt(call));
|
||||
}
|
||||
} else if (auto* str = el_ty->As<sem::Struct>()) {
|
||||
for (auto* member : str->Members()) {
|
||||
auto* offset = b.Add("offset", member->Offset());
|
||||
auto* access = b.MemberAccessor(
|
||||
"value", ctx.Clone(member->Declaration()->symbol));
|
||||
Symbol store =
|
||||
StoreFunc(buf_ty, member->Type()->UnwrapRef(), var_user);
|
||||
auto* call = b.Call(store, "buffer", offset, access);
|
||||
body.emplace_back(b.CallStmt(call));
|
||||
}
|
||||
}
|
||||
b.Func(name, params, b.ty.void_(), body);
|
||||
}
|
||||
|
||||
return name;
|
||||
});
|
||||
}
|
||||
|
||||
/// AtomicFunc() returns a symbol to an intrinsic function that performs an
|
||||
/// atomic operation from a storage buffer of type `buf_ty`. The function has
|
||||
/// the signature:
|
||||
// `fn atomic_op(buf : buf_ty, offset : u32, ...) -> T`
|
||||
/// @param buf_ty the storage buffer type
|
||||
/// @param el_ty the storage buffer element type
|
||||
/// @param intrinsic the atomic intrinsic
|
||||
/// @param var_user the variable user
|
||||
/// @return the name of the function that performs the load
|
||||
Symbol AtomicFunc(const sem::Type* buf_ty,
|
||||
const sem::Type* el_ty,
|
||||
const sem::Builtin* intrinsic,
|
||||
const sem::VariableUser* var_user) {
|
||||
auto op = intrinsic->Type();
|
||||
return utils::GetOrCreate(atomic_funcs, AtomicKey{buf_ty, el_ty, op}, [&] {
|
||||
auto* buf_ast_ty = CreateASTTypeFor(ctx, buf_ty);
|
||||
auto* disable_validation = b.Disable(
|
||||
ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
|
||||
// The first parameter to all WGSL atomics is the expression to the
|
||||
// atomic. This is replaced with two parameters: the buffer and offset.
|
||||
|
||||
ast::VariableList params = {
|
||||
// Note: The buffer parameter requires the kStorage StorageClass in
|
||||
// order for HLSL to emit this as a ByteAddressBuffer.
|
||||
b.create<ast::Variable>(b.Sym("buffer"), ast::StorageClass::kStorage,
|
||||
var_user->Variable()->Access(), buf_ast_ty,
|
||||
true, false, nullptr,
|
||||
ast::AttributeList{disable_validation}),
|
||||
b.Param("offset", b.ty.u32()),
|
||||
};
|
||||
|
||||
// Other parameters are copied as-is:
|
||||
for (size_t i = 1; i < intrinsic->Parameters().size(); i++) {
|
||||
auto* param = intrinsic->Parameters()[i];
|
||||
auto* ty = CreateASTTypeFor(ctx, param->Type());
|
||||
params.emplace_back(b.Param("param_" + std::to_string(i), ty));
|
||||
}
|
||||
|
||||
auto* atomic = IntrinsicAtomicFor(ctx.dst, op, el_ty);
|
||||
if (atomic == nullptr) {
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "IntrinsicAtomicFor() returned nullptr for op " << op
|
||||
<< " and type " << el_ty->type_name();
|
||||
}
|
||||
|
||||
auto* ret_ty = CreateASTTypeFor(ctx, intrinsic->ReturnType());
|
||||
auto* func = b.create<ast::Function>(
|
||||
b.Sym(), params, ret_ty, nullptr,
|
||||
ast::AttributeList{
|
||||
atomic,
|
||||
b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
|
||||
},
|
||||
ast::AttributeList{});
|
||||
|
||||
b.AST().AddFunction(func);
|
||||
return func->symbol;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
DecomposeMemoryAccess::Intrinsic::Intrinsic(ProgramID pid,
|
||||
Op o,
|
||||
ast::StorageClass sc,
|
||||
DataType ty)
|
||||
: Base(pid), op(o), storage_class(sc), type(ty) {}
|
||||
DecomposeMemoryAccess::Intrinsic::~Intrinsic() = default;
|
||||
std::string DecomposeMemoryAccess::Intrinsic::InternalName() const {
|
||||
std::stringstream ss;
|
||||
switch (op) {
|
||||
case Op::kLoad:
|
||||
ss << "intrinsic_load_";
|
||||
break;
|
||||
case Op::kStore:
|
||||
ss << "intrinsic_store_";
|
||||
break;
|
||||
case Op::kAtomicLoad:
|
||||
ss << "intrinsic_atomic_load_";
|
||||
break;
|
||||
case Op::kAtomicStore:
|
||||
ss << "intrinsic_atomic_store_";
|
||||
break;
|
||||
case Op::kAtomicAdd:
|
||||
ss << "intrinsic_atomic_add_";
|
||||
break;
|
||||
case Op::kAtomicSub:
|
||||
ss << "intrinsic_atomic_sub_";
|
||||
break;
|
||||
case Op::kAtomicMax:
|
||||
ss << "intrinsic_atomic_max_";
|
||||
break;
|
||||
case Op::kAtomicMin:
|
||||
ss << "intrinsic_atomic_min_";
|
||||
break;
|
||||
case Op::kAtomicAnd:
|
||||
ss << "intrinsic_atomic_and_";
|
||||
break;
|
||||
case Op::kAtomicOr:
|
||||
ss << "intrinsic_atomic_or_";
|
||||
break;
|
||||
case Op::kAtomicXor:
|
||||
ss << "intrinsic_atomic_xor_";
|
||||
break;
|
||||
case Op::kAtomicExchange:
|
||||
ss << "intrinsic_atomic_exchange_";
|
||||
break;
|
||||
case Op::kAtomicCompareExchangeWeak:
|
||||
ss << "intrinsic_atomic_compare_exchange_weak_";
|
||||
break;
|
||||
}
|
||||
ss << storage_class << "_";
|
||||
switch (type) {
|
||||
case DataType::kU32:
|
||||
ss << "u32";
|
||||
break;
|
||||
case DataType::kF32:
|
||||
ss << "f32";
|
||||
break;
|
||||
case DataType::kI32:
|
||||
ss << "i32";
|
||||
break;
|
||||
case DataType::kVec2U32:
|
||||
ss << "vec2_u32";
|
||||
break;
|
||||
case DataType::kVec2F32:
|
||||
ss << "vec2_f32";
|
||||
break;
|
||||
case DataType::kVec2I32:
|
||||
ss << "vec2_i32";
|
||||
break;
|
||||
case DataType::kVec3U32:
|
||||
ss << "vec3_u32";
|
||||
break;
|
||||
case DataType::kVec3F32:
|
||||
ss << "vec3_f32";
|
||||
break;
|
||||
case DataType::kVec3I32:
|
||||
ss << "vec3_i32";
|
||||
break;
|
||||
case DataType::kVec4U32:
|
||||
ss << "vec4_u32";
|
||||
break;
|
||||
case DataType::kVec4F32:
|
||||
ss << "vec4_f32";
|
||||
break;
|
||||
case DataType::kVec4I32:
|
||||
ss << "vec4_i32";
|
||||
break;
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
const DecomposeMemoryAccess::Intrinsic* DecomposeMemoryAccess::Intrinsic::Clone(
|
||||
CloneContext* ctx) const {
|
||||
return ctx->dst->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
|
||||
ctx->dst->ID(), op, storage_class, type);
|
||||
}
|
||||
|
||||
DecomposeMemoryAccess::DecomposeMemoryAccess() = default;
|
||||
DecomposeMemoryAccess::~DecomposeMemoryAccess() = default;
|
||||
|
||||
bool DecomposeMemoryAccess::ShouldRun(const Program* program,
|
||||
const DataMap&) const {
|
||||
for (auto* decl : program->AST().GlobalDeclarations()) {
|
||||
if (auto* var = program->Sem().Get<sem::Variable>(decl)) {
|
||||
if (var->StorageClass() == ast::StorageClass::kStorage ||
|
||||
var->StorageClass() == ast::StorageClass::kUniform) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void DecomposeMemoryAccess::Run(CloneContext& ctx,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
auto& sem = ctx.src->Sem();
|
||||
|
||||
State state(ctx);
|
||||
|
||||
// Scan the AST nodes for storage and uniform buffer accesses. Complex
|
||||
// expression chains (e.g. `storage_buffer.foo.bar[20].x`) are handled by
|
||||
// maintaining an offset chain via the `state.TakeAccess()`,
|
||||
// `state.AddAccess()` methods.
|
||||
//
|
||||
// Inner-most expression nodes are guaranteed to be visited first because AST
|
||||
// nodes are fully immutable and require their children to be constructed
|
||||
// first so their pointer can be passed to the parent's constructor.
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
if (auto* ident = node->As<ast::IdentifierExpression>()) {
|
||||
// X
|
||||
if (auto* var = sem.Get<sem::VariableUser>(ident)) {
|
||||
if (var->Variable()->StorageClass() == ast::StorageClass::kStorage ||
|
||||
var->Variable()->StorageClass() == ast::StorageClass::kUniform) {
|
||||
// Variable to a storage or uniform buffer
|
||||
state.AddAccess(ident, {
|
||||
var,
|
||||
state.ToOffset(0u),
|
||||
var->Type()->UnwrapRef(),
|
||||
});
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto* accessor = node->As<ast::MemberAccessorExpression>()) {
|
||||
// X.Y
|
||||
auto* accessor_sem = sem.Get(accessor);
|
||||
if (auto* swizzle = accessor_sem->As<sem::Swizzle>()) {
|
||||
if (swizzle->Indices().size() == 1) {
|
||||
if (auto access = state.TakeAccess(accessor->structure)) {
|
||||
auto* vec_ty = access.type->As<sem::Vector>();
|
||||
auto* offset =
|
||||
state.Mul(vec_ty->type()->Size(), swizzle->Indices()[0]);
|
||||
state.AddAccess(accessor, {
|
||||
access.var,
|
||||
state.Add(access.offset, offset),
|
||||
vec_ty->type()->UnwrapRef(),
|
||||
});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (auto access = state.TakeAccess(accessor->structure)) {
|
||||
auto* str_ty = access.type->As<sem::Struct>();
|
||||
auto* member = str_ty->FindMember(accessor->member->symbol);
|
||||
auto offset = member->Offset();
|
||||
state.AddAccess(accessor, {
|
||||
access.var,
|
||||
state.Add(access.offset, offset),
|
||||
member->Type()->UnwrapRef(),
|
||||
});
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto* accessor = node->As<ast::IndexAccessorExpression>()) {
|
||||
if (auto access = state.TakeAccess(accessor->object)) {
|
||||
// X[Y]
|
||||
if (auto* arr = access.type->As<sem::Array>()) {
|
||||
auto* offset = state.Mul(arr->Stride(), accessor->index);
|
||||
state.AddAccess(accessor, {
|
||||
access.var,
|
||||
state.Add(access.offset, offset),
|
||||
arr->ElemType()->UnwrapRef(),
|
||||
});
|
||||
continue;
|
||||
}
|
||||
if (auto* vec_ty = access.type->As<sem::Vector>()) {
|
||||
auto* offset = state.Mul(vec_ty->type()->Size(), accessor->index);
|
||||
state.AddAccess(accessor, {
|
||||
access.var,
|
||||
state.Add(access.offset, offset),
|
||||
vec_ty->type()->UnwrapRef(),
|
||||
});
|
||||
continue;
|
||||
}
|
||||
if (auto* mat_ty = access.type->As<sem::Matrix>()) {
|
||||
auto* offset = state.Mul(mat_ty->ColumnStride(), accessor->index);
|
||||
state.AddAccess(accessor, {
|
||||
access.var,
|
||||
state.Add(access.offset, offset),
|
||||
mat_ty->ColumnType(),
|
||||
});
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (auto* op = node->As<ast::UnaryOpExpression>()) {
|
||||
if (op->op == ast::UnaryOp::kAddressOf) {
|
||||
// &X
|
||||
if (auto access = state.TakeAccess(op->expr)) {
|
||||
// HLSL does not support pointers, so just take the access from the
|
||||
// reference and place it on the pointer.
|
||||
state.AddAccess(op, access);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (auto* assign = node->As<ast::AssignmentStatement>()) {
|
||||
// X = Y
|
||||
// Move the LHS access to a store.
|
||||
if (auto lhs = state.TakeAccess(assign->lhs)) {
|
||||
state.stores.emplace_back(Store{assign, lhs});
|
||||
}
|
||||
}
|
||||
|
||||
if (auto* call_expr = node->As<ast::CallExpression>()) {
|
||||
auto* call = sem.Get(call_expr);
|
||||
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
|
||||
if (builtin->Type() == sem::BuiltinType::kArrayLength) {
|
||||
// arrayLength(X)
|
||||
// Don't convert X into a load, this builtin actually requires the
|
||||
// real pointer.
|
||||
state.TakeAccess(call_expr->args[0]);
|
||||
continue;
|
||||
}
|
||||
if (builtin->IsAtomic()) {
|
||||
if (auto access = state.TakeAccess(call_expr->args[0])) {
|
||||
// atomic___(X)
|
||||
ctx.Replace(call_expr, [=, &ctx, &state] {
|
||||
auto* buf = access.var->Declaration();
|
||||
auto* offset = access.offset->Build(ctx);
|
||||
auto* buf_ty = access.var->Type()->UnwrapRef();
|
||||
auto* el_ty = access.type->UnwrapRef()->As<sem::Atomic>()->Type();
|
||||
Symbol func = state.AtomicFunc(
|
||||
buf_ty, el_ty, builtin, access.var->As<sem::VariableUser>());
|
||||
|
||||
ast::ExpressionList args{ctx.Clone(buf), offset};
|
||||
for (size_t i = 1; i < call_expr->args.size(); i++) {
|
||||
auto* arg = call_expr->args[i];
|
||||
args.emplace_back(ctx.Clone(arg));
|
||||
}
|
||||
return ctx.dst->Call(func, args);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// All remaining accesses are loads, transform these into calls to the
|
||||
// corresponding load function
|
||||
for (auto* expr : state.expression_order) {
|
||||
auto access_it = state.accesses.find(expr);
|
||||
if (access_it == state.accesses.end()) {
|
||||
continue;
|
||||
}
|
||||
BufferAccess access = access_it->second;
|
||||
ctx.Replace(expr, [=, &ctx, &state] {
|
||||
auto* buf = access.var->Declaration();
|
||||
auto* offset = access.offset->Build(ctx);
|
||||
auto* buf_ty = access.var->Type()->UnwrapRef();
|
||||
auto* el_ty = access.type->UnwrapRef();
|
||||
Symbol func =
|
||||
state.LoadFunc(buf_ty, el_ty, access.var->As<sem::VariableUser>());
|
||||
return ctx.dst->Call(func, ctx.CloneWithoutTransform(buf), offset);
|
||||
});
|
||||
}
|
||||
|
||||
// And replace all storage and uniform buffer assignments with stores
|
||||
for (auto store : state.stores) {
|
||||
ctx.Replace(store.assignment, [=, &ctx, &state] {
|
||||
auto* buf = store.target.var->Declaration();
|
||||
auto* offset = store.target.offset->Build(ctx);
|
||||
auto* buf_ty = store.target.var->Type()->UnwrapRef();
|
||||
auto* el_ty = store.target.type->UnwrapRef();
|
||||
auto* value = store.assignment->rhs;
|
||||
Symbol func = state.StoreFunc(buf_ty, el_ty,
|
||||
store.target.var->As<sem::VariableUser>());
|
||||
auto* call = ctx.dst->Call(func, ctx.CloneWithoutTransform(buf), offset,
|
||||
ctx.Clone(value));
|
||||
return ctx.dst->CallStmt(call);
|
||||
});
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::Offset);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::OffsetLiteral);
|
||||
131
src/tint/transform/decompose_memory_access.h
Normal file
131
src/tint/transform/decompose_memory_access.h
Normal file
@@ -0,0 +1,131 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_DECOMPOSE_MEMORY_ACCESS_H_
|
||||
#define SRC_TINT_TRANSFORM_DECOMPOSE_MEMORY_ACCESS_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "src/tint/ast/internal_attribute.h"
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
|
||||
// Forward declarations
|
||||
class CloneContext;
|
||||
|
||||
namespace transform {
|
||||
|
||||
/// DecomposeMemoryAccess is a transform used to replace storage and uniform
|
||||
/// buffer accesses with a combination of load, store or atomic functions on
|
||||
/// primitive types.
|
||||
class DecomposeMemoryAccess
|
||||
: public Castable<DecomposeMemoryAccess, Transform> {
|
||||
public:
|
||||
/// Intrinsic is an InternalAttribute that's used to decorate a stub function
|
||||
/// so that the HLSL transforms this into calls to
|
||||
/// `[RW]ByteAddressBuffer.Load[N]()` or `[RW]ByteAddressBuffer.Store[N]()`,
|
||||
/// with a possible cast.
|
||||
class Intrinsic : public Castable<Intrinsic, ast::InternalAttribute> {
|
||||
public:
|
||||
/// Intrinsic op
|
||||
enum class Op {
|
||||
kLoad,
|
||||
kStore,
|
||||
kAtomicLoad,
|
||||
kAtomicStore,
|
||||
kAtomicAdd,
|
||||
kAtomicSub,
|
||||
kAtomicMax,
|
||||
kAtomicMin,
|
||||
kAtomicAnd,
|
||||
kAtomicOr,
|
||||
kAtomicXor,
|
||||
kAtomicExchange,
|
||||
kAtomicCompareExchangeWeak,
|
||||
};
|
||||
|
||||
/// Intrinsic data type
|
||||
enum class DataType {
|
||||
kU32,
|
||||
kF32,
|
||||
kI32,
|
||||
kVec2U32,
|
||||
kVec2F32,
|
||||
kVec2I32,
|
||||
kVec3U32,
|
||||
kVec3F32,
|
||||
kVec3I32,
|
||||
kVec4U32,
|
||||
kVec4F32,
|
||||
kVec4I32,
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
/// @param program_id the identifier of the program that owns this node
|
||||
/// @param o the op of the intrinsic
|
||||
/// @param sc the storage class of the buffer
|
||||
/// @param ty the data type of the intrinsic
|
||||
Intrinsic(ProgramID program_id, Op o, ast::StorageClass sc, DataType ty);
|
||||
/// Destructor
|
||||
~Intrinsic() override;
|
||||
|
||||
/// @return a short description of the internal attribute which will be
|
||||
/// displayed as `@internal(<name>)`
|
||||
std::string InternalName() const override;
|
||||
|
||||
/// Performs a deep clone of this object using the CloneContext `ctx`.
|
||||
/// @param ctx the clone context
|
||||
/// @return the newly cloned object
|
||||
const Intrinsic* Clone(CloneContext* ctx) const override;
|
||||
|
||||
/// The op of the intrinsic
|
||||
const Op op;
|
||||
|
||||
/// The storage class of the buffer this intrinsic operates on
|
||||
ast::StorageClass const storage_class;
|
||||
|
||||
/// The type of the intrinsic
|
||||
const DataType type;
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
DecomposeMemoryAccess();
|
||||
/// Destructor
|
||||
~DecomposeMemoryAccess() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
struct State;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_DECOMPOSE_MEMORY_ACCESS_H_
|
||||
2800
src/tint/transform/decompose_memory_access_test.cc
Normal file
2800
src/tint/transform/decompose_memory_access_test.cc
Normal file
File diff suppressed because it is too large
Load Diff
162
src/tint/transform/decompose_strided_array.cc
Normal file
162
src/tint/transform/decompose_strided_array.cc
Normal file
@@ -0,0 +1,162 @@
|
||||
// Copyright 2022 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/decompose_strided_array.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/call.h"
|
||||
#include "src/tint/sem/expression.h"
|
||||
#include "src/tint/sem/member_accessor_expression.h"
|
||||
#include "src/tint/sem/type_constructor.h"
|
||||
#include "src/tint/transform/simplify_pointers.h"
|
||||
#include "src/tint/utils/hash.h"
|
||||
#include "src/tint/utils/map.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeStridedArray);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using DecomposedArrays = std::unordered_map<const sem::Array*, Symbol>;
|
||||
|
||||
} // namespace
|
||||
|
||||
DecomposeStridedArray::DecomposeStridedArray() = default;
|
||||
|
||||
DecomposeStridedArray::~DecomposeStridedArray() = default;
|
||||
|
||||
bool DecomposeStridedArray::ShouldRun(const Program* program,
|
||||
const DataMap&) const {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (auto* ast = node->As<ast::Array>()) {
|
||||
if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void DecomposeStridedArray::Run(CloneContext& ctx,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
const auto& sem = ctx.src->Sem();
|
||||
|
||||
static constexpr const char* kMemberName = "el";
|
||||
|
||||
// Maps an array type in the source program to the name of the struct wrapper
|
||||
// type in the target program.
|
||||
std::unordered_map<const sem::Array*, Symbol> decomposed;
|
||||
|
||||
// Find and replace all arrays with a @stride attribute with a array that has
|
||||
// the @stride removed. If the source array stride does not match the natural
|
||||
// stride for the array element type, then replace the array element type with
|
||||
// a structure, holding a single field with a @size attribute equal to the
|
||||
// array stride.
|
||||
ctx.ReplaceAll([&](const ast::Array* ast) -> const ast::Array* {
|
||||
if (auto* arr = sem.Get(ast)) {
|
||||
if (!arr->IsStrideImplicit()) {
|
||||
auto el_ty = utils::GetOrCreate(decomposed, arr, [&] {
|
||||
auto name = ctx.dst->Symbols().New("strided_arr");
|
||||
auto* member_ty = ctx.Clone(ast->type);
|
||||
auto* member = ctx.dst->Member(kMemberName, member_ty,
|
||||
{ctx.dst->MemberSize(arr->Stride())});
|
||||
ctx.dst->Structure(name, {member});
|
||||
return name;
|
||||
});
|
||||
auto* count = ctx.Clone(ast->count);
|
||||
return ctx.dst->ty.array(ctx.dst->ty.type_name(el_ty), count);
|
||||
}
|
||||
if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) {
|
||||
// Strip the @stride attribute
|
||||
auto* ty = ctx.Clone(ast->type);
|
||||
auto* count = ctx.Clone(ast->count);
|
||||
return ctx.dst->ty.array(ty, count);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
// Find all array index-accessors expressions for arrays that have had their
|
||||
// element changed to a single field structure. These expressions are adjusted
|
||||
// to insert an additional member accessor for the single structure field.
|
||||
// Example: `arr[i]` -> `arr[i].el`
|
||||
ctx.ReplaceAll(
|
||||
[&](const ast::IndexAccessorExpression* idx) -> const ast::Expression* {
|
||||
if (auto* ty = ctx.src->TypeOf(idx->object)) {
|
||||
if (auto* arr = ty->UnwrapRef()->As<sem::Array>()) {
|
||||
if (!arr->IsStrideImplicit()) {
|
||||
auto* expr = ctx.CloneWithoutTransform(idx);
|
||||
return ctx.dst->MemberAccessor(expr, kMemberName);
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
// Find all array type constructor expressions for array types that have had
|
||||
// their element changed to a single field structure. These constructors are
|
||||
// adjusted to wrap each of the arguments with an additional constructor for
|
||||
// the new element structure type.
|
||||
// Example:
|
||||
// `@stride(32) array<i32, 3>(1, 2, 3)`
|
||||
// ->
|
||||
// `array<strided_arr, 3>(strided_arr(1), strided_arr(2), strided_arr(3))`
|
||||
ctx.ReplaceAll(
|
||||
[&](const ast::CallExpression* expr) -> const ast::Expression* {
|
||||
if (!expr->args.empty()) {
|
||||
if (auto* call = sem.Get(expr)) {
|
||||
if (auto* ctor = call->Target()->As<sem::TypeConstructor>()) {
|
||||
if (auto* arr = ctor->ReturnType()->As<sem::Array>()) {
|
||||
// Begin by cloning the array constructor type or name
|
||||
// If this is an unaliased array, this may add a new entry to
|
||||
// decomposed.
|
||||
// If this is an aliased array, decomposed should already be
|
||||
// populated with any strided aliases.
|
||||
ast::CallExpression::Target target;
|
||||
if (expr->target.type) {
|
||||
target.type = ctx.Clone(expr->target.type);
|
||||
} else {
|
||||
target.name = ctx.Clone(expr->target.name);
|
||||
}
|
||||
|
||||
ast::ExpressionList args;
|
||||
if (auto it = decomposed.find(arr); it != decomposed.end()) {
|
||||
args.reserve(expr->args.size());
|
||||
for (auto* arg : expr->args) {
|
||||
args.emplace_back(
|
||||
ctx.dst->Call(it->second, ctx.Clone(arg)));
|
||||
}
|
||||
} else {
|
||||
args = ctx.Clone(expr->args);
|
||||
}
|
||||
|
||||
return target.type ? ctx.dst->Construct(target.type, args)
|
||||
: ctx.dst->Call(target.name, args);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
61
src/tint/transform/decompose_strided_array.h
Normal file
61
src/tint/transform/decompose_strided_array.h
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_DECOMPOSE_STRIDED_ARRAY_H_
|
||||
#define SRC_TINT_TRANSFORM_DECOMPOSE_STRIDED_ARRAY_H_
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// DecomposeStridedArray transforms replaces arrays with a non-default
|
||||
/// `@stride` attribute with an array of structure elements, where the
|
||||
/// structure contains a single field with an equivalent `@size` attribute.
|
||||
/// `@stride` attributes on arrays that match the default stride are also
|
||||
/// removed.
|
||||
///
|
||||
/// @note Depends on the following transforms to have been run first:
|
||||
/// * SimplifyPointers
|
||||
class DecomposeStridedArray
|
||||
: public Castable<DecomposeStridedArray, Transform> {
|
||||
public:
|
||||
/// Constructor
|
||||
DecomposeStridedArray();
|
||||
|
||||
/// Destructor
|
||||
~DecomposeStridedArray() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_DECOMPOSE_STRIDED_ARRAY_H_
|
||||
698
src/tint/transform/decompose_strided_array_test.cc
Normal file
698
src/tint/transform/decompose_strided_array_test.cc
Normal file
@@ -0,0 +1,698 @@
|
||||
// Copyright 2022 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/decompose_strided_array.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/transform/simplify_pointers.h"
|
||||
#include "src/tint/transform/test_helper.h"
|
||||
#include "src/tint/transform/unshadow.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using DecomposeStridedArrayTest = TransformTest;
|
||||
using f32 = ProgramBuilder::f32;
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, ShouldRunEmptyModule) {
|
||||
ProgramBuilder b;
|
||||
EXPECT_FALSE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, ShouldRunNonStridedArray) {
|
||||
// var<private> arr : array<f32, 4>
|
||||
|
||||
ProgramBuilder b;
|
||||
b.Global("arr", b.ty.array<f32, 4>(), ast::StorageClass::kPrivate);
|
||||
EXPECT_FALSE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, ShouldRunDefaultStridedArray) {
|
||||
// var<private> arr : @stride(4) array<f32, 4>
|
||||
|
||||
ProgramBuilder b;
|
||||
b.Global("arr", b.ty.array<f32, 4>(4), ast::StorageClass::kPrivate);
|
||||
EXPECT_TRUE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, ShouldRunExplicitStridedArray) {
|
||||
// var<private> arr : @stride(16) array<f32, 4>
|
||||
|
||||
ProgramBuilder b;
|
||||
b.Global("arr", b.ty.array<f32, 4>(16), ast::StorageClass::kPrivate);
|
||||
EXPECT_TRUE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, Empty) {
|
||||
auto* src = R"()";
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<DecomposeStridedArray>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, PrivateDefaultStridedArray) {
|
||||
// var<private> arr : @stride(4) array<f32, 4>
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let a : @stride(4) array<f32, 4> = a;
|
||||
// let b : f32 = arr[1];
|
||||
// }
|
||||
|
||||
ProgramBuilder b;
|
||||
b.Global("arr", b.ty.array<f32, 4>(4), ast::StorageClass::kPrivate);
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("a", b.ty.array<f32, 4>(4), b.Expr("arr"))),
|
||||
b.Decl(b.Const("b", b.ty.f32(), b.IndexAccessor("arr", 1))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
var<private> arr : array<f32, 4>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let a : array<f32, 4> = arr;
|
||||
let b : f32 = arr[1];
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, PrivateStridedArray) {
|
||||
// var<private> arr : @stride(32) array<f32, 4>
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let a : @stride(32) array<f32, 4> = a;
|
||||
// let b : f32 = arr[1];
|
||||
// }
|
||||
|
||||
ProgramBuilder b;
|
||||
b.Global("arr", b.ty.array<f32, 4>(32), ast::StorageClass::kPrivate);
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("a", b.ty.array<f32, 4>(32), b.Expr("arr"))),
|
||||
b.Decl(b.Const("b", b.ty.f32(), b.IndexAccessor("arr", 1))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
struct strided_arr {
|
||||
@size(32)
|
||||
el : f32;
|
||||
}
|
||||
|
||||
var<private> arr : array<strided_arr, 4>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let a : array<strided_arr, 4> = arr;
|
||||
let b : f32 = arr[1].el;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, ReadUniformStridedArray) {
|
||||
// struct S {
|
||||
// a : @stride(32) array<f32, 4>;
|
||||
// };
|
||||
// @group(0) @binding(0) var<uniform> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let a : @stride(32) array<f32, 4> = s.a;
|
||||
// let b : f32 = s.a[1];
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(32))});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
|
||||
b.GroupAndBinding(0, 0));
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("a", b.ty.array<f32, 4>(32),
|
||||
b.MemberAccessor("s", "a"))),
|
||||
b.Decl(b.Const("b", b.ty.f32(),
|
||||
b.IndexAccessor(b.MemberAccessor("s", "a"), 1))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
struct strided_arr {
|
||||
@size(32)
|
||||
el : f32;
|
||||
}
|
||||
|
||||
struct S {
|
||||
a : array<strided_arr, 4>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<uniform> s : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let a : array<strided_arr, 4> = s.a;
|
||||
let b : f32 = s.a[1].el;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, ReadUniformDefaultStridedArray) {
|
||||
// struct S {
|
||||
// a : @stride(16) array<vec4<f32>, 4>;
|
||||
// };
|
||||
// @group(0) @binding(0) var<uniform> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let a : @stride(16) array<vec4<f32>, 4> = s.a;
|
||||
// let b : f32 = s.a[1][2];
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S =
|
||||
b.Structure("S", {b.Member("a", b.ty.array(b.ty.vec4<f32>(), 4, 16))});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
|
||||
b.GroupAndBinding(0, 0));
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("a", b.ty.array(b.ty.vec4<f32>(), 4, 16),
|
||||
b.MemberAccessor("s", "a"))),
|
||||
b.Decl(b.Const(
|
||||
"b", b.ty.f32(),
|
||||
b.IndexAccessor(b.IndexAccessor(b.MemberAccessor("s", "a"), 1),
|
||||
2))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect =
|
||||
R"(
|
||||
struct S {
|
||||
a : array<vec4<f32>, 4>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<uniform> s : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let a : array<vec4<f32>, 4> = s.a;
|
||||
let b : f32 = s.a[1][2];
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, ReadStorageStridedArray) {
|
||||
// struct S {
|
||||
// a : @stride(32) array<f32, 4>;
|
||||
// };
|
||||
// @group(0) @binding(0) var<storage> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let a : @stride(32) array<f32, 4> = s.a;
|
||||
// let b : f32 = s.a[1];
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(32))});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
|
||||
b.GroupAndBinding(0, 0));
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("a", b.ty.array<f32, 4>(32),
|
||||
b.MemberAccessor("s", "a"))),
|
||||
b.Decl(b.Const("b", b.ty.f32(),
|
||||
b.IndexAccessor(b.MemberAccessor("s", "a"), 1))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
struct strided_arr {
|
||||
@size(32)
|
||||
el : f32;
|
||||
}
|
||||
|
||||
struct S {
|
||||
a : array<strided_arr, 4>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage> s : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let a : array<strided_arr, 4> = s.a;
|
||||
let b : f32 = s.a[1].el;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, ReadStorageDefaultStridedArray) {
|
||||
// struct S {
|
||||
// a : @stride(4) array<f32, 4>;
|
||||
// };
|
||||
// @group(0) @binding(0) var<storage> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let a : @stride(4) array<f32, 4> = s.a;
|
||||
// let b : f32 = s.a[1];
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(4))});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
|
||||
b.GroupAndBinding(0, 0));
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("a", b.ty.array<f32, 4>(4),
|
||||
b.MemberAccessor("s", "a"))),
|
||||
b.Decl(b.Const("b", b.ty.f32(),
|
||||
b.IndexAccessor(b.MemberAccessor("s", "a"), 1))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
a : array<f32, 4>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage> s : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let a : array<f32, 4> = s.a;
|
||||
let b : f32 = s.a[1];
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, WriteStorageStridedArray) {
|
||||
// struct S {
|
||||
// a : @stride(32) array<f32, 4>;
|
||||
// };
|
||||
// @group(0) @binding(0) var<storage, read_write> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// s.a = @stride(32) array<f32, 4>();
|
||||
// s.a = @stride(32) array<f32, 4>(1.0, 2.0, 3.0, 4.0);
|
||||
// s.a[1] = 5.0;
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(32))});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
|
||||
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
|
||||
b.Func(
|
||||
"f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Assign(b.MemberAccessor("s", "a"),
|
||||
b.Construct(b.ty.array<f32, 4>(32))),
|
||||
b.Assign(b.MemberAccessor("s", "a"),
|
||||
b.Construct(b.ty.array<f32, 4>(32), 1.0f, 2.0f, 3.0f, 4.0f)),
|
||||
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1), 5.0f),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect =
|
||||
R"(
|
||||
struct strided_arr {
|
||||
@size(32)
|
||||
el : f32;
|
||||
}
|
||||
|
||||
struct S {
|
||||
a : array<strided_arr, 4>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> s : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
s.a = array<strided_arr, 4>();
|
||||
s.a = array<strided_arr, 4>(strided_arr(1.0), strided_arr(2.0), strided_arr(3.0), strided_arr(4.0));
|
||||
s.a[1].el = 5.0;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, WriteStorageDefaultStridedArray) {
|
||||
// struct S {
|
||||
// a : @stride(4) array<f32, 4>;
|
||||
// };
|
||||
// @group(0) @binding(0) var<storage, read_write> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// s.a = @stride(4) array<f32, 4>();
|
||||
// s.a = @stride(4) array<f32, 4>(1.0, 2.0, 3.0, 4.0);
|
||||
// s.a[1] = 5.0;
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(4))});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
|
||||
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
|
||||
b.Func(
|
||||
"f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Assign(b.MemberAccessor("s", "a"),
|
||||
b.Construct(b.ty.array<f32, 4>(4))),
|
||||
b.Assign(b.MemberAccessor("s", "a"),
|
||||
b.Construct(b.ty.array<f32, 4>(4), 1.0f, 2.0f, 3.0f, 4.0f)),
|
||||
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1), 5.0f),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect =
|
||||
R"(
|
||||
struct S {
|
||||
a : array<f32, 4>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> s : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
s.a = array<f32, 4>();
|
||||
s.a = array<f32, 4>(1.0, 2.0, 3.0, 4.0);
|
||||
s.a[1] = 5.0;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, ReadWriteViaPointerLets) {
|
||||
// struct S {
|
||||
// a : @stride(32) array<f32, 4>;
|
||||
// };
|
||||
// @group(0) @binding(0) var<storage, read_write> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let a = &s.a;
|
||||
// let b = &*&*(a);
|
||||
// let c = *b;
|
||||
// let d = (*b)[1];
|
||||
// (*b) = @stride(32) array<f32, 4>(1.0, 2.0, 3.0, 4.0);
|
||||
// (*b)[1] = 5.0;
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(32))});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
|
||||
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("a", nullptr,
|
||||
b.AddressOf(b.MemberAccessor("s", "a")))),
|
||||
b.Decl(b.Const("b", nullptr,
|
||||
b.AddressOf(b.Deref(b.AddressOf(b.Deref("a")))))),
|
||||
b.Decl(b.Const("c", nullptr, b.Deref("b"))),
|
||||
b.Decl(b.Const("d", nullptr, b.IndexAccessor(b.Deref("b"), 1))),
|
||||
b.Assign(b.Deref("b"), b.Construct(b.ty.array<f32, 4>(32), 1.0f,
|
||||
2.0f, 3.0f, 4.0f)),
|
||||
b.Assign(b.IndexAccessor(b.Deref("b"), 1), 5.0f),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect =
|
||||
R"(
|
||||
struct strided_arr {
|
||||
@size(32)
|
||||
el : f32;
|
||||
}
|
||||
|
||||
struct S {
|
||||
a : array<strided_arr, 4>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> s : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let c = s.a;
|
||||
let d = s.a[1].el;
|
||||
s.a = array<strided_arr, 4>(strided_arr(1.0), strided_arr(2.0), strided_arr(3.0), strided_arr(4.0));
|
||||
s.a[1].el = 5.0;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, PrivateAliasedStridedArray) {
|
||||
// type ARR = @stride(32) array<f32, 4>;
|
||||
// struct S {
|
||||
// a : ARR;
|
||||
// };
|
||||
// @group(0) @binding(0) var<storage, read_write> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let a : ARR = s.a;
|
||||
// let b : f32 = s.a[1];
|
||||
// s.a = ARR();
|
||||
// s.a = ARR(1.0, 2.0, 3.0, 4.0);
|
||||
// s.a[1] = 5.0;
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
b.Alias("ARR", b.ty.array<f32, 4>(32));
|
||||
auto* S = b.Structure("S", {b.Member("a", b.ty.type_name("ARR"))});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
|
||||
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
|
||||
b.Func(
|
||||
"f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(
|
||||
b.Const("a", b.ty.type_name("ARR"), b.MemberAccessor("s", "a"))),
|
||||
b.Decl(b.Const("b", b.ty.f32(),
|
||||
b.IndexAccessor(b.MemberAccessor("s", "a"), 1))),
|
||||
b.Assign(b.MemberAccessor("s", "a"),
|
||||
b.Construct(b.ty.type_name("ARR"))),
|
||||
b.Assign(b.MemberAccessor("s", "a"),
|
||||
b.Construct(b.ty.type_name("ARR"), 1.0f, 2.0f, 3.0f, 4.0f)),
|
||||
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1), 5.0f),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
struct strided_arr {
|
||||
@size(32)
|
||||
el : f32;
|
||||
}
|
||||
|
||||
type ARR = array<strided_arr, 4>;
|
||||
|
||||
struct S {
|
||||
a : ARR;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> s : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let a : ARR = s.a;
|
||||
let b : f32 = s.a[1].el;
|
||||
s.a = ARR();
|
||||
s.a = ARR(strided_arr(1.0), strided_arr(2.0), strided_arr(3.0), strided_arr(4.0));
|
||||
s.a[1].el = 5.0;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedArrayTest, PrivateNestedStridedArray) {
|
||||
// type ARR_A = @stride(8) array<f32, 2>;
|
||||
// type ARR_B = @stride(128) array<@stride(16) array<ARR_A, 3>, 4>;
|
||||
// struct S {
|
||||
// a : ARR_B;
|
||||
// };
|
||||
// @group(0) @binding(0) var<storage, read_write> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let a : ARR_B = s.a;
|
||||
// let b : array<@stride(8) array<f32, 2>, 3> = s.a[3];
|
||||
// let c = s.a[3][2];
|
||||
// let d = s.a[3][2][1];
|
||||
// s.a = ARR_B();
|
||||
// s.a[3][2][1] = 5.0;
|
||||
// }
|
||||
|
||||
ProgramBuilder b;
|
||||
b.Alias("ARR_A", b.ty.array<f32, 2>(8));
|
||||
b.Alias("ARR_B",
|
||||
b.ty.array( //
|
||||
b.ty.array(b.ty.type_name("ARR_A"), 3, 16), //
|
||||
4, 128));
|
||||
auto* S = b.Structure("S", {b.Member("a", b.ty.type_name("ARR_B"))});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
|
||||
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("a", b.ty.type_name("ARR_B"),
|
||||
b.MemberAccessor("s", "a"))),
|
||||
b.Decl(b.Const("b", b.ty.array(b.ty.type_name("ARR_A"), 3, 16),
|
||||
b.IndexAccessor( //
|
||||
b.MemberAccessor("s", "a"), //
|
||||
3))),
|
||||
b.Decl(b.Const("c", b.ty.type_name("ARR_A"),
|
||||
b.IndexAccessor( //
|
||||
b.IndexAccessor( //
|
||||
b.MemberAccessor("s", "a"), //
|
||||
3),
|
||||
2))),
|
||||
b.Decl(b.Const("d", b.ty.f32(),
|
||||
b.IndexAccessor( //
|
||||
b.IndexAccessor( //
|
||||
b.IndexAccessor( //
|
||||
b.MemberAccessor("s", "a"), //
|
||||
3),
|
||||
2),
|
||||
1))),
|
||||
b.Assign(b.MemberAccessor("s", "a"),
|
||||
b.Construct(b.ty.type_name("ARR_B"))),
|
||||
b.Assign(b.IndexAccessor( //
|
||||
b.IndexAccessor( //
|
||||
b.IndexAccessor( //
|
||||
b.MemberAccessor("s", "a"), //
|
||||
3),
|
||||
2),
|
||||
1),
|
||||
5.0f),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect =
|
||||
R"(
|
||||
struct strided_arr {
|
||||
@size(8)
|
||||
el : f32;
|
||||
}
|
||||
|
||||
type ARR_A = array<strided_arr, 2>;
|
||||
|
||||
struct strided_arr_1 {
|
||||
@size(128)
|
||||
el : array<ARR_A, 3>;
|
||||
}
|
||||
|
||||
type ARR_B = array<strided_arr_1, 4>;
|
||||
|
||||
struct S {
|
||||
a : ARR_B;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> s : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let a : ARR_B = s.a;
|
||||
let b : array<ARR_A, 3> = s.a[3].el;
|
||||
let c : ARR_A = s.a[3].el[2];
|
||||
let d : f32 = s.a[3].el[2][1].el;
|
||||
s.a = ARR_B();
|
||||
s.a[3].el[2][1].el = 5.0;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
251
src/tint/transform/decompose_strided_matrix.cc
Normal file
251
src/tint/transform/decompose_strided_matrix.cc
Normal file
@@ -0,0 +1,251 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/decompose_strided_matrix.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/expression.h"
|
||||
#include "src/tint/sem/member_accessor_expression.h"
|
||||
#include "src/tint/transform/simplify_pointers.h"
|
||||
#include "src/tint/utils/hash.h"
|
||||
#include "src/tint/utils/map.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeStridedMatrix);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
/// MatrixInfo describes a matrix member with a custom stride
|
||||
struct MatrixInfo {
|
||||
/// The stride in bytes between columns of the matrix
|
||||
uint32_t stride = 0;
|
||||
/// The type of the matrix
|
||||
const sem::Matrix* matrix = nullptr;
|
||||
|
||||
/// @returns a new ast::Array that holds an vector column for each row of the
|
||||
/// matrix.
|
||||
const ast::Array* array(ProgramBuilder* b) const {
|
||||
return b->ty.array(b->ty.vec<ProgramBuilder::f32>(matrix->rows()),
|
||||
matrix->columns(), stride);
|
||||
}
|
||||
|
||||
/// Equality operator
|
||||
bool operator==(const MatrixInfo& info) const {
|
||||
return stride == info.stride && matrix == info.matrix;
|
||||
}
|
||||
/// Hash function
|
||||
struct Hasher {
|
||||
size_t operator()(const MatrixInfo& t) const {
|
||||
return utils::Hash(t.stride, t.matrix);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
/// Return type of the callback function of GatherCustomStrideMatrixMembers
|
||||
enum GatherResult { kContinue, kStop };
|
||||
|
||||
/// GatherCustomStrideMatrixMembers scans `program` for all matrix members of
|
||||
/// storage and uniform structs, which are of a matrix type, and have a custom
|
||||
/// matrix stride attribute. For each matrix member found, `callback` is called.
|
||||
/// `callback` is a function with the signature:
|
||||
/// GatherResult(const sem::StructMember* member,
|
||||
/// sem::Matrix* matrix,
|
||||
/// uint32_t stride)
|
||||
/// If `callback` return GatherResult::kStop, then the scanning will immediately
|
||||
/// terminate, and GatherCustomStrideMatrixMembers() will return, otherwise
|
||||
/// scanning will continue.
|
||||
template <typename F>
|
||||
void GatherCustomStrideMatrixMembers(const Program* program, F&& callback) {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (auto* str = node->As<ast::Struct>()) {
|
||||
auto* str_ty = program->Sem().Get(str);
|
||||
if (!str_ty->UsedAs(ast::StorageClass::kUniform) &&
|
||||
!str_ty->UsedAs(ast::StorageClass::kStorage)) {
|
||||
continue;
|
||||
}
|
||||
for (auto* member : str_ty->Members()) {
|
||||
auto* matrix = member->Type()->As<sem::Matrix>();
|
||||
if (!matrix) {
|
||||
continue;
|
||||
}
|
||||
auto* attr = ast::GetAttribute<ast::StrideAttribute>(
|
||||
member->Declaration()->attributes);
|
||||
if (!attr) {
|
||||
continue;
|
||||
}
|
||||
uint32_t stride = attr->stride;
|
||||
if (matrix->ColumnStride() == stride) {
|
||||
continue;
|
||||
}
|
||||
if (callback(member, matrix, stride) == GatherResult::kStop) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
DecomposeStridedMatrix::DecomposeStridedMatrix() = default;
|
||||
|
||||
DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
|
||||
|
||||
bool DecomposeStridedMatrix::ShouldRun(const Program* program,
|
||||
const DataMap&) const {
|
||||
bool should_run = false;
|
||||
GatherCustomStrideMatrixMembers(
|
||||
program, [&](const sem::StructMember*, sem::Matrix*, uint32_t) {
|
||||
should_run = true;
|
||||
return GatherResult::kStop;
|
||||
});
|
||||
return should_run;
|
||||
}
|
||||
|
||||
void DecomposeStridedMatrix::Run(CloneContext& ctx,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
// Scan the program for all storage and uniform structure matrix members with
|
||||
// a custom stride attribute. Replace these matrices with an equivalent array,
|
||||
// and populate the `decomposed` map with the members that have been replaced.
|
||||
std::unordered_map<const ast::StructMember*, MatrixInfo> decomposed;
|
||||
GatherCustomStrideMatrixMembers(
|
||||
ctx.src, [&](const sem::StructMember* member, sem::Matrix* matrix,
|
||||
uint32_t stride) {
|
||||
// We've got ourselves a struct member of a matrix type with a custom
|
||||
// stride. Replace this with an array of column vectors.
|
||||
MatrixInfo info{stride, matrix};
|
||||
auto* replacement = ctx.dst->Member(
|
||||
member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst));
|
||||
ctx.Replace(member->Declaration(), replacement);
|
||||
decomposed.emplace(member->Declaration(), info);
|
||||
return GatherResult::kContinue;
|
||||
});
|
||||
|
||||
// For all expressions where a single matrix column vector was indexed, we can
|
||||
// preserve these without calling conversion functions.
|
||||
// Example:
|
||||
// ssbo.mat[2] -> ssbo.mat[2]
|
||||
ctx.ReplaceAll([&](const ast::IndexAccessorExpression* expr)
|
||||
-> const ast::IndexAccessorExpression* {
|
||||
if (auto* access =
|
||||
ctx.src->Sem().Get<sem::StructMemberAccess>(expr->object)) {
|
||||
auto it = decomposed.find(access->Member()->Declaration());
|
||||
if (it != decomposed.end()) {
|
||||
auto* obj = ctx.CloneWithoutTransform(expr->object);
|
||||
auto* idx = ctx.Clone(expr->index);
|
||||
return ctx.dst->IndexAccessor(obj, idx);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
// For all struct member accesses to the matrix on the LHS of an assignment,
|
||||
// we need to convert the matrix to the array before assigning to the
|
||||
// structure.
|
||||
// Example:
|
||||
// ssbo.mat = mat_to_arr(m)
|
||||
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
|
||||
ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt)
|
||||
-> const ast::Statement* {
|
||||
if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
|
||||
auto it = decomposed.find(access->Member()->Declaration());
|
||||
if (it == decomposed.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
MatrixInfo info = it->second;
|
||||
auto fn = utils::GetOrCreate(mat_to_arr, info, [&] {
|
||||
auto name = ctx.dst->Symbols().New(
|
||||
"mat" + std::to_string(info.matrix->columns()) + "x" +
|
||||
std::to_string(info.matrix->rows()) + "_stride_" +
|
||||
std::to_string(info.stride) + "_to_arr");
|
||||
|
||||
auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
|
||||
auto array = [&] { return info.array(ctx.dst); };
|
||||
|
||||
auto mat = ctx.dst->Sym("m");
|
||||
ast::ExpressionList columns(info.matrix->columns());
|
||||
for (uint32_t i = 0; i < static_cast<uint32_t>(columns.size()); i++) {
|
||||
columns[i] = ctx.dst->IndexAccessor(mat, i);
|
||||
}
|
||||
ctx.dst->Func(name,
|
||||
{
|
||||
ctx.dst->Param(mat, matrix()),
|
||||
},
|
||||
array(),
|
||||
{
|
||||
ctx.dst->Return(ctx.dst->Construct(array(), columns)),
|
||||
});
|
||||
return name;
|
||||
});
|
||||
auto* lhs = ctx.CloneWithoutTransform(stmt->lhs);
|
||||
auto* rhs = ctx.dst->Call(fn, ctx.Clone(stmt->rhs));
|
||||
return ctx.dst->Assign(lhs, rhs);
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
// For all other struct member accesses, we need to convert the array to the
|
||||
// matrix type. Example:
|
||||
// m = arr_to_mat(ssbo.mat)
|
||||
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
|
||||
ctx.ReplaceAll(
|
||||
[&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* {
|
||||
if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr)) {
|
||||
auto it = decomposed.find(access->Member()->Declaration());
|
||||
if (it == decomposed.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
MatrixInfo info = it->second;
|
||||
auto fn = utils::GetOrCreate(arr_to_mat, info, [&] {
|
||||
auto name = ctx.dst->Symbols().New(
|
||||
"arr_to_mat" + std::to_string(info.matrix->columns()) + "x" +
|
||||
std::to_string(info.matrix->rows()) + "_stride_" +
|
||||
std::to_string(info.stride));
|
||||
|
||||
auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
|
||||
auto array = [&] { return info.array(ctx.dst); };
|
||||
|
||||
auto arr = ctx.dst->Sym("arr");
|
||||
ast::ExpressionList columns(info.matrix->columns());
|
||||
for (uint32_t i = 0; i < static_cast<uint32_t>(columns.size());
|
||||
i++) {
|
||||
columns[i] = ctx.dst->IndexAccessor(arr, i);
|
||||
}
|
||||
ctx.dst->Func(
|
||||
name,
|
||||
{
|
||||
ctx.dst->Param(arr, array()),
|
||||
},
|
||||
matrix(),
|
||||
{
|
||||
ctx.dst->Return(ctx.dst->Construct(matrix(), columns)),
|
||||
});
|
||||
return name;
|
||||
});
|
||||
return ctx.dst->Call(fn, ctx.CloneWithoutTransform(expr));
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
61
src/tint/transform/decompose_strided_matrix.h
Normal file
61
src/tint/transform/decompose_strided_matrix.h
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_DECOMPOSE_STRIDED_MATRIX_H_
|
||||
#define SRC_TINT_TRANSFORM_DECOMPOSE_STRIDED_MATRIX_H_
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// DecomposeStridedMatrix transforms replaces matrix members of storage or
|
||||
/// uniform buffer structures, that have a [[stride]] attribute, into an array
|
||||
/// of N column vectors.
|
||||
/// This transform is used by the SPIR-V reader to handle the SPIR-V
|
||||
/// MatrixStride attribute.
|
||||
///
|
||||
/// @note Depends on the following transforms to have been run first:
|
||||
/// * SimplifyPointers
|
||||
class DecomposeStridedMatrix
|
||||
: public Castable<DecomposeStridedMatrix, Transform> {
|
||||
public:
|
||||
/// Constructor
|
||||
DecomposeStridedMatrix();
|
||||
|
||||
/// Destructor
|
||||
~DecomposeStridedMatrix() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_DECOMPOSE_STRIDED_MATRIX_H_
|
||||
671
src/tint/transform/decompose_strided_matrix_test.cc
Normal file
671
src/tint/transform/decompose_strided_matrix_test.cc
Normal file
@@ -0,0 +1,671 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/decompose_strided_matrix.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/tint/ast/disable_validation_attribute.h"
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/transform/simplify_pointers.h"
|
||||
#include "src/tint/transform/test_helper.h"
|
||||
#include "src/tint/transform/unshadow.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using DecomposeStridedMatrixTest = TransformTest;
|
||||
using f32 = ProgramBuilder::f32;
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<DecomposeStridedMatrix>(src));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, ShouldRunNonStridedMatrox) {
|
||||
auto* src = R"(
|
||||
var<private> m : mat3x2<f32>;
|
||||
)";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<DecomposeStridedMatrix>(src));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, Empty) {
|
||||
auto* src = R"()";
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<DecomposeStridedMatrix>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix) {
|
||||
// struct S {
|
||||
// @offset(16) @stride(32)
|
||||
// @internal(ignore_stride_attribute)
|
||||
// m : mat2x2<f32>;
|
||||
// };
|
||||
// @group(0) @binding(0) var<uniform> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let x : mat2x2<f32> = s.m;
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure(
|
||||
"S",
|
||||
{
|
||||
b.Member(
|
||||
"m", b.ty.mat2x2<f32>(),
|
||||
{
|
||||
b.create<ast::StructMemberOffsetAttribute>(16),
|
||||
b.create<ast::StrideAttribute>(32),
|
||||
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
|
||||
}),
|
||||
});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
|
||||
b.GroupAndBinding(0, 0));
|
||||
b.Func(
|
||||
"f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
@size(16)
|
||||
padding : u32;
|
||||
m : @stride(32) array<vec2<f32>, 2u>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<uniform> s : S;
|
||||
|
||||
fn arr_to_mat2x2_stride_32(arr : @stride(32) array<vec2<f32>, 2u>) -> mat2x2<f32> {
|
||||
return mat2x2<f32>(arr[0u], arr[1u]);
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let x : mat2x2<f32> = arr_to_mat2x2_stride_32(s.m);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, ReadUniformColumn) {
|
||||
// struct S {
|
||||
// @offset(16) @stride(32)
|
||||
// @internal(ignore_stride_attribute)
|
||||
// m : mat2x2<f32>;
|
||||
// };
|
||||
// @group(0) @binding(0) var<uniform> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let x : vec2<f32> = s.m[1];
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure(
|
||||
"S",
|
||||
{
|
||||
b.Member(
|
||||
"m", b.ty.mat2x2<f32>(),
|
||||
{
|
||||
b.create<ast::StructMemberOffsetAttribute>(16),
|
||||
b.create<ast::StrideAttribute>(32),
|
||||
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
|
||||
}),
|
||||
});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
|
||||
b.GroupAndBinding(0, 0));
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("x", b.ty.vec2<f32>(),
|
||||
b.IndexAccessor(b.MemberAccessor("s", "m"), 1))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
@size(16)
|
||||
padding : u32;
|
||||
m : @stride(32) array<vec2<f32>, 2u>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<uniform> s : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let x : vec2<f32> = s.m[1];
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix_DefaultStride) {
|
||||
// struct S {
|
||||
// @offset(16) @stride(8)
|
||||
// @internal(ignore_stride_attribute)
|
||||
// m : mat2x2<f32>;
|
||||
// };
|
||||
// @group(0) @binding(0) var<uniform> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let x : mat2x2<f32> = s.m;
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure(
|
||||
"S",
|
||||
{
|
||||
b.Member(
|
||||
"m", b.ty.mat2x2<f32>(),
|
||||
{
|
||||
b.create<ast::StructMemberOffsetAttribute>(16),
|
||||
b.create<ast::StrideAttribute>(8),
|
||||
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
|
||||
}),
|
||||
});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
|
||||
b.GroupAndBinding(0, 0));
|
||||
b.Func(
|
||||
"f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
@size(16)
|
||||
padding : u32;
|
||||
@stride(8) @internal(disable_validation__ignore_stride)
|
||||
m : mat2x2<f32>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<uniform> s : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let x : mat2x2<f32> = s.m;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, ReadStorageMatrix) {
|
||||
// struct S {
|
||||
// @offset(8) @stride(32)
|
||||
// @internal(ignore_stride_attribute)
|
||||
// m : mat2x2<f32>;
|
||||
// };
|
||||
// @group(0) @binding(0) var<storage, read_write> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let x : mat2x2<f32> = s.m;
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure(
|
||||
"S",
|
||||
{
|
||||
b.Member(
|
||||
"m", b.ty.mat2x2<f32>(),
|
||||
{
|
||||
b.create<ast::StructMemberOffsetAttribute>(8),
|
||||
b.create<ast::StrideAttribute>(32),
|
||||
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
|
||||
}),
|
||||
});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
|
||||
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
|
||||
b.Func(
|
||||
"f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
@size(8)
|
||||
padding : u32;
|
||||
m : @stride(32) array<vec2<f32>, 2u>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> s : S;
|
||||
|
||||
fn arr_to_mat2x2_stride_32(arr : @stride(32) array<vec2<f32>, 2u>) -> mat2x2<f32> {
|
||||
return mat2x2<f32>(arr[0u], arr[1u]);
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let x : mat2x2<f32> = arr_to_mat2x2_stride_32(s.m);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, ReadStorageColumn) {
|
||||
// struct S {
|
||||
// @offset(16) @stride(32)
|
||||
// @internal(ignore_stride_attribute)
|
||||
// m : mat2x2<f32>;
|
||||
// };
|
||||
// @group(0) @binding(0) var<storage, read_write> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let x : vec2<f32> = s.m[1];
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure(
|
||||
"S",
|
||||
{
|
||||
b.Member(
|
||||
"m", b.ty.mat2x2<f32>(),
|
||||
{
|
||||
b.create<ast::StructMemberOffsetAttribute>(16),
|
||||
b.create<ast::StrideAttribute>(32),
|
||||
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
|
||||
}),
|
||||
});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
|
||||
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("x", b.ty.vec2<f32>(),
|
||||
b.IndexAccessor(b.MemberAccessor("s", "m"), 1))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
@size(16)
|
||||
padding : u32;
|
||||
m : @stride(32) array<vec2<f32>, 2u>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> s : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let x : vec2<f32> = s.m[1];
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, WriteStorageMatrix) {
|
||||
// struct S {
|
||||
// @offset(8) @stride(32)
|
||||
// @internal(ignore_stride_attribute)
|
||||
// m : mat2x2<f32>;
|
||||
// };
|
||||
// @group(0) @binding(0) var<storage, read_write> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure(
|
||||
"S",
|
||||
{
|
||||
b.Member(
|
||||
"m", b.ty.mat2x2<f32>(),
|
||||
{
|
||||
b.create<ast::StructMemberOffsetAttribute>(8),
|
||||
b.create<ast::StrideAttribute>(32),
|
||||
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
|
||||
}),
|
||||
});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
|
||||
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Assign(b.MemberAccessor("s", "m"),
|
||||
b.mat2x2<f32>(b.vec2<f32>(1.0f, 2.0f),
|
||||
b.vec2<f32>(3.0f, 4.0f))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
@size(8)
|
||||
padding : u32;
|
||||
m : @stride(32) array<vec2<f32>, 2u>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> s : S;
|
||||
|
||||
fn mat2x2_stride_32_to_arr(m : mat2x2<f32>) -> @stride(32) array<vec2<f32>, 2u> {
|
||||
return @stride(32) array<vec2<f32>, 2u>(m[0u], m[1u]);
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
s.m = mat2x2_stride_32_to_arr(mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0)));
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, WriteStorageColumn) {
|
||||
// struct S {
|
||||
// @offset(8) @stride(32)
|
||||
// @internal(ignore_stride_attribute)
|
||||
// m : mat2x2<f32>;
|
||||
// };
|
||||
// @group(0) @binding(0) var<storage, read_write> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// s.m[1] = vec2<f32>(1.0, 2.0);
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure(
|
||||
"S",
|
||||
{
|
||||
b.Member(
|
||||
"m", b.ty.mat2x2<f32>(),
|
||||
{
|
||||
b.create<ast::StructMemberOffsetAttribute>(8),
|
||||
b.create<ast::StrideAttribute>(32),
|
||||
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
|
||||
}),
|
||||
});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
|
||||
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m"), 1),
|
||||
b.vec2<f32>(1.0f, 2.0f)),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
@size(8)
|
||||
padding : u32;
|
||||
m : @stride(32) array<vec2<f32>, 2u>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> s : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
s.m[1] = vec2<f32>(1.0, 2.0);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, ReadWriteViaPointerLets) {
|
||||
// struct S {
|
||||
// @offset(8) @stride(32)
|
||||
// @internal(ignore_stride_attribute)
|
||||
// m : mat2x2<f32>;
|
||||
// };
|
||||
// @group(0) @binding(0) var<storage, read_write> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let a = &s.m;
|
||||
// let b = &*&*(a);
|
||||
// let x = *b;
|
||||
// let y = (*b)[1];
|
||||
// let z = x[1];
|
||||
// (*b) = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
|
||||
// (*b)[1] = vec2<f32>(5.0, 6.0);
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure(
|
||||
"S",
|
||||
{
|
||||
b.Member(
|
||||
"m", b.ty.mat2x2<f32>(),
|
||||
{
|
||||
b.create<ast::StructMemberOffsetAttribute>(8),
|
||||
b.create<ast::StrideAttribute>(32),
|
||||
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
|
||||
}),
|
||||
});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
|
||||
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
|
||||
b.Func(
|
||||
"f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(
|
||||
b.Const("a", nullptr, b.AddressOf(b.MemberAccessor("s", "m")))),
|
||||
b.Decl(b.Const("b", nullptr,
|
||||
b.AddressOf(b.Deref(b.AddressOf(b.Deref("a")))))),
|
||||
b.Decl(b.Const("x", nullptr, b.Deref("b"))),
|
||||
b.Decl(b.Const("y", nullptr, b.IndexAccessor(b.Deref("b"), 1))),
|
||||
b.Decl(b.Const("z", nullptr, b.IndexAccessor("x", 1))),
|
||||
b.Assign(b.Deref("b"), b.mat2x2<f32>(b.vec2<f32>(1.0f, 2.0f),
|
||||
b.vec2<f32>(3.0f, 4.0f))),
|
||||
b.Assign(b.IndexAccessor(b.Deref("b"), 1), b.vec2<f32>(5.0f, 6.0f)),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
@size(8)
|
||||
padding : u32;
|
||||
m : @stride(32) array<vec2<f32>, 2u>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> s : S;
|
||||
|
||||
fn arr_to_mat2x2_stride_32(arr : @stride(32) array<vec2<f32>, 2u>) -> mat2x2<f32> {
|
||||
return mat2x2<f32>(arr[0u], arr[1u]);
|
||||
}
|
||||
|
||||
fn mat2x2_stride_32_to_arr(m : mat2x2<f32>) -> @stride(32) array<vec2<f32>, 2u> {
|
||||
return @stride(32) array<vec2<f32>, 2u>(m[0u], m[1u]);
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let x = arr_to_mat2x2_stride_32(s.m);
|
||||
let y = s.m[1];
|
||||
let z = x[1];
|
||||
s.m = mat2x2_stride_32_to_arr(mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0)));
|
||||
s.m[1] = vec2<f32>(5.0, 6.0);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, ReadPrivateMatrix) {
|
||||
// struct S {
|
||||
// @offset(8) @stride(32)
|
||||
// @internal(ignore_stride_attribute)
|
||||
// m : mat2x2<f32>;
|
||||
// };
|
||||
// var<private> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// let x : mat2x2<f32> = s.m;
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure(
|
||||
"S",
|
||||
{
|
||||
b.Member(
|
||||
"m", b.ty.mat2x2<f32>(),
|
||||
{
|
||||
b.create<ast::StructMemberOffsetAttribute>(8),
|
||||
b.create<ast::StrideAttribute>(32),
|
||||
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
|
||||
}),
|
||||
});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kPrivate);
|
||||
b.Func(
|
||||
"f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
@size(8)
|
||||
padding : u32;
|
||||
@stride(32) @internal(disable_validation__ignore_stride)
|
||||
m : mat2x2<f32>;
|
||||
}
|
||||
|
||||
var<private> s : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
let x : mat2x2<f32> = s.m;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, WritePrivateMatrix) {
|
||||
// struct S {
|
||||
// @offset(8) @stride(32)
|
||||
// @internal(ignore_stride_attribute)
|
||||
// m : mat2x2<f32>;
|
||||
// };
|
||||
// var<private> s : S;
|
||||
//
|
||||
// @stage(compute) @workgroup_size(1)
|
||||
// fn f() {
|
||||
// s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure(
|
||||
"S",
|
||||
{
|
||||
b.Member(
|
||||
"m", b.ty.mat2x2<f32>(),
|
||||
{
|
||||
b.create<ast::StructMemberOffsetAttribute>(8),
|
||||
b.create<ast::StrideAttribute>(32),
|
||||
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
|
||||
}),
|
||||
});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kPrivate);
|
||||
b.Func("f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Assign(b.MemberAccessor("s", "m"),
|
||||
b.mat2x2<f32>(b.vec2<f32>(1.0f, 2.0f),
|
||||
b.vec2<f32>(3.0f, 4.0f))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
@size(8)
|
||||
padding : u32;
|
||||
@stride(32) @internal(disable_validation__ignore_stride)
|
||||
m : mat2x2<f32>;
|
||||
}
|
||||
|
||||
var<private> s : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {
|
||||
s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
|
||||
Program(std::move(b)));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
131
src/tint/transform/external_texture_transform.cc
Normal file
131
src/tint/transform/external_texture_transform.cc
Normal file
@@ -0,0 +1,131 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/external_texture_transform.h"
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/call.h"
|
||||
#include "src/tint/sem/variable.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::ExternalTextureTransform);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
ExternalTextureTransform::ExternalTextureTransform() = default;
|
||||
ExternalTextureTransform::~ExternalTextureTransform() = default;
|
||||
|
||||
void ExternalTextureTransform::Run(CloneContext& ctx,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
auto& sem = ctx.src->Sem();
|
||||
|
||||
// Within this transform, usages of texture_external are replaced with a
|
||||
// texture_2d<f32>, which will allow us perform operations on a
|
||||
// texture_external without maintaining texture_external-specific code
|
||||
// generation paths in the backends.
|
||||
|
||||
// When replacing instances of texture_external with texture_2d<f32> we must
|
||||
// also modify calls to the texture_external overloads of textureLoad and
|
||||
// textureSampleLevel, which unlike their texture_2d<f32> overloads do not
|
||||
// require a level parameter. To do this we identify calls to textureLoad and
|
||||
// textureSampleLevel that use texture_external as the first parameter and add
|
||||
// a parameter for the level (which is always 0).
|
||||
|
||||
// Scan the AST nodes for calls to textureLoad or textureSampleLevel.
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
if (auto* call_expr = node->As<ast::CallExpression>()) {
|
||||
if (auto* builtin = sem.Get(call_expr)->Target()->As<sem::Builtin>()) {
|
||||
if (builtin->Type() == sem::BuiltinType::kTextureLoad ||
|
||||
builtin->Type() == sem::BuiltinType::kTextureSampleLevel) {
|
||||
// When a textureLoad or textureSampleLevel has been identified, check
|
||||
// if the first parameter is an external texture.
|
||||
if (auto* var =
|
||||
sem.Get(call_expr->args[0])->As<sem::VariableUser>()) {
|
||||
if (var->Variable()
|
||||
->Type()
|
||||
->UnwrapRef()
|
||||
->Is<sem::ExternalTexture>()) {
|
||||
if (builtin->Type() == sem::BuiltinType::kTextureLoad &&
|
||||
call_expr->args.size() != 2) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
<< "expected textureLoad call with a texture_external to "
|
||||
"have 2 parameters, found "
|
||||
<< call_expr->args.size() << " parameters";
|
||||
}
|
||||
|
||||
if (builtin->Type() == sem::BuiltinType::kTextureSampleLevel &&
|
||||
call_expr->args.size() != 3) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
<< "expected textureSampleLevel call with a "
|
||||
"texture_external to have 3 parameters, found "
|
||||
<< call_expr->args.size() << " parameters";
|
||||
}
|
||||
|
||||
// Replace the call with another that has the same parameters in
|
||||
// addition to a level parameter (always zero for external
|
||||
// textures).
|
||||
auto* exp = ctx.Clone(call_expr->target.name);
|
||||
auto* externalTextureParam = ctx.Clone(call_expr->args[0]);
|
||||
|
||||
ast::ExpressionList params;
|
||||
if (builtin->Type() == sem::BuiltinType::kTextureLoad) {
|
||||
auto* coordsParam = ctx.Clone(call_expr->args[1]);
|
||||
auto* levelParam = ctx.dst->Expr(0);
|
||||
params = {externalTextureParam, coordsParam, levelParam};
|
||||
} else if (builtin->Type() ==
|
||||
sem::BuiltinType::kTextureSampleLevel) {
|
||||
auto* samplerParam = ctx.Clone(call_expr->args[1]);
|
||||
auto* coordsParam = ctx.Clone(call_expr->args[2]);
|
||||
auto* levelParam = ctx.dst->Expr(0.0f);
|
||||
params = {externalTextureParam, samplerParam, coordsParam,
|
||||
levelParam};
|
||||
}
|
||||
|
||||
auto* newCall = ctx.dst->create<ast::CallExpression>(exp, params);
|
||||
ctx.Replace(call_expr, newCall);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Scan the AST nodes for external texture declarations.
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
if (auto* var = node->As<ast::Variable>()) {
|
||||
if (::tint::Is<ast::ExternalTexture>(var->type)) {
|
||||
// Replace a single-plane external texture with a 2D, f32 sampled
|
||||
// texture.
|
||||
auto* newType = ctx.dst->ty.sampled_texture(ast::TextureDimension::k2d,
|
||||
ctx.dst->ty.f32());
|
||||
auto clonedSrc = ctx.Clone(var->source);
|
||||
auto clonedSym = ctx.Clone(var->symbol);
|
||||
auto* clonedConstructor = ctx.Clone(var->constructor);
|
||||
auto clonedAttributes = ctx.Clone(var->attributes);
|
||||
auto* newVar = ctx.dst->create<ast::Variable>(
|
||||
clonedSrc, clonedSym, var->declared_storage_class,
|
||||
var->declared_access, newType, false, false, clonedConstructor,
|
||||
clonedAttributes);
|
||||
|
||||
ctx.Replace(var, newVar);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
53
src/tint/transform/external_texture_transform.h
Normal file
53
src/tint/transform/external_texture_transform.h
Normal file
@@ -0,0 +1,53 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_EXTERNAL_TEXTURE_TRANSFORM_H_
|
||||
#define SRC_TINT_TRANSFORM_EXTERNAL_TEXTURE_TRANSFORM_H_
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// Because an external texture is comprised of 1-3 texture views we can simply
|
||||
/// transform external textures into the appropriate number of sampled textures.
|
||||
/// This allows us to share SPIR-V/HLSL writer paths for sampled textures
|
||||
/// instead of adding dedicated writer paths for external textures.
|
||||
/// ExternalTextureTransform performs this transformation.
|
||||
class ExternalTextureTransform
|
||||
: public Castable<ExternalTextureTransform, Transform> {
|
||||
public:
|
||||
/// Constructor
|
||||
ExternalTextureTransform();
|
||||
/// Destructor
|
||||
~ExternalTextureTransform() override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_EXTERNAL_TEXTURE_TRANSFORM_H_
|
||||
187
src/tint/transform/external_texture_transform_test.cc
Normal file
187
src/tint/transform/external_texture_transform_test.cc
Normal file
@@ -0,0 +1,187 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/external_texture_transform.h"
|
||||
|
||||
#include "src/tint/transform/test_helper.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using ExternalTextureTransformTest = TransformTest;
|
||||
|
||||
TEST_F(ExternalTextureTransformTest, SampleLevelSinglePlane) {
|
||||
auto* src = R"(
|
||||
@group(0) @binding(0) var s : sampler;
|
||||
|
||||
@group(0) @binding(1) var t : texture_external;
|
||||
|
||||
@stage(fragment)
|
||||
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
|
||||
return textureSampleLevel(t, s, (coord.xy / vec2<f32>(4.0, 4.0)));
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@group(0) @binding(0) var s : sampler;
|
||||
|
||||
@group(0) @binding(1) var t : texture_2d<f32>;
|
||||
|
||||
@stage(fragment)
|
||||
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
|
||||
return textureSampleLevel(t, s, (coord.xy / vec2<f32>(4.0, 4.0)), 0.0);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<ExternalTextureTransform>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(ExternalTextureTransformTest, SampleLevelSinglePlane_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
@stage(fragment)
|
||||
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
|
||||
return textureSampleLevel(t, s, (coord.xy / vec2<f32>(4.0, 4.0)));
|
||||
}
|
||||
|
||||
@group(0) @binding(1) var t : texture_external;
|
||||
|
||||
@group(0) @binding(0) var s : sampler;
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@stage(fragment)
|
||||
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
|
||||
return textureSampleLevel(t, s, (coord.xy / vec2<f32>(4.0, 4.0)), 0.0);
|
||||
}
|
||||
|
||||
@group(0) @binding(1) var t : texture_2d<f32>;
|
||||
|
||||
@group(0) @binding(0) var s : sampler;
|
||||
)";
|
||||
|
||||
auto got = Run<ExternalTextureTransform>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(ExternalTextureTransformTest, LoadSinglePlane) {
|
||||
auto* src = R"(
|
||||
@group(0) @binding(0) var t : texture_external;
|
||||
|
||||
@stage(fragment)
|
||||
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
|
||||
return textureLoad(t, vec2<i32>(1, 1));
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@group(0) @binding(0) var t : texture_2d<f32>;
|
||||
|
||||
@stage(fragment)
|
||||
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
|
||||
return textureLoad(t, vec2<i32>(1, 1), 0);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<ExternalTextureTransform>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(ExternalTextureTransformTest, LoadSinglePlane_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
@stage(fragment)
|
||||
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
|
||||
return textureLoad(t, vec2<i32>(1, 1));
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var t : texture_external;
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@stage(fragment)
|
||||
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
|
||||
return textureLoad(t, vec2<i32>(1, 1), 0);
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var t : texture_2d<f32>;
|
||||
)";
|
||||
|
||||
auto got = Run<ExternalTextureTransform>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(ExternalTextureTransformTest, DimensionsSinglePlane) {
|
||||
auto* src = R"(
|
||||
@group(0) @binding(0) var t : texture_external;
|
||||
|
||||
@stage(fragment)
|
||||
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
|
||||
var dim : vec2<i32>;
|
||||
dim = textureDimensions(t);
|
||||
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@group(0) @binding(0) var t : texture_2d<f32>;
|
||||
|
||||
@stage(fragment)
|
||||
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
|
||||
var dim : vec2<i32>;
|
||||
dim = textureDimensions(t);
|
||||
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<ExternalTextureTransform>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(ExternalTextureTransformTest, DimensionsSinglePlane_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
@stage(fragment)
|
||||
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
|
||||
var dim : vec2<i32>;
|
||||
dim = textureDimensions(t);
|
||||
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var t : texture_external;
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@stage(fragment)
|
||||
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
|
||||
var dim : vec2<i32>;
|
||||
dim = textureDimensions(t);
|
||||
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var t : texture_2d<f32>;
|
||||
)";
|
||||
|
||||
auto got = Run<ExternalTextureTransform>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
188
src/tint/transform/first_index_offset.cc
Normal file
188
src/tint/transform/first_index_offset.cc
Normal file
@@ -0,0 +1,188 @@
|
||||
// Copyright 2020 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/first_index_offset.h"
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/function.h"
|
||||
#include "src/tint/sem/member_accessor_expression.h"
|
||||
#include "src/tint/sem/struct.h"
|
||||
#include "src/tint/sem/variable.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::FirstIndexOffset);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::FirstIndexOffset::BindingPoint);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::FirstIndexOffset::Data);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
// Uniform buffer member names
|
||||
constexpr char kFirstVertexName[] = "first_vertex_index";
|
||||
constexpr char kFirstInstanceName[] = "first_instance_index";
|
||||
|
||||
} // namespace
|
||||
|
||||
FirstIndexOffset::BindingPoint::BindingPoint() = default;
|
||||
FirstIndexOffset::BindingPoint::BindingPoint(uint32_t b, uint32_t g)
|
||||
: binding(b), group(g) {}
|
||||
FirstIndexOffset::BindingPoint::~BindingPoint() = default;
|
||||
|
||||
FirstIndexOffset::Data::Data(bool has_vtx_index,
|
||||
bool has_inst_index,
|
||||
uint32_t first_vtx_offset,
|
||||
uint32_t first_inst_offset)
|
||||
: has_vertex_index(has_vtx_index),
|
||||
has_instance_index(has_inst_index),
|
||||
first_vertex_offset(first_vtx_offset),
|
||||
first_instance_offset(first_inst_offset) {}
|
||||
FirstIndexOffset::Data::Data(const Data&) = default;
|
||||
FirstIndexOffset::Data::~Data() = default;
|
||||
|
||||
FirstIndexOffset::FirstIndexOffset() = default;
|
||||
FirstIndexOffset::~FirstIndexOffset() = default;
|
||||
|
||||
bool FirstIndexOffset::ShouldRun(const Program* program, const DataMap&) const {
|
||||
for (auto* fn : program->AST().Functions()) {
|
||||
if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void FirstIndexOffset::Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const {
|
||||
// Get the uniform buffer binding point
|
||||
uint32_t ub_binding = binding_;
|
||||
uint32_t ub_group = group_;
|
||||
if (auto* binding_point = inputs.Get<BindingPoint>()) {
|
||||
ub_binding = binding_point->binding;
|
||||
ub_group = binding_point->group;
|
||||
}
|
||||
|
||||
// Map of builtin usages
|
||||
std::unordered_map<const sem::Variable*, const char*> builtin_vars;
|
||||
std::unordered_map<const sem::StructMember*, const char*> builtin_members;
|
||||
|
||||
bool has_vertex_index = false;
|
||||
bool has_instance_index = false;
|
||||
|
||||
// Traverse the AST scanning for builtin accesses via variables (includes
|
||||
// parameters) or structure member accesses.
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
if (auto* var = node->As<ast::Variable>()) {
|
||||
for (auto* attr : var->attributes) {
|
||||
if (auto* builtin_attr = attr->As<ast::BuiltinAttribute>()) {
|
||||
ast::Builtin builtin = builtin_attr->builtin;
|
||||
if (builtin == ast::Builtin::kVertexIndex) {
|
||||
auto* sem_var = ctx.src->Sem().Get(var);
|
||||
builtin_vars.emplace(sem_var, kFirstVertexName);
|
||||
has_vertex_index = true;
|
||||
}
|
||||
if (builtin == ast::Builtin::kInstanceIndex) {
|
||||
auto* sem_var = ctx.src->Sem().Get(var);
|
||||
builtin_vars.emplace(sem_var, kFirstInstanceName);
|
||||
has_instance_index = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (auto* member = node->As<ast::StructMember>()) {
|
||||
for (auto* attr : member->attributes) {
|
||||
if (auto* builtin_attr = attr->As<ast::BuiltinAttribute>()) {
|
||||
ast::Builtin builtin = builtin_attr->builtin;
|
||||
if (builtin == ast::Builtin::kVertexIndex) {
|
||||
auto* sem_mem = ctx.src->Sem().Get(member);
|
||||
builtin_members.emplace(sem_mem, kFirstVertexName);
|
||||
has_vertex_index = true;
|
||||
}
|
||||
if (builtin == ast::Builtin::kInstanceIndex) {
|
||||
auto* sem_mem = ctx.src->Sem().Get(member);
|
||||
builtin_members.emplace(sem_mem, kFirstInstanceName);
|
||||
has_instance_index = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Byte offsets on the uniform buffer
|
||||
uint32_t vertex_index_offset = 0;
|
||||
uint32_t instance_index_offset = 0;
|
||||
|
||||
if (has_vertex_index || has_instance_index) {
|
||||
// Add uniform buffer members and calculate byte offsets
|
||||
uint32_t offset = 0;
|
||||
ast::StructMemberList members;
|
||||
if (has_vertex_index) {
|
||||
members.push_back(ctx.dst->Member(kFirstVertexName, ctx.dst->ty.u32()));
|
||||
vertex_index_offset = offset;
|
||||
offset += 4;
|
||||
}
|
||||
if (has_instance_index) {
|
||||
members.push_back(ctx.dst->Member(kFirstInstanceName, ctx.dst->ty.u32()));
|
||||
instance_index_offset = offset;
|
||||
offset += 4;
|
||||
}
|
||||
auto* struct_ = ctx.dst->Structure(ctx.dst->Sym(), std::move(members));
|
||||
|
||||
// Create a global to hold the uniform buffer
|
||||
Symbol buffer_name = ctx.dst->Sym();
|
||||
ctx.dst->Global(buffer_name, ctx.dst->ty.Of(struct_),
|
||||
ast::StorageClass::kUniform, nullptr,
|
||||
ast::AttributeList{
|
||||
ctx.dst->create<ast::BindingAttribute>(ub_binding),
|
||||
ctx.dst->create<ast::GroupAttribute>(ub_group),
|
||||
});
|
||||
|
||||
// Fix up all references to the builtins with the offsets
|
||||
ctx.ReplaceAll(
|
||||
[=, &ctx](const ast::Expression* expr) -> const ast::Expression* {
|
||||
if (auto* sem = ctx.src->Sem().Get(expr)) {
|
||||
if (auto* user = sem->As<sem::VariableUser>()) {
|
||||
auto it = builtin_vars.find(user->Variable());
|
||||
if (it != builtin_vars.end()) {
|
||||
return ctx.dst->Add(
|
||||
ctx.CloneWithoutTransform(expr),
|
||||
ctx.dst->MemberAccessor(buffer_name, it->second));
|
||||
}
|
||||
}
|
||||
if (auto* access = sem->As<sem::StructMemberAccess>()) {
|
||||
auto it = builtin_members.find(access->Member());
|
||||
if (it != builtin_members.end()) {
|
||||
return ctx.dst->Add(
|
||||
ctx.CloneWithoutTransform(expr),
|
||||
ctx.dst->MemberAccessor(buffer_name, it->second));
|
||||
}
|
||||
}
|
||||
}
|
||||
// Not interested in this experssion. Just clone.
|
||||
return nullptr;
|
||||
});
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
|
||||
outputs.Add<Data>(has_vertex_index, has_instance_index, vertex_index_offset,
|
||||
instance_index_offset);
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
143
src/tint/transform/first_index_offset.h
Normal file
143
src/tint/transform/first_index_offset.h
Normal file
@@ -0,0 +1,143 @@
|
||||
// Copyright 2020 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_FIRST_INDEX_OFFSET_H_
|
||||
#define SRC_TINT_TRANSFORM_FIRST_INDEX_OFFSET_H_
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// Adds firstVertex/Instance (injected via root constants) to
|
||||
/// vertex/instance index builtins.
|
||||
///
|
||||
/// This transform assumes that Name transform has been run before.
|
||||
///
|
||||
/// Unlike other APIs, D3D always starts vertex and instance numbering at 0,
|
||||
/// regardless of the firstVertex/Instance value specified. This transformer
|
||||
/// adds the value of firstVertex/Instance to each builtin. This action is
|
||||
/// performed by adding a new constant equal to original builtin +
|
||||
/// firstVertex/Instance to each function that references one of these builtins.
|
||||
///
|
||||
/// Note that D3D does not have any semantics for firstVertex/Instance.
|
||||
/// Therefore, these values must by passed to the shader.
|
||||
///
|
||||
/// Before:
|
||||
/// ```
|
||||
/// @builtin(vertex_index) var<in> vert_idx : u32;
|
||||
/// fn func() -> u32 {
|
||||
/// return vert_idx;
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// After:
|
||||
/// ```
|
||||
/// struct TintFirstIndexOffsetData {
|
||||
/// tint_first_vertex_index : u32;
|
||||
/// tint_first_instance_index : u32;
|
||||
/// };
|
||||
/// @builtin(vertex_index) var<in> tint_first_index_offset_vert_idx : u32;
|
||||
/// @binding(N) @group(M) var<uniform> tint_first_index_data :
|
||||
/// TintFirstIndexOffsetData;
|
||||
/// fn func() -> u32 {
|
||||
/// const vert_idx = (tint_first_index_offset_vert_idx +
|
||||
/// tint_first_index_data.tint_first_vertex_index);
|
||||
/// return vert_idx;
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
class FirstIndexOffset : public Castable<FirstIndexOffset, Transform> {
|
||||
public:
|
||||
/// BindingPoint is consumed by the FirstIndexOffset transform.
|
||||
/// BindingPoint specifies the binding point of the first index uniform
|
||||
/// buffer.
|
||||
struct BindingPoint : public Castable<BindingPoint, transform::Data> {
|
||||
/// Constructor
|
||||
BindingPoint();
|
||||
|
||||
/// Constructor
|
||||
/// @param b the binding index
|
||||
/// @param g the binding group
|
||||
BindingPoint(uint32_t b, uint32_t g);
|
||||
|
||||
/// Destructor
|
||||
~BindingPoint() override;
|
||||
|
||||
/// `@binding()` for the first vertex / first instance uniform buffer
|
||||
uint32_t binding = 0;
|
||||
/// `@group()` for the first vertex / first instance uniform buffer
|
||||
uint32_t group = 0;
|
||||
};
|
||||
|
||||
/// Data is outputted by the FirstIndexOffset transform.
|
||||
/// Data holds information about shader usage and constant buffer offsets.
|
||||
struct Data : public Castable<Data, transform::Data> {
|
||||
/// Constructor
|
||||
/// @param has_vtx_index True if the shader uses vertex_index
|
||||
/// @param has_inst_index True if the shader uses instance_index
|
||||
/// @param first_vtx_offset Offset of first vertex into constant buffer
|
||||
/// @param first_inst_offset Offset of first instance into constant buffer
|
||||
Data(bool has_vtx_index,
|
||||
bool has_inst_index,
|
||||
uint32_t first_vtx_offset,
|
||||
uint32_t first_inst_offset);
|
||||
|
||||
/// Copy constructor
|
||||
Data(const Data&);
|
||||
|
||||
/// Destructor
|
||||
~Data() override;
|
||||
|
||||
/// True if the shader uses vertex_index
|
||||
const bool has_vertex_index;
|
||||
/// True if the shader uses instance_index
|
||||
const bool has_instance_index;
|
||||
/// Offset of first vertex into constant buffer
|
||||
const uint32_t first_vertex_offset;
|
||||
/// Offset of first instance into constant buffer
|
||||
const uint32_t first_instance_offset;
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
FirstIndexOffset();
|
||||
/// Destructor
|
||||
~FirstIndexOffset() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
private:
|
||||
uint32_t binding_ = 0;
|
||||
uint32_t group_ = 0;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_FIRST_INDEX_OFFSET_H_
|
||||
650
src/tint/transform/first_index_offset_test.cc
Normal file
650
src/tint/transform/first_index_offset_test.cc
Normal file
@@ -0,0 +1,650 @@
|
||||
// Copyright 2020 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/first_index_offset.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/tint/transform/test_helper.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using FirstIndexOffsetTest = TransformTest;
|
||||
|
||||
TEST_F(FirstIndexOffsetTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<FirstIndexOffset>(src));
|
||||
}
|
||||
|
||||
TEST_F(FirstIndexOffsetTest, ShouldRunFragmentStage) {
|
||||
auto* src = R"(
|
||||
[[stage(fragment)]]
|
||||
fn entry() {
|
||||
return;
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<FirstIndexOffset>(src));
|
||||
}
|
||||
|
||||
TEST_F(FirstIndexOffsetTest, ShouldRunVertexStage) {
|
||||
auto* src = R"(
|
||||
[[stage(vertex)]]
|
||||
fn entry() -> [[builtin(position)]] vec4<f32> {
|
||||
return vec4<f32>();
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<FirstIndexOffset>(src));
|
||||
}
|
||||
|
||||
TEST_F(FirstIndexOffsetTest, EmptyModule) {
|
||||
auto* src = "";
|
||||
auto* expect = "";
|
||||
|
||||
DataMap config;
|
||||
config.Add<FirstIndexOffset::BindingPoint>(0, 0);
|
||||
auto got = Run<FirstIndexOffset>(src, std::move(config));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
|
||||
auto* data = got.data.Get<FirstIndexOffset::Data>();
|
||||
|
||||
EXPECT_EQ(data, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(FirstIndexOffsetTest, BasicVertexShader) {
|
||||
auto* src = R"(
|
||||
@stage(vertex)
|
||||
fn entry() -> @builtin(position) vec4<f32> {
|
||||
return vec4<f32>();
|
||||
}
|
||||
)";
|
||||
auto* expect = src;
|
||||
|
||||
DataMap config;
|
||||
config.Add<FirstIndexOffset::BindingPoint>(0, 0);
|
||||
auto got = Run<FirstIndexOffset>(src, std::move(config));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
|
||||
auto* data = got.data.Get<FirstIndexOffset::Data>();
|
||||
|
||||
ASSERT_NE(data, nullptr);
|
||||
EXPECT_EQ(data->has_vertex_index, false);
|
||||
EXPECT_EQ(data->has_instance_index, false);
|
||||
EXPECT_EQ(data->first_vertex_offset, 0u);
|
||||
EXPECT_EQ(data->first_instance_offset, 0u);
|
||||
}
|
||||
|
||||
TEST_F(FirstIndexOffsetTest, BasicModuleVertexIndex) {
|
||||
auto* src = R"(
|
||||
fn test(vert_idx : u32) -> u32 {
|
||||
return vert_idx;
|
||||
}
|
||||
|
||||
@stage(vertex)
|
||||
fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
|
||||
test(vert_idx);
|
||||
return vec4<f32>();
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol {
|
||||
first_vertex_index : u32;
|
||||
}
|
||||
|
||||
@binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
|
||||
|
||||
fn test(vert_idx : u32) -> u32 {
|
||||
return vert_idx;
|
||||
}
|
||||
|
||||
@stage(vertex)
|
||||
fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
|
||||
test((vert_idx + tint_symbol_1.first_vertex_index));
|
||||
return vec4<f32>();
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap config;
|
||||
config.Add<FirstIndexOffset::BindingPoint>(1, 2);
|
||||
auto got = Run<FirstIndexOffset>(src, std::move(config));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
|
||||
auto* data = got.data.Get<FirstIndexOffset::Data>();
|
||||
|
||||
ASSERT_NE(data, nullptr);
|
||||
EXPECT_EQ(data->has_vertex_index, true);
|
||||
EXPECT_EQ(data->has_instance_index, false);
|
||||
EXPECT_EQ(data->first_vertex_offset, 0u);
|
||||
EXPECT_EQ(data->first_instance_offset, 0u);
|
||||
}
|
||||
|
||||
TEST_F(FirstIndexOffsetTest, BasicModuleVertexIndex_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
@stage(vertex)
|
||||
fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
|
||||
test(vert_idx);
|
||||
return vec4<f32>();
|
||||
}
|
||||
|
||||
fn test(vert_idx : u32) -> u32 {
|
||||
return vert_idx;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol {
|
||||
first_vertex_index : u32;
|
||||
}
|
||||
|
||||
@binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
|
||||
|
||||
@stage(vertex)
|
||||
fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
|
||||
test((vert_idx + tint_symbol_1.first_vertex_index));
|
||||
return vec4<f32>();
|
||||
}
|
||||
|
||||
fn test(vert_idx : u32) -> u32 {
|
||||
return vert_idx;
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap config;
|
||||
config.Add<FirstIndexOffset::BindingPoint>(1, 2);
|
||||
auto got = Run<FirstIndexOffset>(src, std::move(config));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
|
||||
auto* data = got.data.Get<FirstIndexOffset::Data>();
|
||||
|
||||
ASSERT_NE(data, nullptr);
|
||||
EXPECT_EQ(data->has_vertex_index, true);
|
||||
EXPECT_EQ(data->has_instance_index, false);
|
||||
EXPECT_EQ(data->first_vertex_offset, 0u);
|
||||
EXPECT_EQ(data->first_instance_offset, 0u);
|
||||
}
|
||||
|
||||
TEST_F(FirstIndexOffsetTest, BasicModuleInstanceIndex) {
|
||||
auto* src = R"(
|
||||
fn test(inst_idx : u32) -> u32 {
|
||||
return inst_idx;
|
||||
}
|
||||
|
||||
@stage(vertex)
|
||||
fn entry(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
|
||||
test(inst_idx);
|
||||
return vec4<f32>();
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol {
|
||||
first_instance_index : u32;
|
||||
}
|
||||
|
||||
@binding(1) @group(7) var<uniform> tint_symbol_1 : tint_symbol;
|
||||
|
||||
fn test(inst_idx : u32) -> u32 {
|
||||
return inst_idx;
|
||||
}
|
||||
|
||||
@stage(vertex)
|
||||
fn entry(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
|
||||
test((inst_idx + tint_symbol_1.first_instance_index));
|
||||
return vec4<f32>();
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap config;
|
||||
config.Add<FirstIndexOffset::BindingPoint>(1, 7);
|
||||
auto got = Run<FirstIndexOffset>(src, std::move(config));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
|
||||
auto* data = got.data.Get<FirstIndexOffset::Data>();
|
||||
|
||||
ASSERT_NE(data, nullptr);
|
||||
EXPECT_EQ(data->has_vertex_index, false);
|
||||
EXPECT_EQ(data->has_instance_index, true);
|
||||
EXPECT_EQ(data->first_vertex_offset, 0u);
|
||||
EXPECT_EQ(data->first_instance_offset, 0u);
|
||||
}
|
||||
|
||||
TEST_F(FirstIndexOffsetTest, BasicModuleInstanceIndex_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
@stage(vertex)
|
||||
fn entry(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
|
||||
test(inst_idx);
|
||||
return vec4<f32>();
|
||||
}
|
||||
|
||||
fn test(inst_idx : u32) -> u32 {
|
||||
return inst_idx;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol {
|
||||
first_instance_index : u32;
|
||||
}
|
||||
|
||||
@binding(1) @group(7) var<uniform> tint_symbol_1 : tint_symbol;
|
||||
|
||||
@stage(vertex)
|
||||
fn entry(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
|
||||
test((inst_idx + tint_symbol_1.first_instance_index));
|
||||
return vec4<f32>();
|
||||
}
|
||||
|
||||
fn test(inst_idx : u32) -> u32 {
|
||||
return inst_idx;
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap config;
|
||||
config.Add<FirstIndexOffset::BindingPoint>(1, 7);
|
||||
auto got = Run<FirstIndexOffset>(src, std::move(config));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
|
||||
auto* data = got.data.Get<FirstIndexOffset::Data>();
|
||||
|
||||
ASSERT_NE(data, nullptr);
|
||||
EXPECT_EQ(data->has_vertex_index, false);
|
||||
EXPECT_EQ(data->has_instance_index, true);
|
||||
EXPECT_EQ(data->first_vertex_offset, 0u);
|
||||
EXPECT_EQ(data->first_instance_offset, 0u);
|
||||
}
|
||||
|
||||
TEST_F(FirstIndexOffsetTest, BasicModuleBothIndex) {
|
||||
auto* src = R"(
|
||||
fn test(instance_idx : u32, vert_idx : u32) -> u32 {
|
||||
return instance_idx + vert_idx;
|
||||
}
|
||||
|
||||
struct Inputs {
|
||||
@builtin(instance_index) instance_idx : u32;
|
||||
@builtin(vertex_index) vert_idx : u32;
|
||||
};
|
||||
|
||||
@stage(vertex)
|
||||
fn entry(inputs : Inputs) -> @builtin(position) vec4<f32> {
|
||||
test(inputs.instance_idx, inputs.vert_idx);
|
||||
return vec4<f32>();
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol {
|
||||
first_vertex_index : u32;
|
||||
first_instance_index : u32;
|
||||
}
|
||||
|
||||
@binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
|
||||
|
||||
fn test(instance_idx : u32, vert_idx : u32) -> u32 {
|
||||
return (instance_idx + vert_idx);
|
||||
}
|
||||
|
||||
struct Inputs {
|
||||
@builtin(instance_index)
|
||||
instance_idx : u32;
|
||||
@builtin(vertex_index)
|
||||
vert_idx : u32;
|
||||
}
|
||||
|
||||
@stage(vertex)
|
||||
fn entry(inputs : Inputs) -> @builtin(position) vec4<f32> {
|
||||
test((inputs.instance_idx + tint_symbol_1.first_instance_index), (inputs.vert_idx + tint_symbol_1.first_vertex_index));
|
||||
return vec4<f32>();
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap config;
|
||||
config.Add<FirstIndexOffset::BindingPoint>(1, 2);
|
||||
auto got = Run<FirstIndexOffset>(src, std::move(config));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
|
||||
auto* data = got.data.Get<FirstIndexOffset::Data>();
|
||||
|
||||
ASSERT_NE(data, nullptr);
|
||||
EXPECT_EQ(data->has_vertex_index, true);
|
||||
EXPECT_EQ(data->has_instance_index, true);
|
||||
EXPECT_EQ(data->first_vertex_offset, 0u);
|
||||
EXPECT_EQ(data->first_instance_offset, 4u);
|
||||
}
|
||||
|
||||
TEST_F(FirstIndexOffsetTest, BasicModuleBothIndex_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
@stage(vertex)
|
||||
fn entry(inputs : Inputs) -> @builtin(position) vec4<f32> {
|
||||
test(inputs.instance_idx, inputs.vert_idx);
|
||||
return vec4<f32>();
|
||||
}
|
||||
|
||||
struct Inputs {
|
||||
@builtin(instance_index) instance_idx : u32;
|
||||
@builtin(vertex_index) vert_idx : u32;
|
||||
};
|
||||
|
||||
fn test(instance_idx : u32, vert_idx : u32) -> u32 {
|
||||
return instance_idx + vert_idx;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol {
|
||||
first_vertex_index : u32;
|
||||
first_instance_index : u32;
|
||||
}
|
||||
|
||||
@binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
|
||||
|
||||
@stage(vertex)
|
||||
fn entry(inputs : Inputs) -> @builtin(position) vec4<f32> {
|
||||
test((inputs.instance_idx + tint_symbol_1.first_instance_index), (inputs.vert_idx + tint_symbol_1.first_vertex_index));
|
||||
return vec4<f32>();
|
||||
}
|
||||
|
||||
struct Inputs {
|
||||
@builtin(instance_index)
|
||||
instance_idx : u32;
|
||||
@builtin(vertex_index)
|
||||
vert_idx : u32;
|
||||
}
|
||||
|
||||
fn test(instance_idx : u32, vert_idx : u32) -> u32 {
|
||||
return (instance_idx + vert_idx);
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap config;
|
||||
config.Add<FirstIndexOffset::BindingPoint>(1, 2);
|
||||
auto got = Run<FirstIndexOffset>(src, std::move(config));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
|
||||
auto* data = got.data.Get<FirstIndexOffset::Data>();
|
||||
|
||||
ASSERT_NE(data, nullptr);
|
||||
EXPECT_EQ(data->has_vertex_index, true);
|
||||
EXPECT_EQ(data->has_instance_index, true);
|
||||
EXPECT_EQ(data->first_vertex_offset, 0u);
|
||||
EXPECT_EQ(data->first_instance_offset, 4u);
|
||||
}
|
||||
|
||||
TEST_F(FirstIndexOffsetTest, NestedCalls) {
|
||||
auto* src = R"(
|
||||
fn func1(vert_idx : u32) -> u32 {
|
||||
return vert_idx;
|
||||
}
|
||||
|
||||
fn func2(vert_idx : u32) -> u32 {
|
||||
return func1(vert_idx);
|
||||
}
|
||||
|
||||
@stage(vertex)
|
||||
fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
|
||||
func2(vert_idx);
|
||||
return vec4<f32>();
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol {
|
||||
first_vertex_index : u32;
|
||||
}
|
||||
|
||||
@binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
|
||||
|
||||
fn func1(vert_idx : u32) -> u32 {
|
||||
return vert_idx;
|
||||
}
|
||||
|
||||
fn func2(vert_idx : u32) -> u32 {
|
||||
return func1(vert_idx);
|
||||
}
|
||||
|
||||
@stage(vertex)
|
||||
fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
|
||||
func2((vert_idx + tint_symbol_1.first_vertex_index));
|
||||
return vec4<f32>();
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap config;
|
||||
config.Add<FirstIndexOffset::BindingPoint>(1, 2);
|
||||
auto got = Run<FirstIndexOffset>(src, std::move(config));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
|
||||
auto* data = got.data.Get<FirstIndexOffset::Data>();
|
||||
|
||||
ASSERT_NE(data, nullptr);
|
||||
EXPECT_EQ(data->has_vertex_index, true);
|
||||
EXPECT_EQ(data->has_instance_index, false);
|
||||
EXPECT_EQ(data->first_vertex_offset, 0u);
|
||||
EXPECT_EQ(data->first_instance_offset, 0u);
|
||||
}
|
||||
|
||||
TEST_F(FirstIndexOffsetTest, NestedCalls_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
@stage(vertex)
|
||||
fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
|
||||
func2(vert_idx);
|
||||
return vec4<f32>();
|
||||
}
|
||||
|
||||
fn func2(vert_idx : u32) -> u32 {
|
||||
return func1(vert_idx);
|
||||
}
|
||||
|
||||
fn func1(vert_idx : u32) -> u32 {
|
||||
return vert_idx;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol {
|
||||
first_vertex_index : u32;
|
||||
}
|
||||
|
||||
@binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
|
||||
|
||||
@stage(vertex)
|
||||
fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
|
||||
func2((vert_idx + tint_symbol_1.first_vertex_index));
|
||||
return vec4<f32>();
|
||||
}
|
||||
|
||||
fn func2(vert_idx : u32) -> u32 {
|
||||
return func1(vert_idx);
|
||||
}
|
||||
|
||||
fn func1(vert_idx : u32) -> u32 {
|
||||
return vert_idx;
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap config;
|
||||
config.Add<FirstIndexOffset::BindingPoint>(1, 2);
|
||||
auto got = Run<FirstIndexOffset>(src, std::move(config));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
|
||||
auto* data = got.data.Get<FirstIndexOffset::Data>();
|
||||
|
||||
ASSERT_NE(data, nullptr);
|
||||
EXPECT_EQ(data->has_vertex_index, true);
|
||||
EXPECT_EQ(data->has_instance_index, false);
|
||||
EXPECT_EQ(data->first_vertex_offset, 0u);
|
||||
EXPECT_EQ(data->first_instance_offset, 0u);
|
||||
}
|
||||
|
||||
TEST_F(FirstIndexOffsetTest, MultipleEntryPoints) {
|
||||
auto* src = R"(
|
||||
fn func(i : u32) -> u32 {
|
||||
return i;
|
||||
}
|
||||
|
||||
@stage(vertex)
|
||||
fn entry_a(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
|
||||
func(vert_idx);
|
||||
return vec4<f32>();
|
||||
}
|
||||
|
||||
@stage(vertex)
|
||||
fn entry_b(@builtin(vertex_index) vert_idx : u32, @builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
|
||||
func(vert_idx + inst_idx);
|
||||
return vec4<f32>();
|
||||
}
|
||||
|
||||
@stage(vertex)
|
||||
fn entry_c(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
|
||||
func(inst_idx);
|
||||
return vec4<f32>();
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol {
|
||||
first_vertex_index : u32;
|
||||
first_instance_index : u32;
|
||||
}
|
||||
|
||||
@binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
|
||||
|
||||
fn func(i : u32) -> u32 {
|
||||
return i;
|
||||
}
|
||||
|
||||
@stage(vertex)
|
||||
fn entry_a(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
|
||||
func((vert_idx + tint_symbol_1.first_vertex_index));
|
||||
return vec4<f32>();
|
||||
}
|
||||
|
||||
@stage(vertex)
|
||||
fn entry_b(@builtin(vertex_index) vert_idx : u32, @builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
|
||||
func(((vert_idx + tint_symbol_1.first_vertex_index) + (inst_idx + tint_symbol_1.first_instance_index)));
|
||||
return vec4<f32>();
|
||||
}
|
||||
|
||||
@stage(vertex)
|
||||
fn entry_c(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
|
||||
func((inst_idx + tint_symbol_1.first_instance_index));
|
||||
return vec4<f32>();
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap config;
|
||||
config.Add<FirstIndexOffset::BindingPoint>(1, 2);
|
||||
auto got = Run<FirstIndexOffset>(src, std::move(config));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
|
||||
auto* data = got.data.Get<FirstIndexOffset::Data>();
|
||||
|
||||
ASSERT_NE(data, nullptr);
|
||||
EXPECT_EQ(data->has_vertex_index, true);
|
||||
EXPECT_EQ(data->has_instance_index, true);
|
||||
EXPECT_EQ(data->first_vertex_offset, 0u);
|
||||
EXPECT_EQ(data->first_instance_offset, 4u);
|
||||
}
|
||||
|
||||
TEST_F(FirstIndexOffsetTest, MultipleEntryPoints_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
@stage(vertex)
|
||||
fn entry_a(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
|
||||
func(vert_idx);
|
||||
return vec4<f32>();
|
||||
}
|
||||
|
||||
@stage(vertex)
|
||||
fn entry_b(@builtin(vertex_index) vert_idx : u32, @builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
|
||||
func(vert_idx + inst_idx);
|
||||
return vec4<f32>();
|
||||
}
|
||||
|
||||
@stage(vertex)
|
||||
fn entry_c(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
|
||||
func(inst_idx);
|
||||
return vec4<f32>();
|
||||
}
|
||||
|
||||
fn func(i : u32) -> u32 {
|
||||
return i;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol {
|
||||
first_vertex_index : u32;
|
||||
first_instance_index : u32;
|
||||
}
|
||||
|
||||
@binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
|
||||
|
||||
@stage(vertex)
|
||||
fn entry_a(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
|
||||
func((vert_idx + tint_symbol_1.first_vertex_index));
|
||||
return vec4<f32>();
|
||||
}
|
||||
|
||||
@stage(vertex)
|
||||
fn entry_b(@builtin(vertex_index) vert_idx : u32, @builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
|
||||
func(((vert_idx + tint_symbol_1.first_vertex_index) + (inst_idx + tint_symbol_1.first_instance_index)));
|
||||
return vec4<f32>();
|
||||
}
|
||||
|
||||
@stage(vertex)
|
||||
fn entry_c(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
|
||||
func((inst_idx + tint_symbol_1.first_instance_index));
|
||||
return vec4<f32>();
|
||||
}
|
||||
|
||||
fn func(i : u32) -> u32 {
|
||||
return i;
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap config;
|
||||
config.Add<FirstIndexOffset::BindingPoint>(1, 2);
|
||||
auto got = Run<FirstIndexOffset>(src, std::move(config));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
|
||||
auto* data = got.data.Get<FirstIndexOffset::Data>();
|
||||
|
||||
ASSERT_NE(data, nullptr);
|
||||
EXPECT_EQ(data->has_vertex_index, true);
|
||||
EXPECT_EQ(data->has_instance_index, true);
|
||||
EXPECT_EQ(data->first_vertex_offset, 0u);
|
||||
EXPECT_EQ(data->first_instance_offset, 4u);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
99
src/tint/transform/fold_constants.cc
Normal file
99
src/tint/transform/fold_constants.cc
Normal file
@@ -0,0 +1,99 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/fold_constants.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/call.h"
|
||||
#include "src/tint/sem/expression.h"
|
||||
#include "src/tint/sem/type_constructor.h"
|
||||
#include "src/tint/sem/type_conversion.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::FoldConstants);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
FoldConstants::FoldConstants() = default;
|
||||
|
||||
FoldConstants::~FoldConstants() = default;
|
||||
|
||||
void FoldConstants::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
|
||||
auto* call = ctx.src->Sem().Get<sem::Call>(expr);
|
||||
if (!call) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto value = call->ConstantValue();
|
||||
if (!value.IsValid()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto* ty = call->Type();
|
||||
|
||||
if (!call->Target()->IsAnyOf<sem::TypeConversion, sem::TypeConstructor>()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// If original ctor expression had no init values, don't replace the
|
||||
// expression
|
||||
if (call->Arguments().empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (auto* vec = ty->As<sem::Vector>()) {
|
||||
uint32_t vec_size = static_cast<uint32_t>(vec->Width());
|
||||
|
||||
// We'd like to construct the new vector with the same number of
|
||||
// constructor args that the original node had, but after folding
|
||||
// constants, cases like the following are problematic:
|
||||
//
|
||||
// vec3<f32> = vec3<f32>(vec2<f32>, 1.0) // vec_size=3, ctor_size=2
|
||||
//
|
||||
// In this case, creating a vec3 with 2 args is invalid, so we should
|
||||
// create it with 3. So what we do is construct with vec_size args,
|
||||
// except if the original vector was single-value initialized, in
|
||||
// which case, we only construct with one arg again.
|
||||
uint32_t ctor_size = (call->Arguments().size() == 1) ? 1 : vec_size;
|
||||
|
||||
ast::ExpressionList ctors;
|
||||
for (uint32_t i = 0; i < ctor_size; ++i) {
|
||||
value.WithScalarAt(
|
||||
i, [&](auto&& s) { ctors.emplace_back(ctx.dst->Expr(s)); });
|
||||
}
|
||||
|
||||
auto* el_ty = CreateASTTypeFor(ctx, vec->type());
|
||||
return ctx.dst->vec(el_ty, vec_size, ctors);
|
||||
}
|
||||
|
||||
if (ty->is_scalar()) {
|
||||
return value.WithScalarAt(0,
|
||||
[&](auto&& s) -> const ast::LiteralExpression* {
|
||||
return ctx.dst->Expr(s);
|
||||
});
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
47
src/tint/transform/fold_constants.h
Normal file
47
src/tint/transform/fold_constants.h
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_FOLD_CONSTANTS_H_
|
||||
#define SRC_TINT_TRANSFORM_FOLD_CONSTANTS_H_
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// FoldConstants transforms the AST by folding constant expressions
|
||||
class FoldConstants : public Castable<FoldConstants, Transform> {
|
||||
public:
|
||||
/// Constructor
|
||||
FoldConstants();
|
||||
|
||||
/// Destructor
|
||||
~FoldConstants() override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_FOLD_CONSTANTS_H_
|
||||
427
src/tint/transform/fold_constants_test.cc
Normal file
427
src/tint/transform/fold_constants_test.cc
Normal file
@@ -0,0 +1,427 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/fold_constants.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/tint/transform/test_helper.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using FoldConstantsTest = TransformTest;
|
||||
|
||||
TEST_F(FoldConstantsTest, Module_Scalar_NoConversion) {
|
||||
auto* src = R"(
|
||||
var<private> a : i32 = i32(123);
|
||||
var<private> b : u32 = u32(123u);
|
||||
var<private> c : f32 = f32(123.0);
|
||||
var<private> d : bool = bool(true);
|
||||
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
var<private> a : i32 = 123;
|
||||
|
||||
var<private> b : u32 = 123u;
|
||||
|
||||
var<private> c : f32 = 123.0;
|
||||
|
||||
var<private> d : bool = true;
|
||||
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<FoldConstants>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(FoldConstantsTest, Module_Scalar_Conversion) {
|
||||
auto* src = R"(
|
||||
var<private> a : i32 = i32(123.0);
|
||||
var<private> b : u32 = u32(123);
|
||||
var<private> c : f32 = f32(123u);
|
||||
var<private> d : bool = bool(123);
|
||||
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
var<private> a : i32 = 123;
|
||||
|
||||
var<private> b : u32 = 123u;
|
||||
|
||||
var<private> c : f32 = 123.0;
|
||||
|
||||
var<private> d : bool = true;
|
||||
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<FoldConstants>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(FoldConstantsTest, Module_Scalar_MultipleConversions) {
|
||||
auto* src = R"(
|
||||
var<private> a : i32 = i32(u32(f32(u32(i32(123.0)))));
|
||||
var<private> b : u32 = u32(i32(f32(i32(u32(123)))));
|
||||
var<private> c : f32 = f32(u32(i32(u32(f32(123u)))));
|
||||
var<private> d : bool = bool(i32(f32(i32(u32(123)))));
|
||||
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
var<private> a : i32 = 123;
|
||||
|
||||
var<private> b : u32 = 123u;
|
||||
|
||||
var<private> c : f32 = 123.0;
|
||||
|
||||
var<private> d : bool = true;
|
||||
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<FoldConstants>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(FoldConstantsTest, Module_Vector_NoConversion) {
|
||||
auto* src = R"(
|
||||
var<private> a : vec3<i32> = vec3<i32>(123);
|
||||
var<private> b : vec3<u32> = vec3<u32>(123u);
|
||||
var<private> c : vec3<f32> = vec3<f32>(123.0);
|
||||
var<private> d : vec3<bool> = vec3<bool>(true);
|
||||
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
var<private> a : vec3<i32> = vec3<i32>(123);
|
||||
|
||||
var<private> b : vec3<u32> = vec3<u32>(123u);
|
||||
|
||||
var<private> c : vec3<f32> = vec3<f32>(123.0);
|
||||
|
||||
var<private> d : vec3<bool> = vec3<bool>(true);
|
||||
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<FoldConstants>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(FoldConstantsTest, Module_Vector_Conversion) {
|
||||
auto* src = R"(
|
||||
var<private> a : vec3<i32> = vec3<i32>(vec3<f32>(123.0));
|
||||
var<private> b : vec3<u32> = vec3<u32>(vec3<i32>(123));
|
||||
var<private> c : vec3<f32> = vec3<f32>(vec3<u32>(123u));
|
||||
var<private> d : vec3<bool> = vec3<bool>(vec3<i32>(123));
|
||||
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
var<private> a : vec3<i32> = vec3<i32>(123);
|
||||
|
||||
var<private> b : vec3<u32> = vec3<u32>(123u);
|
||||
|
||||
var<private> c : vec3<f32> = vec3<f32>(123.0);
|
||||
|
||||
var<private> d : vec3<bool> = vec3<bool>(true);
|
||||
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<FoldConstants>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(FoldConstantsTest, Module_Vector_MultipleConversions) {
|
||||
auto* src = R"(
|
||||
var<private> a : vec3<i32> = vec3<i32>(vec3<u32>(vec3<f32>(vec3<u32>(u32(123.0)))));
|
||||
var<private> b : vec3<u32> = vec3<u32>(vec3<i32>(vec3<f32>(vec3<i32>(i32(123)))));
|
||||
var<private> c : vec3<f32> = vec3<f32>(vec3<u32>(vec3<i32>(vec3<u32>(u32(123u)))));
|
||||
var<private> d : vec3<bool> = vec3<bool>(vec3<i32>(vec3<f32>(vec3<i32>(i32(123)))));
|
||||
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
var<private> a : vec3<i32> = vec3<i32>(123);
|
||||
|
||||
var<private> b : vec3<u32> = vec3<u32>(123u);
|
||||
|
||||
var<private> c : vec3<f32> = vec3<f32>(123.0);
|
||||
|
||||
var<private> d : vec3<bool> = vec3<bool>(true);
|
||||
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<FoldConstants>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(FoldConstantsTest, Module_Vector_MixedSizeConversions) {
|
||||
auto* src = R"(
|
||||
var<private> a : vec4<i32> = vec4<i32>(vec3<i32>(vec3<u32>(1u, 2u, 3u)), 4);
|
||||
var<private> b : vec4<i32> = vec4<i32>(vec2<i32>(vec2<u32>(1u, 2u)), vec2<i32>(4, 5));
|
||||
var<private> c : vec4<i32> = vec4<i32>(1, vec2<i32>(vec2<f32>(2.0, 3.0)), 4);
|
||||
var<private> d : vec4<i32> = vec4<i32>(1, 2, vec2<i32>(vec2<f32>(3.0, 4.0)));
|
||||
var<private> e : vec4<bool> = vec4<bool>(false, bool(f32(1.0)), vec2<bool>(vec2<i32>(0, i32(4u))));
|
||||
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
var<private> a : vec4<i32> = vec4<i32>(1, 2, 3, 4);
|
||||
|
||||
var<private> b : vec4<i32> = vec4<i32>(1, 2, 4, 5);
|
||||
|
||||
var<private> c : vec4<i32> = vec4<i32>(1, 2, 3, 4);
|
||||
|
||||
var<private> d : vec4<i32> = vec4<i32>(1, 2, 3, 4);
|
||||
|
||||
var<private> e : vec4<bool> = vec4<bool>(false, true, false, true);
|
||||
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<FoldConstants>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(FoldConstantsTest, Function_Scalar_NoConversion) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var a : i32 = i32(123);
|
||||
var b : u32 = u32(123u);
|
||||
var c : f32 = f32(123.0);
|
||||
var d : bool = bool(true);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var a : i32 = 123;
|
||||
var b : u32 = 123u;
|
||||
var c : f32 = 123.0;
|
||||
var d : bool = true;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<FoldConstants>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(FoldConstantsTest, Function_Scalar_Conversion) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var a : i32 = i32(123.0);
|
||||
var b : u32 = u32(123);
|
||||
var c : f32 = f32(123u);
|
||||
var d : bool = bool(123);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var a : i32 = 123;
|
||||
var b : u32 = 123u;
|
||||
var c : f32 = 123.0;
|
||||
var d : bool = true;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<FoldConstants>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(FoldConstantsTest, Function_Scalar_MultipleConversions) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var a : i32 = i32(u32(f32(u32(i32(123.0)))));
|
||||
var b : u32 = u32(i32(f32(i32(u32(123)))));
|
||||
var c : f32 = f32(u32(i32(u32(f32(123u)))));
|
||||
var d : bool = bool(i32(f32(i32(u32(123)))));
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var a : i32 = 123;
|
||||
var b : u32 = 123u;
|
||||
var c : f32 = 123.0;
|
||||
var d : bool = true;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<FoldConstants>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(FoldConstantsTest, Function_Vector_NoConversion) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var a : vec3<i32> = vec3<i32>(123);
|
||||
var b : vec3<u32> = vec3<u32>(123u);
|
||||
var c : vec3<f32> = vec3<f32>(123.0);
|
||||
var d : vec3<bool> = vec3<bool>(true);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var a : vec3<i32> = vec3<i32>(123);
|
||||
var b : vec3<u32> = vec3<u32>(123u);
|
||||
var c : vec3<f32> = vec3<f32>(123.0);
|
||||
var d : vec3<bool> = vec3<bool>(true);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<FoldConstants>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(FoldConstantsTest, Function_Vector_Conversion) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var a : vec3<i32> = vec3<i32>(vec3<f32>(123.0));
|
||||
var b : vec3<u32> = vec3<u32>(vec3<i32>(123));
|
||||
var c : vec3<f32> = vec3<f32>(vec3<u32>(123u));
|
||||
var d : vec3<bool> = vec3<bool>(vec3<i32>(123));
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var a : vec3<i32> = vec3<i32>(123);
|
||||
var b : vec3<u32> = vec3<u32>(123u);
|
||||
var c : vec3<f32> = vec3<f32>(123.0);
|
||||
var d : vec3<bool> = vec3<bool>(true);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<FoldConstants>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(FoldConstantsTest, Function_Vector_MultipleConversions) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var a : vec3<i32> = vec3<i32>(vec3<u32>(vec3<f32>(vec3<u32>(u32(123.0)))));
|
||||
var b : vec3<u32> = vec3<u32>(vec3<i32>(vec3<f32>(vec3<i32>(i32(123)))));
|
||||
var c : vec3<f32> = vec3<f32>(vec3<u32>(vec3<i32>(vec3<u32>(u32(123u)))));
|
||||
var d : vec3<bool> = vec3<bool>(vec3<i32>(vec3<f32>(vec3<i32>(i32(123)))));
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var a : vec3<i32> = vec3<i32>(123);
|
||||
var b : vec3<u32> = vec3<u32>(123u);
|
||||
var c : vec3<f32> = vec3<f32>(123.0);
|
||||
var d : vec3<bool> = vec3<bool>(true);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<FoldConstants>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(FoldConstantsTest, Function_Vector_MixedSizeConversions) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var a : vec4<i32> = vec4<i32>(vec3<i32>(vec3<u32>(1u, 2u, 3u)), 4);
|
||||
var b : vec4<i32> = vec4<i32>(vec2<i32>(vec2<u32>(1u, 2u)), vec2<i32>(4, 5));
|
||||
var c : vec4<i32> = vec4<i32>(1, vec2<i32>(vec2<f32>(2.0, 3.0)), 4);
|
||||
var d : vec4<i32> = vec4<i32>(1, 2, vec2<i32>(vec2<f32>(3.0, 4.0)));
|
||||
var e : vec4<bool> = vec4<bool>(false, bool(f32(1.0)), vec2<bool>(vec2<i32>(0, i32(4u))));
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var a : vec4<i32> = vec4<i32>(1, 2, 3, 4);
|
||||
var b : vec4<i32> = vec4<i32>(1, 2, 4, 5);
|
||||
var c : vec4<i32> = vec4<i32>(1, 2, 3, 4);
|
||||
var d : vec4<i32> = vec4<i32>(1, 2, 3, 4);
|
||||
var e : vec4<bool> = vec4<bool>(false, true, false, true);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<FoldConstants>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(FoldConstantsTest, Function_Vector_ConstantWithNonConstant) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var a : f32 = f32();
|
||||
var b : vec2<f32> = vec2<f32>(f32(i32(1)), a);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var a : f32 = f32();
|
||||
var b : vec2<f32> = vec2<f32>(1.0, a);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<FoldConstants>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
92
src/tint/transform/fold_trivial_single_use_lets.cc
Normal file
92
src/tint/transform/fold_trivial_single_use_lets.cc
Normal file
@@ -0,0 +1,92 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/fold_trivial_single_use_lets.h"
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/block_statement.h"
|
||||
#include "src/tint/sem/function.h"
|
||||
#include "src/tint/sem/statement.h"
|
||||
#include "src/tint/sem/variable.h"
|
||||
#include "src/tint/utils/scoped_assignment.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::FoldTrivialSingleUseLets);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
const ast::VariableDeclStatement* AsTrivialLetDecl(const ast::Statement* stmt) {
|
||||
auto* var_decl = stmt->As<ast::VariableDeclStatement>();
|
||||
if (!var_decl) {
|
||||
return nullptr;
|
||||
}
|
||||
auto* var = var_decl->variable;
|
||||
if (!var->is_const) {
|
||||
return nullptr;
|
||||
}
|
||||
auto* ctor = var->constructor;
|
||||
if (!IsAnyOf<ast::IdentifierExpression, ast::LiteralExpression>(ctor)) {
|
||||
return nullptr;
|
||||
}
|
||||
return var_decl;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
FoldTrivialSingleUseLets::FoldTrivialSingleUseLets() = default;
|
||||
|
||||
FoldTrivialSingleUseLets::~FoldTrivialSingleUseLets() = default;
|
||||
|
||||
void FoldTrivialSingleUseLets::Run(CloneContext& ctx,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
if (auto* block = node->As<ast::BlockStatement>()) {
|
||||
auto& stmts = block->statements;
|
||||
for (size_t stmt_idx = 0; stmt_idx < stmts.size(); stmt_idx++) {
|
||||
auto* stmt = stmts[stmt_idx];
|
||||
if (auto* let_decl = AsTrivialLetDecl(stmt)) {
|
||||
auto* let = let_decl->variable;
|
||||
auto* sem_let = ctx.src->Sem().Get(let);
|
||||
auto& users = sem_let->Users();
|
||||
if (users.size() != 1) {
|
||||
continue; // Does not have a single user.
|
||||
}
|
||||
|
||||
auto* user = users[0];
|
||||
auto* user_stmt = user->Stmt()->Declaration();
|
||||
|
||||
for (size_t i = stmt_idx; i < stmts.size(); i++) {
|
||||
if (user_stmt == stmts[i]) {
|
||||
auto* user_expr = user->Declaration();
|
||||
ctx.Remove(stmts, let_decl);
|
||||
ctx.Replace(user_expr, ctx.Clone(let->constructor));
|
||||
}
|
||||
if (!AsTrivialLetDecl(stmts[i])) {
|
||||
// Stop if we hit a statement that isn't the single use of the
|
||||
// let, and isn't a let itself.
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
61
src/tint/transform/fold_trivial_single_use_lets.h
Normal file
61
src/tint/transform/fold_trivial_single_use_lets.h
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_FOLD_TRIVIAL_SINGLE_USE_LETS_H_
|
||||
#define SRC_TINT_TRANSFORM_FOLD_TRIVIAL_SINGLE_USE_LETS_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// FoldTrivialSingleUseLets is an optimizer for folding away trivial `let`
|
||||
/// statements into their single place of use. This transform is intended to
|
||||
/// clean up the SSA `let`s produced by the SPIR-V reader.
|
||||
/// `let`s can only be folded if:
|
||||
/// * There is a single usage of the `let` value.
|
||||
/// * The `let` is constructed with a ScalarConstructorExpression, or with an
|
||||
/// IdentifierExpression.
|
||||
/// * There are only other foldable `let`s between the `let` declaration and its
|
||||
/// single usage.
|
||||
/// These rules prevent any hoisting of the let that may affect execution
|
||||
/// behaviour.
|
||||
class FoldTrivialSingleUseLets
|
||||
: public Castable<FoldTrivialSingleUseLets, Transform> {
|
||||
public:
|
||||
/// Constructor
|
||||
FoldTrivialSingleUseLets();
|
||||
|
||||
/// Destructor
|
||||
~FoldTrivialSingleUseLets() override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_FOLD_TRIVIAL_SINGLE_USE_LETS_H_
|
||||
188
src/tint/transform/fold_trivial_single_use_lets_test.cc
Normal file
188
src/tint/transform/fold_trivial_single_use_lets_test.cc
Normal file
@@ -0,0 +1,188 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/fold_trivial_single_use_lets.h"
|
||||
|
||||
#include "src/tint/transform/test_helper.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using FoldTrivialSingleUseLetsTest = TransformTest;
|
||||
|
||||
TEST_F(FoldTrivialSingleUseLetsTest, EmptyModule) {
|
||||
auto* src = "";
|
||||
auto* expect = "";
|
||||
|
||||
auto got = Run<FoldTrivialSingleUseLets>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(FoldTrivialSingleUseLetsTest, Single) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
let x = 1;
|
||||
_ = x;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
_ = 1;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<FoldTrivialSingleUseLets>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(FoldTrivialSingleUseLetsTest, Multiple) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
let x = 1;
|
||||
let y = 2;
|
||||
let z = 3;
|
||||
_ = x + y + z;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
_ = ((1 + 2) + 3);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<FoldTrivialSingleUseLets>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(FoldTrivialSingleUseLetsTest, Chained) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
let x = 1;
|
||||
let y = x;
|
||||
let z = y;
|
||||
_ = z;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
_ = 1;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<FoldTrivialSingleUseLets>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(FoldTrivialSingleUseLetsTest, NoFold_NonTrivialLet) {
|
||||
auto* src = R"(
|
||||
fn function_with_posssible_side_effect() -> i32 {
|
||||
return 1;
|
||||
}
|
||||
|
||||
fn f() {
|
||||
let x = 1;
|
||||
let y = function_with_posssible_side_effect();
|
||||
_ = (x + y);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<FoldTrivialSingleUseLets>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(FoldTrivialSingleUseLetsTest, NoFold_NonTrivialLet_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
let x = 1;
|
||||
let y = function_with_posssible_side_effect();
|
||||
_ = (x + y);
|
||||
}
|
||||
|
||||
fn function_with_posssible_side_effect() -> i32 {
|
||||
return 1;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<FoldTrivialSingleUseLets>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(FoldTrivialSingleUseLetsTest, NoFold_UseInSubBlock) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
let x = 1;
|
||||
{
|
||||
_ = x;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<FoldTrivialSingleUseLets>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(FoldTrivialSingleUseLetsTest, NoFold_MultipleUses) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
let x = 1;
|
||||
_ = (x + x);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<FoldTrivialSingleUseLets>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(FoldTrivialSingleUseLetsTest, NoFold_Shadowing) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var y = 1;
|
||||
let x = y;
|
||||
{
|
||||
let y = false;
|
||||
_ = (x + x);
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<FoldTrivialSingleUseLets>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
76
src/tint/transform/for_loop_to_loop.cc
Normal file
76
src/tint/transform/for_loop_to_loop.cc
Normal file
@@ -0,0 +1,76 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/for_loop_to_loop.h"
|
||||
|
||||
#include "src/tint/ast/break_statement.h"
|
||||
#include "src/tint/program_builder.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::ForLoopToLoop);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
ForLoopToLoop::ForLoopToLoop() = default;
|
||||
|
||||
ForLoopToLoop::~ForLoopToLoop() = default;
|
||||
|
||||
bool ForLoopToLoop::ShouldRun(const Program* program, const DataMap&) const {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (node->Is<ast::ForLoopStatement>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ForLoopToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
ctx.ReplaceAll(
|
||||
[&](const ast::ForLoopStatement* for_loop) -> const ast::Statement* {
|
||||
ast::StatementList stmts;
|
||||
if (auto* cond = for_loop->condition) {
|
||||
// !condition
|
||||
auto* not_cond = ctx.dst->create<ast::UnaryOpExpression>(
|
||||
ast::UnaryOp::kNot, ctx.Clone(cond));
|
||||
|
||||
// { break; }
|
||||
auto* break_body =
|
||||
ctx.dst->Block(ctx.dst->create<ast::BreakStatement>());
|
||||
|
||||
// if (!condition) { break; }
|
||||
stmts.emplace_back(ctx.dst->If(not_cond, break_body));
|
||||
}
|
||||
for (auto* stmt : for_loop->body->statements) {
|
||||
stmts.emplace_back(ctx.Clone(stmt));
|
||||
}
|
||||
|
||||
const ast::BlockStatement* continuing = nullptr;
|
||||
if (auto* cont = for_loop->continuing) {
|
||||
continuing = ctx.dst->Block(ctx.Clone(cont));
|
||||
}
|
||||
|
||||
auto* body = ctx.dst->Block(stmts);
|
||||
auto* loop = ctx.dst->create<ast::LoopStatement>(body, continuing);
|
||||
|
||||
if (auto* init = for_loop->initializer) {
|
||||
return ctx.dst->Block(ctx.Clone(init), loop);
|
||||
}
|
||||
|
||||
return loop;
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
54
src/tint/transform/for_loop_to_loop.h
Normal file
54
src/tint/transform/for_loop_to_loop.h
Normal file
@@ -0,0 +1,54 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_FOR_LOOP_TO_LOOP_H_
|
||||
#define SRC_TINT_TRANSFORM_FOR_LOOP_TO_LOOP_H_
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// ForLoopToLoop is a Transform that converts a for-loop statement into a loop
|
||||
/// statement. This is required by the SPIR-V writer.
|
||||
class ForLoopToLoop : public Castable<ForLoopToLoop, Transform> {
|
||||
public:
|
||||
/// Constructor
|
||||
ForLoopToLoop();
|
||||
|
||||
/// Destructor
|
||||
~ForLoopToLoop() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_FOR_LOOP_TO_LOOP_H_
|
||||
374
src/tint/transform/for_loop_to_loop_test.cc
Normal file
374
src/tint/transform/for_loop_to_loop_test.cc
Normal file
@@ -0,0 +1,374 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/for_loop_to_loop.h"
|
||||
|
||||
#include "src/tint/transform/test_helper.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using ForLoopToLoopTest = TransformTest;
|
||||
|
||||
TEST_F(ForLoopToLoopTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<ForLoopToLoop>(src));
|
||||
}
|
||||
|
||||
TEST_F(ForLoopToLoopTest, ShouldRunHasForLoop) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
for (;;) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<ForLoopToLoop>(src));
|
||||
}
|
||||
|
||||
TEST_F(ForLoopToLoopTest, EmptyModule) {
|
||||
auto* src = "";
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<ForLoopToLoop>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
// Test an empty for loop.
|
||||
TEST_F(ForLoopToLoopTest, Empty) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
for (;;) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
loop {
|
||||
break;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<ForLoopToLoop>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
// Test a for loop with non-empty body.
|
||||
TEST_F(ForLoopToLoopTest, Body) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
for (;;) {
|
||||
discard;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
loop {
|
||||
discard;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<ForLoopToLoop>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
// Test a for loop declaring a variable in the initializer statement.
|
||||
TEST_F(ForLoopToLoopTest, InitializerStatementDecl) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
for (var i: i32;;) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
{
|
||||
var i : i32;
|
||||
loop {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<ForLoopToLoop>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
// Test a for loop declaring and initializing a variable in the initializer
|
||||
// statement.
|
||||
TEST_F(ForLoopToLoopTest, InitializerStatementDeclEqual) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
for (var i: i32 = 0;;) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
{
|
||||
var i : i32 = 0;
|
||||
loop {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<ForLoopToLoop>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
// Test a for loop declaring a const variable in the initializer statement.
|
||||
TEST_F(ForLoopToLoopTest, InitializerStatementConstDecl) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
for (let i: i32 = 0;;) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
{
|
||||
let i : i32 = 0;
|
||||
loop {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<ForLoopToLoop>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
// Test a for loop assigning a variable in the initializer statement.
|
||||
TEST_F(ForLoopToLoopTest, InitializerStatementAssignment) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var i: i32;
|
||||
for (i = 0;;) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var i : i32;
|
||||
{
|
||||
i = 0;
|
||||
loop {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<ForLoopToLoop>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
// Test a for loop calling a function in the initializer statement.
|
||||
TEST_F(ForLoopToLoopTest, InitializerStatementFuncCall) {
|
||||
auto* src = R"(
|
||||
fn a(x : i32, y : i32) {
|
||||
}
|
||||
|
||||
fn f() {
|
||||
var b : i32;
|
||||
var c : i32;
|
||||
for (a(b,c);;) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn a(x : i32, y : i32) {
|
||||
}
|
||||
|
||||
fn f() {
|
||||
var b : i32;
|
||||
var c : i32;
|
||||
{
|
||||
a(b, c);
|
||||
loop {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<ForLoopToLoop>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
// Test a for loop with a break condition
|
||||
TEST_F(ForLoopToLoopTest, BreakCondition) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
for (; 0 == 1;) {
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
loop {
|
||||
if (!((0 == 1))) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<ForLoopToLoop>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
// Test a for loop assigning a variable in the continuing statement.
|
||||
TEST_F(ForLoopToLoopTest, ContinuingAssignment) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var x: i32;
|
||||
for (;;x = 2) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var x : i32;
|
||||
loop {
|
||||
break;
|
||||
|
||||
continuing {
|
||||
x = 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<ForLoopToLoop>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
// Test a for loop calling a function in the continuing statement.
|
||||
TEST_F(ForLoopToLoopTest, ContinuingFuncCall) {
|
||||
auto* src = R"(
|
||||
fn a(x : i32, y : i32) {
|
||||
}
|
||||
|
||||
fn f() {
|
||||
var b : i32;
|
||||
var c : i32;
|
||||
for (;;a(b,c)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn a(x : i32, y : i32) {
|
||||
}
|
||||
|
||||
fn f() {
|
||||
var b : i32;
|
||||
var c : i32;
|
||||
loop {
|
||||
break;
|
||||
|
||||
continuing {
|
||||
a(b, c);
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<ForLoopToLoop>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
// Test a for loop with all statements non-empty.
|
||||
TEST_F(ForLoopToLoopTest, All) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var a : i32;
|
||||
for(var i : i32 = 0; i < 4; i = i + 1) {
|
||||
if (a == 0) {
|
||||
continue;
|
||||
}
|
||||
a = a + 2;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var a : i32;
|
||||
{
|
||||
var i : i32 = 0;
|
||||
loop {
|
||||
if (!((i < 4))) {
|
||||
break;
|
||||
}
|
||||
if ((a == 0)) {
|
||||
continue;
|
||||
}
|
||||
a = (a + 2);
|
||||
|
||||
continuing {
|
||||
i = (i + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<ForLoopToLoop>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
118
src/tint/transform/glsl.cc
Normal file
118
src/tint/transform/glsl.cc
Normal file
@@ -0,0 +1,118 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/glsl.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/transform/add_empty_entry_point.h"
|
||||
#include "src/tint/transform/add_spirv_block_attribute.h"
|
||||
#include "src/tint/transform/binding_remapper.h"
|
||||
#include "src/tint/transform/canonicalize_entry_point_io.h"
|
||||
#include "src/tint/transform/combine_samplers.h"
|
||||
#include "src/tint/transform/decompose_memory_access.h"
|
||||
#include "src/tint/transform/external_texture_transform.h"
|
||||
#include "src/tint/transform/fold_trivial_single_use_lets.h"
|
||||
#include "src/tint/transform/loop_to_for_loop.h"
|
||||
#include "src/tint/transform/manager.h"
|
||||
#include "src/tint/transform/pad_array_elements.h"
|
||||
#include "src/tint/transform/promote_initializers_to_const_var.h"
|
||||
#include "src/tint/transform/remove_phonies.h"
|
||||
#include "src/tint/transform/renamer.h"
|
||||
#include "src/tint/transform/simplify_pointers.h"
|
||||
#include "src/tint/transform/single_entry_point.h"
|
||||
#include "src/tint/transform/unshadow.h"
|
||||
#include "src/tint/transform/zero_init_workgroup_memory.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::Glsl);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::Glsl::Config);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
Glsl::Glsl() = default;
|
||||
Glsl::~Glsl() = default;
|
||||
|
||||
Output Glsl::Run(const Program* in, const DataMap& inputs) const {
|
||||
Manager manager;
|
||||
DataMap data;
|
||||
|
||||
auto* cfg = inputs.Get<Config>();
|
||||
|
||||
if (cfg && !cfg->entry_point.empty()) {
|
||||
manager.Add<SingleEntryPoint>();
|
||||
data.Add<SingleEntryPoint::Config>(cfg->entry_point);
|
||||
}
|
||||
manager.Add<Renamer>();
|
||||
data.Add<Renamer::Config>(Renamer::Target::kGlslKeywords,
|
||||
/* preserve_unicode */ false);
|
||||
manager.Add<Unshadow>();
|
||||
|
||||
// Attempt to convert `loop`s into for-loops. This is to try and massage the
|
||||
// output into something that will not cause FXC to choke or misbehave.
|
||||
manager.Add<FoldTrivialSingleUseLets>();
|
||||
manager.Add<LoopToForLoop>();
|
||||
|
||||
if (!cfg || !cfg->disable_workgroup_init) {
|
||||
// ZeroInitWorkgroupMemory must come before CanonicalizeEntryPointIO as
|
||||
// ZeroInitWorkgroupMemory may inject new builtin parameters.
|
||||
manager.Add<ZeroInitWorkgroupMemory>();
|
||||
}
|
||||
manager.Add<CanonicalizeEntryPointIO>();
|
||||
manager.Add<SimplifyPointers>();
|
||||
|
||||
manager.Add<RemovePhonies>();
|
||||
manager.Add<CombineSamplers>();
|
||||
if (auto* binding_info = inputs.Get<CombineSamplers::BindingInfo>()) {
|
||||
data.Add<CombineSamplers::BindingInfo>(*binding_info);
|
||||
} else {
|
||||
data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(),
|
||||
sem::BindingPoint());
|
||||
}
|
||||
manager.Add<BindingRemapper>();
|
||||
if (auto* remappings = inputs.Get<BindingRemapper::Remappings>()) {
|
||||
data.Add<BindingRemapper::Remappings>(*remappings);
|
||||
} else {
|
||||
BindingRemapper::BindingPoints bp;
|
||||
BindingRemapper::AccessControls ac;
|
||||
data.Add<BindingRemapper::Remappings>(bp, ac, /* mayCollide */ true);
|
||||
}
|
||||
manager.Add<ExternalTextureTransform>();
|
||||
manager.Add<PromoteInitializersToConstVar>();
|
||||
|
||||
manager.Add<PadArrayElements>();
|
||||
manager.Add<AddEmptyEntryPoint>();
|
||||
manager.Add<AddSpirvBlockAttribute>();
|
||||
|
||||
data.Add<CanonicalizeEntryPointIO::Config>(
|
||||
CanonicalizeEntryPointIO::ShaderStyle::kGlsl);
|
||||
auto out = manager.Run(in, data);
|
||||
if (!out.program.IsValid()) {
|
||||
return out;
|
||||
}
|
||||
|
||||
ProgramBuilder builder;
|
||||
CloneContext ctx(&builder, &out.program);
|
||||
ctx.Clone();
|
||||
return Output{Program(std::move(builder))};
|
||||
}
|
||||
|
||||
Glsl::Config::Config(const std::string& ep, bool disable_wi)
|
||||
: entry_point(ep), disable_workgroup_init(disable_wi) {}
|
||||
Glsl::Config::Config(const Config&) = default;
|
||||
Glsl::Config::~Config() = default;
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
70
src/tint/transform/glsl.h
Normal file
70
src/tint/transform/glsl.h
Normal file
@@ -0,0 +1,70 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_GLSL_H_
|
||||
#define SRC_TINT_TRANSFORM_GLSL_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
|
||||
// Forward declarations
|
||||
class CloneContext;
|
||||
|
||||
namespace transform {
|
||||
|
||||
/// Glsl is a transform used to sanitize a Program for use with the Glsl writer.
|
||||
/// Passing a non-sanitized Program to the Glsl writer will result in undefined
|
||||
/// behavior.
|
||||
class Glsl : public Castable<Glsl, Transform> {
|
||||
public:
|
||||
/// Configuration options for the Glsl sanitizer transform.
|
||||
struct Config : public Castable<Data, transform::Data> {
|
||||
/// Constructor
|
||||
/// @param entry_point the root entry point function to generate
|
||||
/// @param disable_workgroup_init `true` to disable workgroup memory zero
|
||||
/// initialization
|
||||
explicit Config(const std::string& entry_point,
|
||||
bool disable_workgroup_init = false);
|
||||
|
||||
/// Copy constructor
|
||||
Config(const Config&);
|
||||
|
||||
/// Destructor
|
||||
~Config() override;
|
||||
|
||||
/// GLSL generator wraps a single entry point in a main() function.
|
||||
std::string entry_point;
|
||||
|
||||
/// Set to `true` to disable workgroup memory zero initialization
|
||||
bool disable_workgroup_init = false;
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
Glsl();
|
||||
~Glsl() override;
|
||||
|
||||
/// Runs the transform on `program`, returning the transformation result.
|
||||
/// @param program the source program to transform
|
||||
/// @param data optional extra transform-specific data
|
||||
/// @returns the transformation result
|
||||
Output Run(const Program* program, const DataMap& data = {}) const override;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_GLSL_H_
|
||||
41
src/tint/transform/glsl_test.cc
Normal file
41
src/tint/transform/glsl_test.cc
Normal file
@@ -0,0 +1,41 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/glsl.h"
|
||||
|
||||
#include "src/tint/transform/test_helper.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using GlslTest = TransformTest;
|
||||
|
||||
TEST_F(GlslTest, AddEmptyEntryPoint) {
|
||||
auto* src = R"()";
|
||||
|
||||
auto* expect = R"(
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn unused_entry_point() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Glsl>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
224
src/tint/transform/localize_struct_array_assignment.cc
Normal file
224
src/tint/transform/localize_struct_array_assignment.cc
Normal file
@@ -0,0 +1,224 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/localize_struct_array_assignment.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/ast/assignment_statement.h"
|
||||
#include "src/tint/ast/traverse_expressions.h"
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/expression.h"
|
||||
#include "src/tint/sem/member_accessor_expression.h"
|
||||
#include "src/tint/sem/reference_type.h"
|
||||
#include "src/tint/sem/statement.h"
|
||||
#include "src/tint/sem/variable.h"
|
||||
#include "src/tint/transform/simplify_pointers.h"
|
||||
#include "src/tint/utils/scoped_assignment.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::LocalizeStructArrayAssignment);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// Private implementation of LocalizeStructArrayAssignment transform
|
||||
class LocalizeStructArrayAssignment::State {
|
||||
private:
|
||||
CloneContext& ctx;
|
||||
ProgramBuilder& b;
|
||||
|
||||
/// Returns true if `expr` contains an index accessor expression to a
|
||||
/// structure member of array type.
|
||||
bool ContainsStructArrayIndex(const ast::Expression* expr) {
|
||||
bool result = false;
|
||||
ast::TraverseExpressions(
|
||||
expr, b.Diagnostics(), [&](const ast::IndexAccessorExpression* ia) {
|
||||
// Indexing using a runtime value?
|
||||
auto* idx_sem = ctx.src->Sem().Get(ia->index);
|
||||
if (!idx_sem->ConstantValue().IsValid()) {
|
||||
// Indexing a member access expr?
|
||||
if (auto* ma = ia->object->As<ast::MemberAccessorExpression>()) {
|
||||
// That accesses an array?
|
||||
if (ctx.src->TypeOf(ma)->UnwrapRef()->Is<sem::Array>()) {
|
||||
result = true;
|
||||
return ast::TraverseAction::Stop;
|
||||
}
|
||||
}
|
||||
}
|
||||
return ast::TraverseAction::Descend;
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns the type and storage class of the originating variable of the lhs
|
||||
// of the assignment statement.
|
||||
// See https://www.w3.org/TR/WGSL/#originating-variable-section
|
||||
std::pair<const sem::Type*, ast::StorageClass>
|
||||
GetOriginatingTypeAndStorageClass(
|
||||
const ast::AssignmentStatement* assign_stmt) {
|
||||
// Get first IdentifierExpr from lhs of assignment, which should resolve to
|
||||
// the pointer or reference of the originating variable of the assignment.
|
||||
// TraverseExpressions traverses left to right, and this code depends on the
|
||||
// fact that for an assignment statement, the variable will be the left-most
|
||||
// expression.
|
||||
// TODO(crbug.com/tint/1341): do this in the Resolver, setting the
|
||||
// originating variable on sem::Expression.
|
||||
const ast::IdentifierExpression* ident = nullptr;
|
||||
ast::TraverseExpressions(assign_stmt->lhs, b.Diagnostics(),
|
||||
[&](const ast::IdentifierExpression* id) {
|
||||
ident = id;
|
||||
return ast::TraverseAction::Stop;
|
||||
});
|
||||
auto* sem_var_user = ctx.src->Sem().Get<sem::VariableUser>(ident);
|
||||
if (!sem_var_user) {
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "Expected to find variable of lhs of assignment statement";
|
||||
return {};
|
||||
}
|
||||
|
||||
auto* var = sem_var_user->Variable();
|
||||
if (auto* ptr = var->Type()->As<sem::Pointer>()) {
|
||||
return {ptr->StoreType(), ptr->StorageClass()};
|
||||
}
|
||||
|
||||
auto* ref = var->Type()->As<sem::Reference>();
|
||||
if (!ref) {
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "Expecting to find variable of type pointer or reference on lhs "
|
||||
"of assignment statement";
|
||||
return {};
|
||||
}
|
||||
|
||||
return {ref->StoreType(), ref->StorageClass()};
|
||||
}
|
||||
|
||||
public:
|
||||
/// Constructor
|
||||
/// @param ctx_in the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {}
|
||||
|
||||
/// Runs the transform
|
||||
void Run() {
|
||||
struct Shared {
|
||||
bool process_nested_nodes = false;
|
||||
ast::StatementList insert_before_stmts;
|
||||
ast::StatementList insert_after_stmts;
|
||||
} s;
|
||||
|
||||
ctx.ReplaceAll([&](const ast::AssignmentStatement* assign_stmt)
|
||||
-> const ast::Statement* {
|
||||
// Process if it's an assignment statement to a dynamically indexed array
|
||||
// within a struct on a function or private storage variable. This
|
||||
// specific use-case is what FXC fails to compile with:
|
||||
// error X3500: array reference cannot be used as an l-value; not natively
|
||||
// addressable
|
||||
if (!ContainsStructArrayIndex(assign_stmt->lhs)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto og = GetOriginatingTypeAndStorageClass(assign_stmt);
|
||||
if (!(og.first->Is<sem::Struct>() &&
|
||||
(og.second == ast::StorageClass::kFunction ||
|
||||
og.second == ast::StorageClass::kPrivate))) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Reset shared state for this assignment statement
|
||||
s = Shared{};
|
||||
|
||||
const ast::Expression* new_lhs = nullptr;
|
||||
{
|
||||
TINT_SCOPED_ASSIGNMENT(s.process_nested_nodes, true);
|
||||
new_lhs = ctx.Clone(assign_stmt->lhs);
|
||||
}
|
||||
|
||||
auto* new_assign_stmt = b.Assign(new_lhs, ctx.Clone(assign_stmt->rhs));
|
||||
|
||||
// Combine insert_before_stmts + new_assign_stmt + insert_after_stmts into
|
||||
// a block and return it
|
||||
ast::StatementList stmts = std::move(s.insert_before_stmts);
|
||||
stmts.reserve(1 + s.insert_after_stmts.size());
|
||||
stmts.emplace_back(new_assign_stmt);
|
||||
stmts.insert(stmts.end(), s.insert_after_stmts.begin(),
|
||||
s.insert_after_stmts.end());
|
||||
|
||||
return b.Block(std::move(stmts));
|
||||
});
|
||||
|
||||
ctx.ReplaceAll([&](const ast::IndexAccessorExpression* index_access)
|
||||
-> const ast::Expression* {
|
||||
if (!s.process_nested_nodes) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Indexing a member access expr?
|
||||
auto* mem_access =
|
||||
index_access->object->As<ast::MemberAccessorExpression>();
|
||||
if (!mem_access) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Process any nested IndexAccessorExpressions
|
||||
mem_access = ctx.Clone(mem_access);
|
||||
|
||||
// Store the address of the member access into a let as we need to read
|
||||
// the value twice e.g. let tint_symbol = &(s.a1);
|
||||
auto mem_access_ptr = b.Sym();
|
||||
s.insert_before_stmts.push_back(
|
||||
b.Decl(b.Const(mem_access_ptr, nullptr, b.AddressOf(mem_access))));
|
||||
|
||||
// Disable further transforms when cloning
|
||||
TINT_SCOPED_ASSIGNMENT(s.process_nested_nodes, false);
|
||||
|
||||
// Copy entire array out of struct into local temp var
|
||||
// e.g. var tint_symbol_1 = *(tint_symbol);
|
||||
auto tmp_var = b.Sym();
|
||||
s.insert_before_stmts.push_back(
|
||||
b.Decl(b.Var(tmp_var, nullptr, b.Deref(mem_access_ptr))));
|
||||
|
||||
// Replace input index_access with a clone of itself, but with its
|
||||
// .object replaced by the new temp var. This is returned from this
|
||||
// function to modify the original assignment statement. e.g.
|
||||
// tint_symbol_1[uniforms.i]
|
||||
auto* new_index_access =
|
||||
b.IndexAccessor(tmp_var, ctx.Clone(index_access->index));
|
||||
|
||||
// Assign temp var back to array
|
||||
// e.g. *(tint_symbol) = tint_symbol_1;
|
||||
auto* assign_rhs_to_temp = b.Assign(b.Deref(mem_access_ptr), tmp_var);
|
||||
s.insert_after_stmts.insert(s.insert_after_stmts.begin(),
|
||||
assign_rhs_to_temp); // push_front
|
||||
|
||||
return new_index_access;
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
}
|
||||
};
|
||||
|
||||
LocalizeStructArrayAssignment::LocalizeStructArrayAssignment() = default;
|
||||
|
||||
LocalizeStructArrayAssignment::~LocalizeStructArrayAssignment() = default;
|
||||
|
||||
void LocalizeStructArrayAssignment::Run(CloneContext& ctx,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
State state(ctx);
|
||||
state.Run();
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
58
src/tint/transform/localize_struct_array_assignment.h
Normal file
58
src/tint/transform/localize_struct_array_assignment.h
Normal file
@@ -0,0 +1,58 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_LOCALIZE_STRUCT_ARRAY_ASSIGNMENT_H_
|
||||
#define SRC_TINT_TRANSFORM_LOCALIZE_STRUCT_ARRAY_ASSIGNMENT_H_
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// This transforms replaces assignment to dynamically-indexed fixed-size arrays
|
||||
/// in structs on shader-local variables with code that copies the arrays to a
|
||||
/// temporary local variable, assigns to the local variable, and copies the
|
||||
/// array back. This is to work around FXC's compilation failure for these cases
|
||||
/// (see crbug.com/tint/1206).
|
||||
///
|
||||
/// @note Depends on the following transforms to have been run first:
|
||||
/// * SimplifyPointers
|
||||
class LocalizeStructArrayAssignment
|
||||
: public Castable<LocalizeStructArrayAssignment, Transform> {
|
||||
public:
|
||||
/// Constructor
|
||||
LocalizeStructArrayAssignment();
|
||||
|
||||
/// Destructor
|
||||
~LocalizeStructArrayAssignment() override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
private:
|
||||
class State;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_LOCALIZE_STRUCT_ARRAY_ASSIGNMENT_H_
|
||||
884
src/tint/transform/localize_struct_array_assignment_test.cc
Normal file
884
src/tint/transform/localize_struct_array_assignment_test.cc
Normal file
@@ -0,0 +1,884 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/localize_struct_array_assignment.h"
|
||||
#include "src/tint/transform/simplify_pointers.h"
|
||||
#include "src/tint/transform/unshadow.h"
|
||||
|
||||
#include "src/tint/transform/test_helper.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using LocalizeStructArrayAssignmentTest = TransformTest;
|
||||
|
||||
TEST_F(LocalizeStructArrayAssignmentTest, EmptyModule) {
|
||||
auto* src = R"()";
|
||||
auto* expect = src;
|
||||
auto got =
|
||||
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(LocalizeStructArrayAssignmentTest, StructArray) {
|
||||
auto* src = R"(
|
||||
@block struct Uniforms {
|
||||
i : u32;
|
||||
};
|
||||
|
||||
struct InnerS {
|
||||
v : i32;
|
||||
};
|
||||
|
||||
struct OuterS {
|
||||
a1 : array<InnerS, 8>;
|
||||
};
|
||||
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var v : InnerS;
|
||||
var s1 : OuterS;
|
||||
s1.a1[uniforms.i] = v;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@block
|
||||
struct Uniforms {
|
||||
i : u32;
|
||||
}
|
||||
|
||||
struct InnerS {
|
||||
v : i32;
|
||||
}
|
||||
|
||||
struct OuterS {
|
||||
a1 : array<InnerS, 8>;
|
||||
}
|
||||
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var v : InnerS;
|
||||
var s1 : OuterS;
|
||||
{
|
||||
let tint_symbol = &(s1.a1);
|
||||
var tint_symbol_1 = *(tint_symbol);
|
||||
tint_symbol_1[uniforms.i] = v;
|
||||
*(tint_symbol) = tint_symbol_1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got =
|
||||
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(LocalizeStructArrayAssignmentTest, StructArray_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var v : InnerS;
|
||||
var s1 : OuterS;
|
||||
s1.a1[uniforms.i] = v;
|
||||
}
|
||||
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
struct OuterS {
|
||||
a1 : array<InnerS, 8>;
|
||||
};
|
||||
|
||||
struct InnerS {
|
||||
v : i32;
|
||||
};
|
||||
|
||||
@block struct Uniforms {
|
||||
i : u32;
|
||||
};
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var v : InnerS;
|
||||
var s1 : OuterS;
|
||||
{
|
||||
let tint_symbol = &(s1.a1);
|
||||
var tint_symbol_1 = *(tint_symbol);
|
||||
tint_symbol_1[uniforms.i] = v;
|
||||
*(tint_symbol) = tint_symbol_1;
|
||||
}
|
||||
}
|
||||
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
struct OuterS {
|
||||
a1 : array<InnerS, 8>;
|
||||
}
|
||||
|
||||
struct InnerS {
|
||||
v : i32;
|
||||
}
|
||||
|
||||
@block
|
||||
struct Uniforms {
|
||||
i : u32;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got =
|
||||
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(LocalizeStructArrayAssignmentTest, StructStructArray) {
|
||||
auto* src = R"(
|
||||
@block struct Uniforms {
|
||||
i : u32;
|
||||
};
|
||||
|
||||
struct InnerS {
|
||||
v : i32;
|
||||
};
|
||||
|
||||
struct S1 {
|
||||
a : array<InnerS, 8>;
|
||||
};
|
||||
|
||||
struct OuterS {
|
||||
s2 : S1;
|
||||
};
|
||||
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var v : InnerS;
|
||||
var s1 : OuterS;
|
||||
s1.s2.a[uniforms.i] = v;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@block
|
||||
struct Uniforms {
|
||||
i : u32;
|
||||
}
|
||||
|
||||
struct InnerS {
|
||||
v : i32;
|
||||
}
|
||||
|
||||
struct S1 {
|
||||
a : array<InnerS, 8>;
|
||||
}
|
||||
|
||||
struct OuterS {
|
||||
s2 : S1;
|
||||
}
|
||||
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var v : InnerS;
|
||||
var s1 : OuterS;
|
||||
{
|
||||
let tint_symbol = &(s1.s2.a);
|
||||
var tint_symbol_1 = *(tint_symbol);
|
||||
tint_symbol_1[uniforms.i] = v;
|
||||
*(tint_symbol) = tint_symbol_1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got =
|
||||
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(LocalizeStructArrayAssignmentTest, StructStructArray_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var v : InnerS;
|
||||
var s1 : OuterS;
|
||||
s1.s2.a[uniforms.i] = v;
|
||||
}
|
||||
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
struct OuterS {
|
||||
s2 : S1;
|
||||
};
|
||||
|
||||
struct S1 {
|
||||
a : array<InnerS, 8>;
|
||||
};
|
||||
|
||||
struct InnerS {
|
||||
v : i32;
|
||||
};
|
||||
|
||||
@block struct Uniforms {
|
||||
i : u32;
|
||||
};
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var v : InnerS;
|
||||
var s1 : OuterS;
|
||||
{
|
||||
let tint_symbol = &(s1.s2.a);
|
||||
var tint_symbol_1 = *(tint_symbol);
|
||||
tint_symbol_1[uniforms.i] = v;
|
||||
*(tint_symbol) = tint_symbol_1;
|
||||
}
|
||||
}
|
||||
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
struct OuterS {
|
||||
s2 : S1;
|
||||
}
|
||||
|
||||
struct S1 {
|
||||
a : array<InnerS, 8>;
|
||||
}
|
||||
|
||||
struct InnerS {
|
||||
v : i32;
|
||||
}
|
||||
|
||||
@block
|
||||
struct Uniforms {
|
||||
i : u32;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got =
|
||||
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(LocalizeStructArrayAssignmentTest, StructArrayArray) {
|
||||
auto* src = R"(
|
||||
@block struct Uniforms {
|
||||
i : u32;
|
||||
j : u32;
|
||||
};
|
||||
|
||||
struct InnerS {
|
||||
v : i32;
|
||||
};
|
||||
|
||||
struct OuterS {
|
||||
a1 : array<array<InnerS, 8>, 8>;
|
||||
};
|
||||
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var v : InnerS;
|
||||
var s1 : OuterS;
|
||||
s1.a1[uniforms.i][uniforms.j] = v;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@block
|
||||
struct Uniforms {
|
||||
i : u32;
|
||||
j : u32;
|
||||
}
|
||||
|
||||
struct InnerS {
|
||||
v : i32;
|
||||
}
|
||||
|
||||
struct OuterS {
|
||||
a1 : array<array<InnerS, 8>, 8>;
|
||||
}
|
||||
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var v : InnerS;
|
||||
var s1 : OuterS;
|
||||
{
|
||||
let tint_symbol = &(s1.a1);
|
||||
var tint_symbol_1 = *(tint_symbol);
|
||||
tint_symbol_1[uniforms.i][uniforms.j] = v;
|
||||
*(tint_symbol) = tint_symbol_1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got =
|
||||
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(LocalizeStructArrayAssignmentTest, StructArrayStruct) {
|
||||
auto* src = R"(
|
||||
@block struct Uniforms {
|
||||
i : u32;
|
||||
};
|
||||
|
||||
struct InnerS {
|
||||
v : i32;
|
||||
};
|
||||
|
||||
struct S1 {
|
||||
s2 : InnerS;
|
||||
};
|
||||
|
||||
struct OuterS {
|
||||
a1 : array<S1, 8>;
|
||||
};
|
||||
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var v : InnerS;
|
||||
var s1 : OuterS;
|
||||
s1.a1[uniforms.i].s2 = v;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@block
|
||||
struct Uniforms {
|
||||
i : u32;
|
||||
}
|
||||
|
||||
struct InnerS {
|
||||
v : i32;
|
||||
}
|
||||
|
||||
struct S1 {
|
||||
s2 : InnerS;
|
||||
}
|
||||
|
||||
struct OuterS {
|
||||
a1 : array<S1, 8>;
|
||||
}
|
||||
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var v : InnerS;
|
||||
var s1 : OuterS;
|
||||
{
|
||||
let tint_symbol = &(s1.a1);
|
||||
var tint_symbol_1 = *(tint_symbol);
|
||||
tint_symbol_1[uniforms.i].s2 = v;
|
||||
*(tint_symbol) = tint_symbol_1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got =
|
||||
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(LocalizeStructArrayAssignmentTest, StructArrayStructArray) {
|
||||
auto* src = R"(
|
||||
@block struct Uniforms {
|
||||
i : u32;
|
||||
j : u32;
|
||||
};
|
||||
|
||||
struct InnerS {
|
||||
v : i32;
|
||||
};
|
||||
|
||||
struct S1 {
|
||||
a2 : array<InnerS, 8>;
|
||||
};
|
||||
|
||||
struct OuterS {
|
||||
a1 : array<S1, 8>;
|
||||
};
|
||||
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var v : InnerS;
|
||||
var s : OuterS;
|
||||
s.a1[uniforms.i].a2[uniforms.j] = v;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@block
|
||||
struct Uniforms {
|
||||
i : u32;
|
||||
j : u32;
|
||||
}
|
||||
|
||||
struct InnerS {
|
||||
v : i32;
|
||||
}
|
||||
|
||||
struct S1 {
|
||||
a2 : array<InnerS, 8>;
|
||||
}
|
||||
|
||||
struct OuterS {
|
||||
a1 : array<S1, 8>;
|
||||
}
|
||||
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var v : InnerS;
|
||||
var s : OuterS;
|
||||
{
|
||||
let tint_symbol = &(s.a1);
|
||||
var tint_symbol_1 = *(tint_symbol);
|
||||
let tint_symbol_2 = &(tint_symbol_1[uniforms.i].a2);
|
||||
var tint_symbol_3 = *(tint_symbol_2);
|
||||
tint_symbol_3[uniforms.j] = v;
|
||||
*(tint_symbol_2) = tint_symbol_3;
|
||||
*(tint_symbol) = tint_symbol_1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got =
|
||||
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(LocalizeStructArrayAssignmentTest, IndexingWithSideEffectFunc) {
|
||||
auto* src = R"(
|
||||
@block struct Uniforms {
|
||||
i : u32;
|
||||
j : u32;
|
||||
};
|
||||
|
||||
struct InnerS {
|
||||
v : i32;
|
||||
};
|
||||
|
||||
struct S1 {
|
||||
a2 : array<InnerS, 8>;
|
||||
};
|
||||
|
||||
struct OuterS {
|
||||
a1 : array<S1, 8>;
|
||||
};
|
||||
|
||||
var<private> nextIndex : u32;
|
||||
fn getNextIndex() -> u32 {
|
||||
nextIndex = nextIndex + 1u;
|
||||
return nextIndex;
|
||||
}
|
||||
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var v : InnerS;
|
||||
var s : OuterS;
|
||||
s.a1[getNextIndex()].a2[uniforms.j] = v;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@block
|
||||
struct Uniforms {
|
||||
i : u32;
|
||||
j : u32;
|
||||
}
|
||||
|
||||
struct InnerS {
|
||||
v : i32;
|
||||
}
|
||||
|
||||
struct S1 {
|
||||
a2 : array<InnerS, 8>;
|
||||
}
|
||||
|
||||
struct OuterS {
|
||||
a1 : array<S1, 8>;
|
||||
}
|
||||
|
||||
var<private> nextIndex : u32;
|
||||
|
||||
fn getNextIndex() -> u32 {
|
||||
nextIndex = (nextIndex + 1u);
|
||||
return nextIndex;
|
||||
}
|
||||
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var v : InnerS;
|
||||
var s : OuterS;
|
||||
{
|
||||
let tint_symbol = &(s.a1);
|
||||
var tint_symbol_1 = *(tint_symbol);
|
||||
let tint_symbol_2 = &(tint_symbol_1[getNextIndex()].a2);
|
||||
var tint_symbol_3 = *(tint_symbol_2);
|
||||
tint_symbol_3[uniforms.j] = v;
|
||||
*(tint_symbol_2) = tint_symbol_3;
|
||||
*(tint_symbol) = tint_symbol_1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got =
|
||||
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(LocalizeStructArrayAssignmentTest,
|
||||
IndexingWithSideEffectFunc_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var v : InnerS;
|
||||
var s : OuterS;
|
||||
s.a1[getNextIndex()].a2[uniforms.j] = v;
|
||||
}
|
||||
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
@block struct Uniforms {
|
||||
i : u32;
|
||||
j : u32;
|
||||
};
|
||||
|
||||
var<private> nextIndex : u32;
|
||||
fn getNextIndex() -> u32 {
|
||||
nextIndex = nextIndex + 1u;
|
||||
return nextIndex;
|
||||
}
|
||||
|
||||
struct OuterS {
|
||||
a1 : array<S1, 8>;
|
||||
};
|
||||
|
||||
struct S1 {
|
||||
a2 : array<InnerS, 8>;
|
||||
};
|
||||
|
||||
struct InnerS {
|
||||
v : i32;
|
||||
};
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var v : InnerS;
|
||||
var s : OuterS;
|
||||
{
|
||||
let tint_symbol = &(s.a1);
|
||||
var tint_symbol_1 = *(tint_symbol);
|
||||
let tint_symbol_2 = &(tint_symbol_1[getNextIndex()].a2);
|
||||
var tint_symbol_3 = *(tint_symbol_2);
|
||||
tint_symbol_3[uniforms.j] = v;
|
||||
*(tint_symbol_2) = tint_symbol_3;
|
||||
*(tint_symbol) = tint_symbol_1;
|
||||
}
|
||||
}
|
||||
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
@block
|
||||
struct Uniforms {
|
||||
i : u32;
|
||||
j : u32;
|
||||
}
|
||||
|
||||
var<private> nextIndex : u32;
|
||||
|
||||
fn getNextIndex() -> u32 {
|
||||
nextIndex = (nextIndex + 1u);
|
||||
return nextIndex;
|
||||
}
|
||||
|
||||
struct OuterS {
|
||||
a1 : array<S1, 8>;
|
||||
}
|
||||
|
||||
struct S1 {
|
||||
a2 : array<InnerS, 8>;
|
||||
}
|
||||
|
||||
struct InnerS {
|
||||
v : i32;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got =
|
||||
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(LocalizeStructArrayAssignmentTest, ViaPointerArg) {
|
||||
auto* src = R"(
|
||||
@block struct Uniforms {
|
||||
i : u32;
|
||||
};
|
||||
struct InnerS {
|
||||
v : i32;
|
||||
};
|
||||
struct OuterS {
|
||||
a1 : array<InnerS, 8>;
|
||||
};
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
fn f(p : ptr<function, OuterS>) {
|
||||
var v : InnerS;
|
||||
(*p).a1[uniforms.i] = v;
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var s1 : OuterS;
|
||||
f(&s1);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@block
|
||||
struct Uniforms {
|
||||
i : u32;
|
||||
}
|
||||
|
||||
struct InnerS {
|
||||
v : i32;
|
||||
}
|
||||
|
||||
struct OuterS {
|
||||
a1 : array<InnerS, 8>;
|
||||
}
|
||||
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
fn f(p : ptr<function, OuterS>) {
|
||||
var v : InnerS;
|
||||
{
|
||||
let tint_symbol = &((*(p)).a1);
|
||||
var tint_symbol_1 = *(tint_symbol);
|
||||
tint_symbol_1[uniforms.i] = v;
|
||||
*(tint_symbol) = tint_symbol_1;
|
||||
}
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var s1 : OuterS;
|
||||
f(&(s1));
|
||||
}
|
||||
)";
|
||||
|
||||
auto got =
|
||||
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(LocalizeStructArrayAssignmentTest, ViaPointerArg_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var s1 : OuterS;
|
||||
f(&s1);
|
||||
}
|
||||
|
||||
fn f(p : ptr<function, OuterS>) {
|
||||
var v : InnerS;
|
||||
(*p).a1[uniforms.i] = v;
|
||||
}
|
||||
|
||||
struct InnerS {
|
||||
v : i32;
|
||||
};
|
||||
struct OuterS {
|
||||
a1 : array<InnerS, 8>;
|
||||
};
|
||||
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
@block struct Uniforms {
|
||||
i : u32;
|
||||
};
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var s1 : OuterS;
|
||||
f(&(s1));
|
||||
}
|
||||
|
||||
fn f(p : ptr<function, OuterS>) {
|
||||
var v : InnerS;
|
||||
{
|
||||
let tint_symbol = &((*(p)).a1);
|
||||
var tint_symbol_1 = *(tint_symbol);
|
||||
tint_symbol_1[uniforms.i] = v;
|
||||
*(tint_symbol) = tint_symbol_1;
|
||||
}
|
||||
}
|
||||
|
||||
struct InnerS {
|
||||
v : i32;
|
||||
}
|
||||
|
||||
struct OuterS {
|
||||
a1 : array<InnerS, 8>;
|
||||
}
|
||||
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
@block
|
||||
struct Uniforms {
|
||||
i : u32;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got =
|
||||
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(LocalizeStructArrayAssignmentTest, ViaPointerVar) {
|
||||
auto* src = R"(
|
||||
@block
|
||||
struct Uniforms {
|
||||
i : u32;
|
||||
};
|
||||
|
||||
struct InnerS {
|
||||
v : i32;
|
||||
};
|
||||
|
||||
struct OuterS {
|
||||
a1 : array<InnerS, 8>;
|
||||
};
|
||||
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
fn f(p : ptr<function, InnerS>, v : InnerS) {
|
||||
*(p) = v;
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var v : InnerS;
|
||||
var s1 : OuterS;
|
||||
let p = &(s1.a1[uniforms.i]);
|
||||
*(p) = v;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@block
|
||||
struct Uniforms {
|
||||
i : u32;
|
||||
}
|
||||
|
||||
struct InnerS {
|
||||
v : i32;
|
||||
}
|
||||
|
||||
struct OuterS {
|
||||
a1 : array<InnerS, 8>;
|
||||
}
|
||||
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
fn f(p : ptr<function, InnerS>, v : InnerS) {
|
||||
*(p) = v;
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var v : InnerS;
|
||||
var s1 : OuterS;
|
||||
let p_save = uniforms.i;
|
||||
{
|
||||
let tint_symbol = &(s1.a1);
|
||||
var tint_symbol_1 = *(tint_symbol);
|
||||
tint_symbol_1[p_save] = v;
|
||||
*(tint_symbol) = tint_symbol_1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got =
|
||||
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(LocalizeStructArrayAssignmentTest, VectorAssignment) {
|
||||
auto* src = R"(
|
||||
@block
|
||||
struct Uniforms {
|
||||
i : u32;
|
||||
}
|
||||
|
||||
@block
|
||||
struct OuterS {
|
||||
a1 : array<u32, 8>;
|
||||
}
|
||||
|
||||
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
|
||||
|
||||
fn f(i : u32) -> u32 {
|
||||
return (i + 1u);
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var s1 : OuterS;
|
||||
var v : vec3<f32>;
|
||||
v[s1.a1[uniforms.i]] = 1.0;
|
||||
v[f(s1.a1[uniforms.i])] = 1.0;
|
||||
}
|
||||
)";
|
||||
|
||||
// Transform does nothing here as we're not actually assigning to the array in
|
||||
// the struct.
|
||||
auto* expect = src;
|
||||
|
||||
auto got =
|
||||
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
145
src/tint/transform/loop_to_for_loop.cc
Normal file
145
src/tint/transform/loop_to_for_loop.cc
Normal file
@@ -0,0 +1,145 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/loop_to_for_loop.h"
|
||||
|
||||
#include "src/tint/ast/break_statement.h"
|
||||
#include "src/tint/ast/for_loop_statement.h"
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/block_statement.h"
|
||||
#include "src/tint/sem/function.h"
|
||||
#include "src/tint/sem/statement.h"
|
||||
#include "src/tint/sem/variable.h"
|
||||
#include "src/tint/utils/scoped_assignment.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::LoopToForLoop);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
bool IsBlockWithSingleBreak(const ast::BlockStatement* block) {
|
||||
if (block->statements.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
return block->statements[0]->Is<ast::BreakStatement>();
|
||||
}
|
||||
|
||||
bool IsVarUsedByStmt(const sem::Info& sem,
|
||||
const ast::Variable* var,
|
||||
const ast::Statement* stmt) {
|
||||
auto* var_sem = sem.Get(var);
|
||||
for (auto* user : var_sem->Users()) {
|
||||
if (auto* s = user->Stmt()) {
|
||||
if (s->Declaration() == stmt) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LoopToForLoop::LoopToForLoop() = default;
|
||||
|
||||
LoopToForLoop::~LoopToForLoop() = default;
|
||||
|
||||
bool LoopToForLoop::ShouldRun(const Program* program, const DataMap&) const {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (node->Is<ast::LoopStatement>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void LoopToForLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
ctx.ReplaceAll([&](const ast::LoopStatement* loop) -> const ast::Statement* {
|
||||
// For loop condition is taken from the first statement in the loop.
|
||||
// This requires an if-statement with either:
|
||||
// * A true block with no else statements, and the true block contains a
|
||||
// single 'break' statement.
|
||||
// * An empty true block with a single, no-condition else statement
|
||||
// containing a single 'break' statement.
|
||||
// Examples:
|
||||
// loop { if (condition) { break; } ... }
|
||||
// loop { if (condition) {} else { break; } ... }
|
||||
auto& stmts = loop->body->statements;
|
||||
if (stmts.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto* if_stmt = stmts[0]->As<ast::IfStatement>();
|
||||
if (!if_stmt) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool negate_condition = false;
|
||||
if (IsBlockWithSingleBreak(if_stmt->body) &&
|
||||
if_stmt->else_statements.empty()) {
|
||||
negate_condition = true;
|
||||
} else if (if_stmt->body->Empty() && if_stmt->else_statements.size() == 1 &&
|
||||
if_stmt->else_statements[0]->condition == nullptr &&
|
||||
IsBlockWithSingleBreak(if_stmt->else_statements[0]->body)) {
|
||||
negate_condition = false;
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// The continuing block must be empty or contain a single, assignment or
|
||||
// function call statement.
|
||||
const ast::Statement* continuing = nullptr;
|
||||
if (auto* loop_cont = loop->continuing) {
|
||||
if (loop_cont->statements.size() != 1) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
continuing = loop_cont->statements[0];
|
||||
if (!continuing
|
||||
->IsAnyOf<ast::AssignmentStatement, ast::CallStatement>()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// And the continuing statement must not use any of the variables declared
|
||||
// in the loop body.
|
||||
for (auto* stmt : loop->body->statements) {
|
||||
if (auto* var_decl = stmt->As<ast::VariableDeclStatement>()) {
|
||||
if (IsVarUsedByStmt(ctx.src->Sem(), var_decl->variable, continuing)) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
continuing = ctx.Clone(continuing);
|
||||
}
|
||||
|
||||
auto* condition = ctx.Clone(if_stmt->condition);
|
||||
if (negate_condition) {
|
||||
condition = ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNot,
|
||||
condition);
|
||||
}
|
||||
|
||||
ast::Statement* initializer = nullptr;
|
||||
|
||||
ctx.Remove(loop->body->statements, if_stmt);
|
||||
auto* body = ctx.Clone(loop->body);
|
||||
return ctx.dst->create<ast::ForLoopStatement>(initializer, condition,
|
||||
continuing, body);
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
54
src/tint/transform/loop_to_for_loop.h
Normal file
54
src/tint/transform/loop_to_for_loop.h
Normal file
@@ -0,0 +1,54 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_LOOP_TO_FOR_LOOP_H_
|
||||
#define SRC_TINT_TRANSFORM_LOOP_TO_FOR_LOOP_H_
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// LoopToForLoop is a Transform that attempts to convert WGSL `loop {}`
|
||||
/// statements into a for-loop statement.
|
||||
class LoopToForLoop : public Castable<LoopToForLoop, Transform> {
|
||||
public:
|
||||
/// Constructor
|
||||
LoopToForLoop();
|
||||
|
||||
/// Destructor
|
||||
~LoopToForLoop() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_LOOP_TO_FOR_LOOP_H_
|
||||
308
src/tint/transform/loop_to_for_loop_test.cc
Normal file
308
src/tint/transform/loop_to_for_loop_test.cc
Normal file
@@ -0,0 +1,308 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/loop_to_for_loop.h"
|
||||
|
||||
#include "src/tint/transform/test_helper.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using LoopToForLoopTest = TransformTest;
|
||||
|
||||
TEST_F(LoopToForLoopTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<LoopToForLoop>(src));
|
||||
}
|
||||
|
||||
TEST_F(LoopToForLoopTest, ShouldRunHasForLoop) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
loop {
|
||||
break;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<LoopToForLoop>(src));
|
||||
}
|
||||
|
||||
TEST_F(LoopToForLoopTest, EmptyModule) {
|
||||
auto* src = "";
|
||||
auto* expect = "";
|
||||
|
||||
auto got = Run<LoopToForLoop>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(LoopToForLoopTest, IfBreak) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var i : i32;
|
||||
i = 0;
|
||||
loop {
|
||||
if (i > 15) {
|
||||
break;
|
||||
}
|
||||
|
||||
_ = 123;
|
||||
|
||||
continuing {
|
||||
i = i + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var i : i32;
|
||||
i = 0;
|
||||
for(; !((i > 15)); i = (i + 1)) {
|
||||
_ = 123;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<LoopToForLoop>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(LoopToForLoopTest, IfElseBreak) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var i : i32;
|
||||
i = 0;
|
||||
loop {
|
||||
if (i < 15) {
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
|
||||
_ = 123;
|
||||
|
||||
continuing {
|
||||
i = i + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var i : i32;
|
||||
i = 0;
|
||||
for(; (i < 15); i = (i + 1)) {
|
||||
_ = 123;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<LoopToForLoop>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(LoopToForLoopTest, Nested) {
|
||||
auto* src = R"(
|
||||
let N = 16u;
|
||||
|
||||
fn f() {
|
||||
var i : u32 = 0u;
|
||||
loop {
|
||||
if (i >= N) {
|
||||
break;
|
||||
}
|
||||
{
|
||||
var j : u32 = 0u;
|
||||
loop {
|
||||
if (j >= N) {
|
||||
break;
|
||||
}
|
||||
|
||||
_ = i;
|
||||
_ = j;
|
||||
|
||||
continuing {
|
||||
j = (j + 1u);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
continuing {
|
||||
i = (i + 1u);
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
let N = 16u;
|
||||
|
||||
fn f() {
|
||||
var i : u32 = 0u;
|
||||
for(; !((i >= N)); i = (i + 1u)) {
|
||||
{
|
||||
var j : u32 = 0u;
|
||||
for(; !((j >= N)); j = (j + 1u)) {
|
||||
_ = i;
|
||||
_ = j;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<LoopToForLoop>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(LoopToForLoopTest, NoTransform_IfMultipleStmts) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var i : i32;
|
||||
i = 0;
|
||||
loop {
|
||||
if ((i < 15)) {
|
||||
_ = i;
|
||||
break;
|
||||
}
|
||||
_ = 123;
|
||||
|
||||
continuing {
|
||||
i = (i + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<LoopToForLoop>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(LoopToForLoopTest, NoTransform_IfElseMultipleStmts) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var i : i32;
|
||||
i = 0;
|
||||
loop {
|
||||
if ((i < 15)) {
|
||||
} else {
|
||||
_ = i;
|
||||
break;
|
||||
}
|
||||
_ = 123;
|
||||
|
||||
continuing {
|
||||
i = (i + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<LoopToForLoop>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(LoopToForLoopTest, NoTransform_ContinuingIsCompound) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var i : i32;
|
||||
i = 0;
|
||||
loop {
|
||||
if ((i < 15)) {
|
||||
break;
|
||||
}
|
||||
_ = 123;
|
||||
|
||||
continuing {
|
||||
if (false) {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<LoopToForLoop>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(LoopToForLoopTest, NoTransform_ContinuingMultipleStmts) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var i : i32;
|
||||
i = 0;
|
||||
loop {
|
||||
if ((i < 15)) {
|
||||
break;
|
||||
}
|
||||
_ = 123;
|
||||
|
||||
continuing {
|
||||
i = (i + 1);
|
||||
_ = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<LoopToForLoop>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(LoopToForLoopTest, NoTransform_ContinuingUsesVarDeclInLoopBody) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var i : i32;
|
||||
i = 0;
|
||||
loop {
|
||||
if ((i < 15)) {
|
||||
break;
|
||||
}
|
||||
var j : i32;
|
||||
|
||||
continuing {
|
||||
i = (i + j);
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<LoopToForLoop>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
86
src/tint/transform/manager.cc
Normal file
86
src/tint/transform/manager.cc
Normal file
@@ -0,0 +1,86 @@
|
||||
// Copyright 2020 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/manager.h"
|
||||
|
||||
/// If set to 1 then the transform::Manager will dump the WGSL of the program
|
||||
/// before and after each transform. Helpful for debugging bad output.
|
||||
#define TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM 0
|
||||
|
||||
#if TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM
|
||||
#define TINT_IF_PRINT_PROGRAM(x) x
|
||||
#else // TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM
|
||||
#define TINT_IF_PRINT_PROGRAM(x)
|
||||
#endif // TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::Manager);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
Manager::Manager() = default;
|
||||
Manager::~Manager() = default;
|
||||
|
||||
Output Manager::Run(const Program* program, const DataMap& data) const {
|
||||
const Program* in = program;
|
||||
|
||||
#if TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM
|
||||
auto print_program = [&](const char* msg, const Transform* transform) {
|
||||
auto wgsl = Program::printer(in);
|
||||
std::cout << "---------------------------------------------------------"
|
||||
<< std::endl;
|
||||
std::cout << "-- " << msg << " " << transform->TypeInfo().name << ":"
|
||||
<< std::endl;
|
||||
std::cout << "---------------------------------------------------------"
|
||||
<< std::endl;
|
||||
std::cout << wgsl << std::endl;
|
||||
std::cout << "---------------------------------------------------------"
|
||||
<< std::endl
|
||||
<< std::endl;
|
||||
};
|
||||
#endif
|
||||
|
||||
Output out;
|
||||
for (const auto& transform : transforms_) {
|
||||
if (!transform->ShouldRun(in, data)) {
|
||||
TINT_IF_PRINT_PROGRAM(std::cout << "Skipping "
|
||||
<< transform->TypeInfo().name);
|
||||
continue;
|
||||
}
|
||||
TINT_IF_PRINT_PROGRAM(print_program("Input to", transform.get()));
|
||||
|
||||
auto res = transform->Run(in, data);
|
||||
out.program = std::move(res.program);
|
||||
out.data.Add(std::move(res.data));
|
||||
in = &out.program;
|
||||
if (!in->IsValid()) {
|
||||
TINT_IF_PRINT_PROGRAM(
|
||||
print_program("Invalid output of", transform.get()));
|
||||
return out;
|
||||
}
|
||||
|
||||
if (transform == transforms_.back()) {
|
||||
TINT_IF_PRINT_PROGRAM(print_program("Output of", transform.get()));
|
||||
}
|
||||
}
|
||||
|
||||
if (program == in) {
|
||||
out.program = program->Clone();
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
64
src/tint/transform/manager.h
Normal file
64
src/tint/transform/manager.h
Normal file
@@ -0,0 +1,64 @@
|
||||
// Copyright 2020 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_MANAGER_H_
|
||||
#define SRC_TINT_TRANSFORM_MANAGER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// A collection of Transforms that act as a single Transform.
|
||||
/// The inner transforms will execute in the appended order.
|
||||
/// If any inner transform fails the manager will return immediately and
|
||||
/// the error can be retrieved with the Output's diagnostics.
|
||||
class Manager : public Castable<Manager, Transform> {
|
||||
public:
|
||||
/// Constructor
|
||||
Manager();
|
||||
~Manager() override;
|
||||
|
||||
/// Add pass to the manager
|
||||
/// @param transform the transform to append
|
||||
void append(std::unique_ptr<Transform> transform) {
|
||||
transforms_.push_back(std::move(transform));
|
||||
}
|
||||
|
||||
/// Add pass to the manager of type `T`, constructed with the provided
|
||||
/// arguments.
|
||||
/// @param args the arguments to forward to the `T` constructor
|
||||
template <typename T, typename... ARGS>
|
||||
void Add(ARGS&&... args) {
|
||||
transforms_.emplace_back(std::make_unique<T>(std::forward<ARGS>(args)...));
|
||||
}
|
||||
|
||||
/// Runs the transforms on `program`, returning the transformation result.
|
||||
/// @param program the source program to transform
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns the transformed program and diagnostics
|
||||
Output Run(const Program* program, const DataMap& data = {}) const override;
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<Transform>> transforms_;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_MANAGER_H_
|
||||
399
src/tint/transform/module_scope_var_to_entry_point_param.cc
Normal file
399
src/tint/transform/module_scope_var_to_entry_point_param.cc
Normal file
@@ -0,0 +1,399 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/module_scope_var_to_entry_point_param.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/tint/ast/disable_validation_attribute.h"
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/call.h"
|
||||
#include "src/tint/sem/function.h"
|
||||
#include "src/tint/sem/statement.h"
|
||||
#include "src/tint/sem/variable.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::ModuleScopeVarToEntryPointParam);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
// Returns `true` if `type` is or contains a matrix type.
|
||||
bool ContainsMatrix(const sem::Type* type) {
|
||||
type = type->UnwrapRef();
|
||||
if (type->Is<sem::Matrix>()) {
|
||||
return true;
|
||||
} else if (auto* ary = type->As<sem::Array>()) {
|
||||
return ContainsMatrix(ary->ElemType());
|
||||
} else if (auto* str = type->As<sem::Struct>()) {
|
||||
for (auto* member : str->Members()) {
|
||||
if (ContainsMatrix(member->Type())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/// State holds the current transform state.
|
||||
struct ModuleScopeVarToEntryPointParam::State {
|
||||
/// The clone context.
|
||||
CloneContext& ctx;
|
||||
|
||||
/// Constructor
|
||||
/// @param context the clone context
|
||||
explicit State(CloneContext& context) : ctx(context) {}
|
||||
|
||||
/// Clone any struct types that are contained in `ty` (including `ty` itself),
|
||||
/// and add it to the global declarations now, so that they precede new global
|
||||
/// declarations that need to reference them.
|
||||
/// @param ty the type to clone
|
||||
void CloneStructTypes(const sem::Type* ty) {
|
||||
if (auto* str = ty->As<sem::Struct>()) {
|
||||
if (!cloned_structs_.emplace(str).second) {
|
||||
// The struct has already been cloned.
|
||||
return;
|
||||
}
|
||||
|
||||
// Recurse into members.
|
||||
for (auto* member : str->Members()) {
|
||||
CloneStructTypes(member->Type());
|
||||
}
|
||||
|
||||
// Clone the struct and add it to the global declaration list.
|
||||
// Remove the old declaration.
|
||||
auto* ast_str = str->Declaration();
|
||||
ctx.dst->AST().AddTypeDecl(ctx.Clone(ast_str));
|
||||
ctx.Remove(ctx.src->AST().GlobalDeclarations(), ast_str);
|
||||
} else if (auto* arr = ty->As<sem::Array>()) {
|
||||
CloneStructTypes(arr->ElemType());
|
||||
}
|
||||
}
|
||||
|
||||
/// Process the module.
|
||||
void Process() {
|
||||
// Predetermine the list of function calls that need to be replaced.
|
||||
using CallList = std::vector<const ast::CallExpression*>;
|
||||
std::unordered_map<const ast::Function*, CallList> calls_to_replace;
|
||||
|
||||
std::vector<const ast::Function*> functions_to_process;
|
||||
|
||||
// Build a list of functions that transitively reference any module-scope
|
||||
// variables.
|
||||
for (auto* func_ast : ctx.src->AST().Functions()) {
|
||||
auto* func_sem = ctx.src->Sem().Get(func_ast);
|
||||
|
||||
bool needs_processing = false;
|
||||
for (auto* var : func_sem->TransitivelyReferencedGlobals()) {
|
||||
if (var->StorageClass() != ast::StorageClass::kNone) {
|
||||
needs_processing = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (needs_processing) {
|
||||
functions_to_process.push_back(func_ast);
|
||||
|
||||
// Find all of the calls to this function that will need to be replaced.
|
||||
for (auto* call : func_sem->CallSites()) {
|
||||
calls_to_replace[call->Stmt()->Function()->Declaration()].push_back(
|
||||
call->Declaration());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build a list of `&ident` expressions. We'll use this later to avoid
|
||||
// generating expressions of the form `&*ident`, which break WGSL validation
|
||||
// rules when this expression is passed to a function.
|
||||
// TODO(jrprice): We should add support for bidirectional SEM tree traversal
|
||||
// so that we can do this on the fly instead.
|
||||
std::unordered_map<const ast::IdentifierExpression*,
|
||||
const ast::UnaryOpExpression*>
|
||||
ident_to_address_of;
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
auto* address_of = node->As<ast::UnaryOpExpression>();
|
||||
if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
|
||||
continue;
|
||||
}
|
||||
if (auto* ident = address_of->expr->As<ast::IdentifierExpression>()) {
|
||||
ident_to_address_of[ident] = address_of;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* func_ast : functions_to_process) {
|
||||
auto* func_sem = ctx.src->Sem().Get(func_ast);
|
||||
bool is_entry_point = func_ast->IsEntryPoint();
|
||||
|
||||
// Map module-scope variables onto their replacement.
|
||||
struct NewVar {
|
||||
Symbol symbol;
|
||||
bool is_pointer;
|
||||
bool is_wrapped;
|
||||
};
|
||||
const char* kWrappedArrayMemberName = "arr";
|
||||
std::unordered_map<const sem::Variable*, NewVar> var_to_newvar;
|
||||
|
||||
// We aggregate all workgroup variables into a struct to avoid hitting
|
||||
// MSL's limit for threadgroup memory arguments.
|
||||
Symbol workgroup_parameter_symbol;
|
||||
ast::StructMemberList workgroup_parameter_members;
|
||||
auto workgroup_param = [&]() {
|
||||
if (!workgroup_parameter_symbol.IsValid()) {
|
||||
workgroup_parameter_symbol = ctx.dst->Sym();
|
||||
}
|
||||
return workgroup_parameter_symbol;
|
||||
};
|
||||
|
||||
for (auto* var : func_sem->TransitivelyReferencedGlobals()) {
|
||||
auto sc = var->StorageClass();
|
||||
auto* ty = var->Type()->UnwrapRef();
|
||||
if (sc == ast::StorageClass::kNone) {
|
||||
continue;
|
||||
}
|
||||
if (sc != ast::StorageClass::kPrivate &&
|
||||
sc != ast::StorageClass::kStorage &&
|
||||
sc != ast::StorageClass::kUniform &&
|
||||
sc != ast::StorageClass::kUniformConstant &&
|
||||
sc != ast::StorageClass::kWorkgroup) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
<< "unhandled module-scope storage class (" << sc << ")";
|
||||
}
|
||||
|
||||
// This is the symbol for the variable that replaces the module-scope
|
||||
// var.
|
||||
auto new_var_symbol = ctx.dst->Sym();
|
||||
|
||||
// Helper to create an AST node for the store type of the variable.
|
||||
auto store_type = [&]() { return CreateASTTypeFor(ctx, ty); };
|
||||
|
||||
// Track whether the new variable is a pointer or not.
|
||||
bool is_pointer = false;
|
||||
|
||||
// Track whether the new variable was wrapped in a struct or not.
|
||||
bool is_wrapped = false;
|
||||
|
||||
if (is_entry_point) {
|
||||
if (var->Type()->UnwrapRef()->is_handle()) {
|
||||
// For a texture or sampler variable, redeclare it as an entry point
|
||||
// parameter. Disable entry point parameter validation.
|
||||
auto* disable_validation =
|
||||
ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter);
|
||||
auto attrs = ctx.Clone(var->Declaration()->attributes);
|
||||
attrs.push_back(disable_validation);
|
||||
auto* param = ctx.dst->Param(new_var_symbol, store_type(), attrs);
|
||||
ctx.InsertFront(func_ast->params, param);
|
||||
} else if (sc == ast::StorageClass::kStorage ||
|
||||
sc == ast::StorageClass::kUniform) {
|
||||
// Variables into the Storage and Uniform storage classes are
|
||||
// redeclared as entry point parameters with a pointer type.
|
||||
auto attributes = ctx.Clone(var->Declaration()->attributes);
|
||||
attributes.push_back(ctx.dst->Disable(
|
||||
ast::DisabledValidation::kEntryPointParameter));
|
||||
attributes.push_back(
|
||||
ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
|
||||
|
||||
auto* param_type = store_type();
|
||||
if (auto* arr = ty->As<sem::Array>();
|
||||
arr && arr->IsRuntimeSized()) {
|
||||
// Wrap runtime-sized arrays in structures, so that we can declare
|
||||
// pointers to them. Ideally we'd just emit the array itself as a
|
||||
// pointer, but this is not representable in Tint's AST.
|
||||
CloneStructTypes(ty);
|
||||
auto* wrapper = ctx.dst->Structure(
|
||||
ctx.dst->Sym(),
|
||||
{ctx.dst->Member(kWrappedArrayMemberName, param_type)});
|
||||
param_type = ctx.dst->ty.Of(wrapper);
|
||||
is_wrapped = true;
|
||||
}
|
||||
|
||||
param_type = ctx.dst->ty.pointer(
|
||||
param_type, sc, var->Declaration()->declared_access);
|
||||
auto* param =
|
||||
ctx.dst->Param(new_var_symbol, param_type, attributes);
|
||||
ctx.InsertFront(func_ast->params, param);
|
||||
is_pointer = true;
|
||||
} else if (sc == ast::StorageClass::kWorkgroup &&
|
||||
ContainsMatrix(var->Type())) {
|
||||
// Due to a bug in the MSL compiler, we use a threadgroup memory
|
||||
// argument for any workgroup allocation that contains a matrix.
|
||||
// See crbug.com/tint/938.
|
||||
// TODO(jrprice): Do this for all other workgroup variables too.
|
||||
|
||||
// Create a member in the workgroup parameter struct.
|
||||
auto member = ctx.Clone(var->Declaration()->symbol);
|
||||
workgroup_parameter_members.push_back(
|
||||
ctx.dst->Member(member, store_type()));
|
||||
CloneStructTypes(var->Type()->UnwrapRef());
|
||||
|
||||
// Create a function-scope variable that is a pointer to the member.
|
||||
auto* member_ptr = ctx.dst->AddressOf(ctx.dst->MemberAccessor(
|
||||
ctx.dst->Deref(workgroup_param()), member));
|
||||
auto* local_var =
|
||||
ctx.dst->Const(new_var_symbol,
|
||||
ctx.dst->ty.pointer(
|
||||
store_type(), ast::StorageClass::kWorkgroup),
|
||||
member_ptr);
|
||||
ctx.InsertFront(func_ast->body->statements,
|
||||
ctx.dst->Decl(local_var));
|
||||
is_pointer = true;
|
||||
} else {
|
||||
// Variables in the Private and Workgroup storage classes are
|
||||
// redeclared at function scope. Disable storage class validation on
|
||||
// this variable.
|
||||
auto* disable_validation =
|
||||
ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass);
|
||||
auto* constructor = ctx.Clone(var->Declaration()->constructor);
|
||||
auto* local_var =
|
||||
ctx.dst->Var(new_var_symbol, store_type(), sc, constructor,
|
||||
ast::AttributeList{disable_validation});
|
||||
ctx.InsertFront(func_ast->body->statements,
|
||||
ctx.dst->Decl(local_var));
|
||||
}
|
||||
} else {
|
||||
// For a regular function, redeclare the variable as a parameter.
|
||||
// Use a pointer for non-handle types.
|
||||
auto* param_type = store_type();
|
||||
ast::AttributeList attributes;
|
||||
if (!var->Type()->UnwrapRef()->is_handle()) {
|
||||
param_type = ctx.dst->ty.pointer(
|
||||
param_type, sc, var->Declaration()->declared_access);
|
||||
is_pointer = true;
|
||||
|
||||
// Disable validation of the parameter's storage class and of
|
||||
// arguments passed it.
|
||||
attributes.push_back(
|
||||
ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
|
||||
attributes.push_back(ctx.dst->Disable(
|
||||
ast::DisabledValidation::kIgnoreInvalidPointerArgument));
|
||||
}
|
||||
ctx.InsertBack(
|
||||
func_ast->params,
|
||||
ctx.dst->Param(new_var_symbol, param_type, attributes));
|
||||
}
|
||||
|
||||
// Replace all uses of the module-scope variable.
|
||||
// For non-entry points, dereference non-handle pointer parameters.
|
||||
for (auto* user : var->Users()) {
|
||||
if (user->Stmt()->Function()->Declaration() == func_ast) {
|
||||
const ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
|
||||
if (is_pointer) {
|
||||
// If this identifier is used by an address-of operator, just
|
||||
// remove the address-of instead of adding a deref, since we
|
||||
// already have a pointer.
|
||||
auto* ident =
|
||||
user->Declaration()->As<ast::IdentifierExpression>();
|
||||
if (ident_to_address_of.count(ident)) {
|
||||
ctx.Replace(ident_to_address_of[ident], expr);
|
||||
continue;
|
||||
}
|
||||
|
||||
expr = ctx.dst->Deref(expr);
|
||||
}
|
||||
if (is_wrapped) {
|
||||
// Get the member from the wrapper structure.
|
||||
expr = ctx.dst->MemberAccessor(expr, kWrappedArrayMemberName);
|
||||
}
|
||||
ctx.Replace(user->Declaration(), expr);
|
||||
}
|
||||
}
|
||||
|
||||
var_to_newvar[var] = {new_var_symbol, is_pointer, is_wrapped};
|
||||
}
|
||||
|
||||
if (!workgroup_parameter_members.empty()) {
|
||||
// Create the workgroup memory parameter.
|
||||
// The parameter is a struct that contains members for each workgroup
|
||||
// variable.
|
||||
auto* str = ctx.dst->Structure(ctx.dst->Sym(),
|
||||
std::move(workgroup_parameter_members));
|
||||
auto* param_type = ctx.dst->ty.pointer(ctx.dst->ty.Of(str),
|
||||
ast::StorageClass::kWorkgroup);
|
||||
auto* disable_validation =
|
||||
ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter);
|
||||
auto* param =
|
||||
ctx.dst->Param(workgroup_param(), param_type, {disable_validation});
|
||||
ctx.InsertFront(func_ast->params, param);
|
||||
}
|
||||
|
||||
// Pass the variables as pointers to any functions that need them.
|
||||
for (auto* call : calls_to_replace[func_ast]) {
|
||||
auto* target =
|
||||
ctx.src->AST().Functions().Find(call->target.name->symbol);
|
||||
auto* target_sem = ctx.src->Sem().Get(target);
|
||||
|
||||
// Add new arguments for any variables that are needed by the callee.
|
||||
// For entry points, pass non-handle types as pointers.
|
||||
for (auto* target_var : target_sem->TransitivelyReferencedGlobals()) {
|
||||
auto sc = target_var->StorageClass();
|
||||
if (sc == ast::StorageClass::kNone) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto new_var = var_to_newvar[target_var];
|
||||
bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
|
||||
const ast::Expression* arg = ctx.dst->Expr(new_var.symbol);
|
||||
if (new_var.is_wrapped) {
|
||||
// The variable is wrapped in a struct, so we need to pass a pointer
|
||||
// to the struct member instead.
|
||||
arg = ctx.dst->AddressOf(ctx.dst->MemberAccessor(
|
||||
ctx.dst->Deref(arg), kWrappedArrayMemberName));
|
||||
} else if (is_entry_point && !is_handle && !new_var.is_pointer) {
|
||||
// We need to pass a pointer and we don't already have one, so take
|
||||
// the address of the new variable.
|
||||
arg = ctx.dst->AddressOf(arg);
|
||||
}
|
||||
ctx.InsertBack(call->args, arg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now remove all module-scope variables with these storage classes.
|
||||
for (auto* var_ast : ctx.src->AST().GlobalVariables()) {
|
||||
auto* var_sem = ctx.src->Sem().Get(var_ast);
|
||||
if (var_sem->StorageClass() != ast::StorageClass::kNone) {
|
||||
ctx.Remove(ctx.src->AST().GlobalDeclarations(), var_ast);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_set<const sem::Struct*> cloned_structs_;
|
||||
};
|
||||
|
||||
ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
|
||||
|
||||
ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default;
|
||||
|
||||
bool ModuleScopeVarToEntryPointParam::ShouldRun(const Program* program,
|
||||
const DataMap&) const {
|
||||
for (auto* decl : program->AST().GlobalDeclarations()) {
|
||||
if (decl->Is<ast::Variable>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
State state{ctx};
|
||||
state.Process();
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
96
src/tint/transform/module_scope_var_to_entry_point_param.h
Normal file
96
src/tint/transform/module_scope_var_to_entry_point_param.h
Normal file
@@ -0,0 +1,96 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_MODULE_SCOPE_VAR_TO_ENTRY_POINT_PARAM_H_
|
||||
#define SRC_TINT_TRANSFORM_MODULE_SCOPE_VAR_TO_ENTRY_POINT_PARAM_H_
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// Move module-scope variables into the entry point as parameters.
|
||||
///
|
||||
/// MSL does not allow module-scope variables to have any address space other
|
||||
/// than `constant`. This transform moves all module-scope declarations into the
|
||||
/// entry point function (either as parameters or function-scope variables) and
|
||||
/// then passes them as pointer parameters to any function that references them.
|
||||
///
|
||||
/// Since WGSL does not allow entry point parameters or function-scope variables
|
||||
/// to have these storage classes, we annotate the new variable declarations
|
||||
/// with an attribute that bypasses that validation rule.
|
||||
///
|
||||
/// Before:
|
||||
/// ```
|
||||
/// struct S {
|
||||
/// f : f32;
|
||||
/// };
|
||||
/// @binding(0) @group(0)
|
||||
/// var<storage, read> s : S;
|
||||
/// var<private> p : f32 = 2.0;
|
||||
///
|
||||
/// fn foo() {
|
||||
/// p = p + f;
|
||||
/// }
|
||||
///
|
||||
/// @stage(compute) @workgroup_size(1)
|
||||
/// fn main() {
|
||||
/// foo();
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// After:
|
||||
/// ```
|
||||
/// fn foo(p : ptr<private, f32>, sptr : ptr<storage, S, read>) {
|
||||
/// *p = *p + (*sptr).f;
|
||||
/// }
|
||||
///
|
||||
/// @stage(compute) @workgroup_size(1)
|
||||
/// fn main(sptr : ptr<storage, S, read>) {
|
||||
/// var<private> p : f32 = 2.0;
|
||||
/// foo(&p, sptr);
|
||||
/// }
|
||||
/// ```
|
||||
class ModuleScopeVarToEntryPointParam
|
||||
: public Castable<ModuleScopeVarToEntryPointParam, Transform> {
|
||||
public:
|
||||
/// Constructor
|
||||
ModuleScopeVarToEntryPointParam();
|
||||
/// Destructor
|
||||
~ModuleScopeVarToEntryPointParam() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
struct State;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_MODULE_SCOPE_VAR_TO_ENTRY_POINT_PARAM_H_
|
||||
1170
src/tint/transform/module_scope_var_to_entry_point_param_test.cc
Normal file
1170
src/tint/transform/module_scope_var_to_entry_point_param_test.cc
Normal file
File diff suppressed because it is too large
Load Diff
454
src/tint/transform/multiplanar_external_texture.cc
Normal file
454
src/tint/transform/multiplanar_external_texture.cc
Normal file
@@ -0,0 +1,454 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/multiplanar_external_texture.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "src/tint/ast/function.h"
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/call.h"
|
||||
#include "src/tint/sem/function.h"
|
||||
#include "src/tint/sem/variable.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::MultiplanarExternalTexture);
|
||||
TINT_INSTANTIATE_TYPEINFO(
|
||||
tint::transform::MultiplanarExternalTexture::NewBindingPoints);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
/// This struct stores symbols for new bindings created as a result of
|
||||
/// transforming a texture_external instance.
|
||||
struct NewBindingSymbols {
|
||||
Symbol params;
|
||||
Symbol plane_0;
|
||||
Symbol plane_1;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
/// State holds the current transform state
|
||||
struct MultiplanarExternalTexture::State {
|
||||
/// The clone context.
|
||||
CloneContext& ctx;
|
||||
|
||||
/// ProgramBuilder for the context
|
||||
ProgramBuilder& b;
|
||||
|
||||
/// Destination binding locations for the expanded texture_external provided
|
||||
/// as input into the transform.
|
||||
const NewBindingPoints* new_binding_points;
|
||||
|
||||
/// Symbol for the ExternalTextureParams struct
|
||||
Symbol params_struct_sym;
|
||||
|
||||
/// Symbol for the textureLoadExternal function
|
||||
Symbol texture_load_external_sym;
|
||||
|
||||
/// Symbol for the textureSampleExternal function
|
||||
Symbol texture_sample_external_sym;
|
||||
|
||||
/// Storage for new bindings that have been created corresponding to an
|
||||
/// original texture_external binding.
|
||||
std::unordered_map<const sem::Variable*, NewBindingSymbols>
|
||||
new_binding_symbols;
|
||||
|
||||
/// Constructor
|
||||
/// @param context the clone
|
||||
/// @param newBindingPoints the input destination binding locations for the
|
||||
/// expanded texture_external
|
||||
State(CloneContext& context, const NewBindingPoints* newBindingPoints)
|
||||
: ctx(context), b(*context.dst), new_binding_points(newBindingPoints) {}
|
||||
|
||||
/// Processes the module
|
||||
void Process() {
|
||||
auto& sem = ctx.src->Sem();
|
||||
|
||||
// For each texture_external binding, we replace it with a texture_2d<f32>
|
||||
// binding and create two additional bindings (one texture_2d<f32> to
|
||||
// represent the secondary plane and one uniform buffer for the
|
||||
// ExternalTextureParams struct).
|
||||
for (auto* var : ctx.src->AST().GlobalVariables()) {
|
||||
auto* sem_var = sem.Get(var);
|
||||
if (!sem_var->Type()->UnwrapRef()->Is<sem::ExternalTexture>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If the attributes are empty, then this must be a texture_external
|
||||
// passed as a function parameter. These variables are transformed
|
||||
// elsewhere.
|
||||
if (var->attributes.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If we find a texture_external binding, we know we must emit the
|
||||
// ExternalTextureParams struct.
|
||||
if (!params_struct_sym.IsValid()) {
|
||||
createExtTexParamsStruct();
|
||||
}
|
||||
|
||||
// The binding points for the newly introduced bindings must have been
|
||||
// provided to this transform. We fetch the new binding points by
|
||||
// providing the original texture_external binding points into the
|
||||
// passed map.
|
||||
BindingPoint bp = {var->BindingPoint().group->value,
|
||||
var->BindingPoint().binding->value};
|
||||
|
||||
BindingsMap::const_iterator it =
|
||||
new_binding_points->bindings_map.find(bp);
|
||||
if (it == new_binding_points->bindings_map.end()) {
|
||||
b.Diagnostics().add_error(
|
||||
diag::System::Transform,
|
||||
"missing new binding points for texture_external at binding {" +
|
||||
std::to_string(bp.group) + "," + std::to_string(bp.binding) +
|
||||
"}");
|
||||
continue;
|
||||
}
|
||||
|
||||
BindingPoints bps = it->second;
|
||||
|
||||
// Symbols for the newly created bindings must be saved so they can be
|
||||
// passed as parameters later. These are placed in a map and keyed by
|
||||
// the source symbol associated with the texture_external binding that
|
||||
// corresponds with the new destination bindings.
|
||||
// NewBindingSymbols new_binding_syms;
|
||||
auto& syms = new_binding_symbols[sem_var];
|
||||
syms.plane_0 = ctx.Clone(var->symbol);
|
||||
syms.plane_1 = b.Symbols().New("ext_tex_plane_1");
|
||||
b.Global(syms.plane_1,
|
||||
b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32()),
|
||||
b.GroupAndBinding(bps.plane_1.group, bps.plane_1.binding));
|
||||
syms.params = b.Symbols().New("ext_tex_params");
|
||||
b.Global(syms.params, b.ty.type_name("ExternalTextureParams"),
|
||||
ast::StorageClass::kUniform,
|
||||
b.GroupAndBinding(bps.params.group, bps.params.binding));
|
||||
|
||||
// Replace the original texture_external binding with a texture_2d<f32>
|
||||
// binding.
|
||||
ast::AttributeList cloned_attributes = ctx.Clone(var->attributes);
|
||||
const ast::Expression* cloned_constructor = ctx.Clone(var->constructor);
|
||||
|
||||
auto* replacement =
|
||||
b.Var(syms.plane_0,
|
||||
b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32()),
|
||||
cloned_constructor, cloned_attributes);
|
||||
ctx.Replace(var, replacement);
|
||||
}
|
||||
|
||||
// We must update all the texture_external parameters for user declared
|
||||
// functions.
|
||||
for (auto* fn : ctx.src->AST().Functions()) {
|
||||
for (const ast::Variable* param : fn->params) {
|
||||
if (auto* sem_var = sem.Get(param)) {
|
||||
if (!sem_var->Type()->UnwrapRef()->Is<sem::ExternalTexture>()) {
|
||||
continue;
|
||||
}
|
||||
// If we find a texture_external, we must ensure the
|
||||
// ExternalTextureParams struct exists.
|
||||
if (!params_struct_sym.IsValid()) {
|
||||
createExtTexParamsStruct();
|
||||
}
|
||||
// When a texture_external is found, we insert all components
|
||||
// the texture_external into the parameter list. We must also place
|
||||
// the new symbols into the transform state so they can be used when
|
||||
// transforming function calls.
|
||||
auto& syms = new_binding_symbols[sem_var];
|
||||
syms.plane_0 = ctx.Clone(param->symbol);
|
||||
syms.plane_1 = b.Symbols().New("ext_tex_plane_1");
|
||||
syms.params = b.Symbols().New("ext_tex_params");
|
||||
auto tex2d_f32 = [&] {
|
||||
return b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32());
|
||||
};
|
||||
ctx.Replace(param, b.Param(syms.plane_0, tex2d_f32()));
|
||||
ctx.InsertAfter(fn->params, param,
|
||||
b.Param(syms.plane_1, tex2d_f32()));
|
||||
ctx.InsertAfter(
|
||||
fn->params, param,
|
||||
b.Param(syms.params, b.ty.type_name(params_struct_sym)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Transform the original textureLoad and textureSampleLevel calls into
|
||||
// textureLoadExternal and textureSampleExternal calls.
|
||||
ctx.ReplaceAll(
|
||||
[&](const ast::CallExpression* expr) -> const ast::CallExpression* {
|
||||
auto* builtin = sem.Get(expr)->Target()->As<sem::Builtin>();
|
||||
|
||||
if (builtin && !builtin->Parameters().empty() &&
|
||||
builtin->Parameters()[0]->Type()->Is<sem::ExternalTexture>() &&
|
||||
builtin->Type() != sem::BuiltinType::kTextureDimensions) {
|
||||
if (auto* var_user = sem.Get<sem::VariableUser>(expr->args[0])) {
|
||||
auto it = new_binding_symbols.find(var_user->Variable());
|
||||
if (it == new_binding_symbols.end()) {
|
||||
// If valid new binding locations were not provided earlier, we
|
||||
// would have been unable to create these symbols. An error
|
||||
// message was emitted earlier, so just return early to avoid
|
||||
// internal compiler errors and retain a clean error message.
|
||||
return nullptr;
|
||||
}
|
||||
auto& syms = it->second;
|
||||
|
||||
if (builtin->Type() == sem::BuiltinType::kTextureLoad) {
|
||||
return createTexLdExt(expr, syms);
|
||||
}
|
||||
|
||||
if (builtin->Type() == sem::BuiltinType::kTextureSampleLevel) {
|
||||
return createTexSmpExt(expr, syms);
|
||||
}
|
||||
}
|
||||
|
||||
} else if (sem.Get(expr)->Target()->Is<sem::Function>()) {
|
||||
// The call expression may be to a user-defined function that
|
||||
// contains a texture_external parameter. These need to be expanded
|
||||
// out to multiple plane textures and the texture parameters
|
||||
// structure.
|
||||
for (auto* arg : expr->args) {
|
||||
if (auto* var_user = sem.Get<sem::VariableUser>(arg)) {
|
||||
// Check if a parameter is a texture_external by trying to find
|
||||
// it in the transform state.
|
||||
auto it = new_binding_symbols.find(var_user->Variable());
|
||||
if (it != new_binding_symbols.end()) {
|
||||
auto& syms = it->second;
|
||||
// When we find a texture_external, we must unpack it into its
|
||||
// components.
|
||||
ctx.Replace(arg, b.Expr(syms.plane_0));
|
||||
ctx.InsertAfter(expr->args, arg, b.Expr(syms.plane_1));
|
||||
ctx.InsertAfter(expr->args, arg, b.Expr(syms.params));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
});
|
||||
}
|
||||
|
||||
/// Creates the ExternalTextureParams struct.
|
||||
void createExtTexParamsStruct() {
|
||||
ast::StructMemberList member_list = {
|
||||
b.Member("numPlanes", b.ty.u32()), b.Member("vr", b.ty.f32()),
|
||||
b.Member("ug", b.ty.f32()), b.Member("vg", b.ty.f32()),
|
||||
b.Member("ub", b.ty.f32())};
|
||||
|
||||
params_struct_sym = b.Symbols().New("ExternalTextureParams");
|
||||
|
||||
b.Structure(params_struct_sym, member_list);
|
||||
}
|
||||
|
||||
/// Constructs a StatementList containing all the statements making up the
|
||||
/// bodies of the textureSampleExternal and textureLoadExternal functions.
|
||||
/// @param call_type determines which function body to generate
|
||||
/// @returns a statement list that makes of the body of the chosen function
|
||||
ast::StatementList createTexFnExtStatementList(sem::BuiltinType call_type) {
|
||||
using f32 = ProgramBuilder::f32;
|
||||
const ast::CallExpression* single_plane_call = nullptr;
|
||||
const ast::CallExpression* plane_0_call = nullptr;
|
||||
const ast::CallExpression* plane_1_call = nullptr;
|
||||
if (call_type == sem::BuiltinType::kTextureSampleLevel) {
|
||||
// textureSampleLevel(plane0, smp, coord.xy, 0.0);
|
||||
single_plane_call =
|
||||
b.Call("textureSampleLevel", "plane0", "smp", "coord", 0.0f);
|
||||
// textureSampleLevel(plane0, smp, coord.xy, 0.0);
|
||||
plane_0_call =
|
||||
b.Call("textureSampleLevel", "plane0", "smp", "coord", 0.0f);
|
||||
// textureSampleLevel(plane1, smp, coord.xy, 0.0);
|
||||
plane_1_call =
|
||||
b.Call("textureSampleLevel", "plane1", "smp", "coord", 0.0f);
|
||||
} else if (call_type == sem::BuiltinType::kTextureLoad) {
|
||||
// textureLoad(plane0, coords.xy, 0);
|
||||
single_plane_call = b.Call("textureLoad", "plane0", "coord", 0);
|
||||
// textureLoad(plane0, coords.xy, 0);
|
||||
plane_0_call = b.Call("textureLoad", "plane0", "coord", 0);
|
||||
// textureLoad(plane1, coords.xy, 0);
|
||||
plane_1_call = b.Call("textureLoad", "plane1", "coord", 0);
|
||||
} else {
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "unhandled builtin: " << call_type;
|
||||
}
|
||||
|
||||
return {
|
||||
// if (params.numPlanes == 1u) {
|
||||
// return singlePlaneCall
|
||||
// }
|
||||
b.If(b.create<ast::BinaryExpression>(
|
||||
ast::BinaryOp::kEqual, b.MemberAccessor("params", "numPlanes"),
|
||||
b.Expr(1u)),
|
||||
b.Block(b.Return(single_plane_call))),
|
||||
// let y = plane0Call.r - 0.0625;
|
||||
b.Decl(b.Const("y", nullptr,
|
||||
b.Sub(b.MemberAccessor(plane_0_call, "r"), 0.0625f))),
|
||||
// let uv = plane1Call.rg - 0.5;
|
||||
b.Decl(b.Const("uv", nullptr,
|
||||
b.Sub(b.MemberAccessor(plane_1_call, "rg"), 0.5f))),
|
||||
// let u = uv.x;
|
||||
b.Decl(b.Const("u", nullptr, b.MemberAccessor("uv", "x"))),
|
||||
// let v = uv.y;
|
||||
b.Decl(b.Const("v", nullptr, b.MemberAccessor("uv", "y"))),
|
||||
// let r = 1.164 * y + params.vr * v;
|
||||
b.Decl(b.Const("r", nullptr,
|
||||
b.Add(b.Mul(1.164f, "y"),
|
||||
b.Mul(b.MemberAccessor("params", "vr"), "v")))),
|
||||
// let g = 1.164 * y - params.ug * u - params.vg * v;
|
||||
b.Decl(
|
||||
b.Const("g", nullptr,
|
||||
b.Sub(b.Sub(b.Mul(1.164f, "y"),
|
||||
b.Mul(b.MemberAccessor("params", "ug"), "u")),
|
||||
b.Mul(b.MemberAccessor("params", "vg"), "v")))),
|
||||
// let b = 1.164 * y + params.ub * u;
|
||||
b.Decl(b.Const("b", nullptr,
|
||||
b.Add(b.Mul(1.164f, "y"),
|
||||
b.Mul(b.MemberAccessor("params", "ub"), "u")))),
|
||||
// return vec4<f32>(r, g, b, 1.0);
|
||||
b.Return(b.vec4<f32>("r", "g", "b", 1.0f)),
|
||||
};
|
||||
}
|
||||
|
||||
/// Creates the textureSampleExternal function if needed and returns a call
|
||||
/// expression to it.
|
||||
/// @param expr the call expression being transformed
|
||||
/// @param syms the expanded symbols to be used in the new call
|
||||
/// @returns a call expression to textureSampleExternal
|
||||
const ast::CallExpression* createTexSmpExt(const ast::CallExpression* expr,
|
||||
NewBindingSymbols syms) {
|
||||
ast::ExpressionList params;
|
||||
const ast::Expression* plane_0_binding_param = ctx.Clone(expr->args[0]);
|
||||
|
||||
if (expr->args.size() != 3) {
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "expected textureSampleLevel call with a "
|
||||
"texture_external to have 3 parameters, found "
|
||||
<< expr->args.size() << " parameters";
|
||||
}
|
||||
|
||||
if (!texture_sample_external_sym.IsValid()) {
|
||||
texture_sample_external_sym = b.Symbols().New("textureSampleExternal");
|
||||
|
||||
// Emit the textureSampleExternal function.
|
||||
ast::VariableList varList = {
|
||||
b.Param("plane0",
|
||||
b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32())),
|
||||
b.Param("plane1",
|
||||
b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32())),
|
||||
b.Param("smp", b.ty.sampler(ast::SamplerKind::kSampler)),
|
||||
b.Param("coord", b.ty.vec2(b.ty.f32())),
|
||||
b.Param("params", b.ty.type_name(params_struct_sym))};
|
||||
|
||||
ast::StatementList statementList =
|
||||
createTexFnExtStatementList(sem::BuiltinType::kTextureSampleLevel);
|
||||
|
||||
b.Func(texture_sample_external_sym, varList, b.ty.vec4(b.ty.f32()),
|
||||
statementList, {});
|
||||
}
|
||||
|
||||
const ast::IdentifierExpression* exp = b.Expr(texture_sample_external_sym);
|
||||
params = {plane_0_binding_param, b.Expr(syms.plane_1),
|
||||
ctx.Clone(expr->args[1]), ctx.Clone(expr->args[2]),
|
||||
b.Expr(syms.params)};
|
||||
return b.Call(exp, params);
|
||||
}
|
||||
|
||||
/// Creates the textureLoadExternal function if needed and returns a call
|
||||
/// expression to it.
|
||||
/// @param expr the call expression being transformed
|
||||
/// @param syms the expanded symbols to be used in the new call
|
||||
/// @returns a call expression to textureLoadExternal
|
||||
const ast::CallExpression* createTexLdExt(const ast::CallExpression* expr,
|
||||
NewBindingSymbols syms) {
|
||||
ast::ExpressionList params;
|
||||
const ast::Expression* plane_0_binding_param = ctx.Clone(expr->args[0]);
|
||||
|
||||
if (expr->args.size() != 2) {
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "expected textureLoad call with a texture_external "
|
||||
"to have 2 parameters, found "
|
||||
<< expr->args.size() << " parameters";
|
||||
}
|
||||
|
||||
if (!texture_load_external_sym.IsValid()) {
|
||||
texture_load_external_sym = b.Symbols().New("textureLoadExternal");
|
||||
|
||||
// Emit the textureLoadExternal function.
|
||||
ast::VariableList var_list = {
|
||||
b.Param("plane0",
|
||||
b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32())),
|
||||
b.Param("plane1",
|
||||
b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32())),
|
||||
b.Param("coord", b.ty.vec2(b.ty.i32())),
|
||||
b.Param("params", b.ty.type_name(params_struct_sym))};
|
||||
|
||||
ast::StatementList statement_list =
|
||||
createTexFnExtStatementList(sem::BuiltinType::kTextureLoad);
|
||||
|
||||
b.Func(texture_load_external_sym, var_list, b.ty.vec4(b.ty.f32()),
|
||||
statement_list, {});
|
||||
}
|
||||
|
||||
const ast::IdentifierExpression* exp = b.Expr(texture_load_external_sym);
|
||||
params = {plane_0_binding_param, b.Expr(syms.plane_1),
|
||||
ctx.Clone(expr->args[1]), b.Expr(syms.params)};
|
||||
return b.Call(exp, params);
|
||||
}
|
||||
};
|
||||
|
||||
MultiplanarExternalTexture::NewBindingPoints::NewBindingPoints(
|
||||
BindingsMap inputBindingsMap)
|
||||
: bindings_map(std::move(inputBindingsMap)) {}
|
||||
MultiplanarExternalTexture::NewBindingPoints::~NewBindingPoints() = default;
|
||||
|
||||
MultiplanarExternalTexture::MultiplanarExternalTexture() = default;
|
||||
MultiplanarExternalTexture::~MultiplanarExternalTexture() = default;
|
||||
|
||||
bool MultiplanarExternalTexture::ShouldRun(const Program* program,
|
||||
const DataMap&) const {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (auto* ty = node->As<ast::Type>()) {
|
||||
if (program->Sem().Get<sem::ExternalTexture>(ty)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Within this transform, an instance of a texture_external binding is unpacked
|
||||
// into two texture_2d<f32> bindings representing two possible planes of a
|
||||
// single texture and a uniform buffer binding representing a struct of
|
||||
// parameters. Calls to textureLoad or textureSampleLevel that contain a
|
||||
// texture_external parameter will be transformed into a newly generated version
|
||||
// of the function, which can perform the desired operation on a single RGBA
|
||||
// plane or on separate Y and UV planes.
|
||||
void MultiplanarExternalTexture::Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap&) const {
|
||||
auto* new_binding_points = inputs.Get<NewBindingPoints>();
|
||||
|
||||
if (!new_binding_points) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
diag::System::Transform,
|
||||
"missing new binding point data for " + std::string(TypeInfo().name));
|
||||
return;
|
||||
}
|
||||
|
||||
State state(ctx, new_binding_points);
|
||||
|
||||
state.Process();
|
||||
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
102
src/tint/transform/multiplanar_external_texture.h
Normal file
102
src/tint/transform/multiplanar_external_texture.h
Normal file
@@ -0,0 +1,102 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_MULTIPLANAR_EXTERNAL_TEXTURE_H_
|
||||
#define SRC_TINT_TRANSFORM_MULTIPLANAR_EXTERNAL_TEXTURE_H_
|
||||
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/ast/struct_member.h"
|
||||
#include "src/tint/sem/binding_point.h"
|
||||
#include "src/tint/sem/builtin_type.h"
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// BindingPoint is an alias to sem::BindingPoint
|
||||
using BindingPoint = sem::BindingPoint;
|
||||
|
||||
/// This struct identifies the binding groups and locations for new bindings to
|
||||
/// use when transforming a texture_external instance.
|
||||
struct BindingPoints {
|
||||
/// The desired binding location of the texture_2d representing plane #1 when
|
||||
/// a texture_external binding is expanded.
|
||||
BindingPoint plane_1;
|
||||
/// The desired binding location of the ExternalTextureParams uniform when a
|
||||
/// texture_external binding is expanded.
|
||||
BindingPoint params;
|
||||
};
|
||||
|
||||
/// Within the MultiplanarExternalTexture transform, each instance of a
|
||||
/// texture_external binding is unpacked into two texture_2d<f32> bindings
|
||||
/// representing two possible planes of a texture and a uniform buffer binding
|
||||
/// representing a struct of parameters. Calls to textureLoad or
|
||||
/// textureSampleLevel that contain a texture_external parameter will be
|
||||
/// transformed into a newly generated version of the function, which can
|
||||
/// perform the desired operation on a single RGBA plane or on seperate Y and UV
|
||||
/// planes.
|
||||
class MultiplanarExternalTexture
|
||||
: public Castable<MultiplanarExternalTexture, Transform> {
|
||||
public:
|
||||
/// BindingsMap is a map where the key is the binding location of a
|
||||
/// texture_external and the value is a struct containing the desired
|
||||
/// locations for new bindings expanded from the texture_external instance.
|
||||
using BindingsMap = std::unordered_map<BindingPoint, BindingPoints>;
|
||||
|
||||
/// NewBindingPoints is consumed by the MultiplanarExternalTexture transform.
|
||||
/// Data holds information about location of each texture_external binding and
|
||||
/// which binding slots it should expand into.
|
||||
struct NewBindingPoints : public Castable<Data, transform::Data> {
|
||||
/// Constructor
|
||||
/// @param bm a map to the new binding slots to use.
|
||||
explicit NewBindingPoints(BindingsMap bm);
|
||||
|
||||
/// Destructor
|
||||
~NewBindingPoints() override;
|
||||
|
||||
/// A map of new binding points to use.
|
||||
const BindingsMap bindings_map;
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
MultiplanarExternalTexture();
|
||||
/// Destructor
|
||||
~MultiplanarExternalTexture() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
struct State;
|
||||
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_MULTIPLANAR_EXTERNAL_TEXTURE_H_
|
||||
1297
src/tint/transform/multiplanar_external_texture_test.cc
Normal file
1297
src/tint/transform/multiplanar_external_texture_test.cc
Normal file
File diff suppressed because it is too large
Load Diff
170
src/tint/transform/num_workgroups_from_uniform.cc
Normal file
170
src/tint/transform/num_workgroups_from_uniform.cc
Normal file
@@ -0,0 +1,170 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/num_workgroups_from_uniform.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/function.h"
|
||||
#include "src/tint/transform/canonicalize_entry_point_io.h"
|
||||
#include "src/tint/utils/hash.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::NumWorkgroupsFromUniform);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::NumWorkgroupsFromUniform::Config);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
/// Accessor describes the identifiers used in a member accessor that is being
|
||||
/// used to retrieve the num_workgroups builtin from a parameter.
|
||||
struct Accessor {
|
||||
Symbol param;
|
||||
Symbol member;
|
||||
|
||||
/// Equality operator
|
||||
bool operator==(const Accessor& other) const {
|
||||
return param == other.param && member == other.member;
|
||||
}
|
||||
/// Hash function
|
||||
struct Hasher {
|
||||
size_t operator()(const Accessor& a) const {
|
||||
return utils::Hash(a.param, a.member);
|
||||
}
|
||||
};
|
||||
};
|
||||
} // namespace
|
||||
|
||||
NumWorkgroupsFromUniform::NumWorkgroupsFromUniform() = default;
|
||||
NumWorkgroupsFromUniform::~NumWorkgroupsFromUniform() = default;
|
||||
|
||||
bool NumWorkgroupsFromUniform::ShouldRun(const Program* program,
|
||||
const DataMap&) const {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (auto* attr = node->As<ast::BuiltinAttribute>()) {
|
||||
if (attr->builtin == ast::Builtin::kNumWorkgroups) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void NumWorkgroupsFromUniform::Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap&) const {
|
||||
auto* cfg = inputs.Get<Config>();
|
||||
if (cfg == nullptr) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
diag::System::Transform,
|
||||
"missing transform data for " + std::string(TypeInfo().name));
|
||||
return;
|
||||
}
|
||||
|
||||
const char* kNumWorkgroupsMemberName = "num_workgroups";
|
||||
|
||||
// Find all entry point parameters that declare the num_workgroups builtin.
|
||||
std::unordered_set<Accessor, Accessor::Hasher> to_replace;
|
||||
for (auto* func : ctx.src->AST().Functions()) {
|
||||
// num_workgroups is only valid for compute stages.
|
||||
if (func->PipelineStage() != ast::PipelineStage::kCompute) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto* param : ctx.src->Sem().Get(func)->Parameters()) {
|
||||
// Because the CanonicalizeEntryPointIO transform has been run, builtins
|
||||
// will only appear as struct members.
|
||||
auto* str = param->Type()->As<sem::Struct>();
|
||||
if (!str) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto* member : str->Members()) {
|
||||
auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(
|
||||
member->Declaration()->attributes);
|
||||
if (!builtin || builtin->builtin != ast::Builtin::kNumWorkgroups) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Capture the symbols that would be used to access this member, which
|
||||
// we will replace later. We currently have no way to get from the
|
||||
// parameter directly to the member accessor expressions that use it.
|
||||
to_replace.insert(
|
||||
{param->Declaration()->symbol, member->Declaration()->symbol});
|
||||
|
||||
// Remove the struct member.
|
||||
// The CanonicalizeEntryPointIO transform will have generated this
|
||||
// struct uniquely for this particular entry point, so we know that
|
||||
// there will be no other uses of this struct in the module and that we
|
||||
// can safely modify it here.
|
||||
ctx.Remove(str->Declaration()->members, member->Declaration());
|
||||
|
||||
// If this is the only member, remove the struct and parameter too.
|
||||
if (str->Members().size() == 1) {
|
||||
ctx.Remove(func->params, param->Declaration());
|
||||
ctx.Remove(ctx.src->AST().GlobalDeclarations(), str->Declaration());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get (or create, on first call) the uniform buffer that will receive the
|
||||
// number of workgroups.
|
||||
const ast::Variable* num_workgroups_ubo = nullptr;
|
||||
auto get_ubo = [&]() {
|
||||
if (!num_workgroups_ubo) {
|
||||
auto* num_workgroups_struct = ctx.dst->Structure(
|
||||
ctx.dst->Sym(),
|
||||
{ctx.dst->Member(kNumWorkgroupsMemberName,
|
||||
ctx.dst->ty.vec3(ctx.dst->ty.u32()))});
|
||||
num_workgroups_ubo = ctx.dst->Global(
|
||||
ctx.dst->Sym(), ctx.dst->ty.Of(num_workgroups_struct),
|
||||
ast::StorageClass::kUniform,
|
||||
ast::AttributeList{ctx.dst->GroupAndBinding(
|
||||
cfg->ubo_binding.group, cfg->ubo_binding.binding)});
|
||||
}
|
||||
return num_workgroups_ubo;
|
||||
};
|
||||
|
||||
// Now replace all the places where the builtins are accessed with the value
|
||||
// loaded from the uniform buffer.
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
auto* accessor = node->As<ast::MemberAccessorExpression>();
|
||||
if (!accessor) {
|
||||
continue;
|
||||
}
|
||||
auto* ident = accessor->structure->As<ast::IdentifierExpression>();
|
||||
if (!ident) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (to_replace.count({ident->symbol, accessor->member->symbol})) {
|
||||
ctx.Replace(accessor, ctx.dst->MemberAccessor(get_ubo()->symbol,
|
||||
kNumWorkgroupsMemberName));
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
NumWorkgroupsFromUniform::Config::Config(sem::BindingPoint ubo_bp)
|
||||
: ubo_binding(ubo_bp) {}
|
||||
NumWorkgroupsFromUniform::Config::Config(const Config&) = default;
|
||||
NumWorkgroupsFromUniform::Config::~Config() = default;
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
90
src/tint/transform/num_workgroups_from_uniform.h
Normal file
90
src/tint/transform/num_workgroups_from_uniform.h
Normal file
@@ -0,0 +1,90 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_NUM_WORKGROUPS_FROM_UNIFORM_H_
|
||||
#define SRC_TINT_TRANSFORM_NUM_WORKGROUPS_FROM_UNIFORM_H_
|
||||
|
||||
#include "src/tint/sem/binding_point.h"
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
|
||||
// Forward declarations
|
||||
class CloneContext;
|
||||
|
||||
namespace transform {
|
||||
|
||||
/// NumWorkgroupsFromUniform is a transform that implements the `num_workgroups`
|
||||
/// builtin by loading it from a uniform buffer.
|
||||
///
|
||||
/// The generated uniform buffer will have the form:
|
||||
/// ```
|
||||
/// struct num_workgroups_struct {
|
||||
/// num_workgroups : vec3<u32>;
|
||||
/// };
|
||||
///
|
||||
/// @group(0) @binding(0)
|
||||
/// var<uniform> num_workgroups_ubo : num_workgroups_struct;
|
||||
/// ```
|
||||
/// The binding group and number used for this uniform buffer is provided via
|
||||
/// the `Config` transform input.
|
||||
///
|
||||
/// @note Depends on the following transforms to have been run first:
|
||||
/// * CanonicalizeEntryPointIO
|
||||
class NumWorkgroupsFromUniform
|
||||
: public Castable<NumWorkgroupsFromUniform, Transform> {
|
||||
public:
|
||||
/// Constructor
|
||||
NumWorkgroupsFromUniform();
|
||||
/// Destructor
|
||||
~NumWorkgroupsFromUniform() override;
|
||||
|
||||
/// Configuration options for the NumWorkgroupsFromUniform transform.
|
||||
struct Config : public Castable<Data, transform::Data> {
|
||||
/// Constructor
|
||||
/// @param ubo_bp the binding point to use for the generated uniform buffer.
|
||||
explicit Config(sem::BindingPoint ubo_bp);
|
||||
|
||||
/// Copy constructor
|
||||
Config(const Config&);
|
||||
|
||||
/// Destructor
|
||||
~Config() override;
|
||||
|
||||
/// The binding point to use for the generated uniform buffer.
|
||||
sem::BindingPoint ubo_binding;
|
||||
};
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_NUM_WORKGROUPS_FROM_UNIFORM_H_
|
||||
456
src/tint/transform/num_workgroups_from_uniform_test.cc
Normal file
456
src/tint/transform/num_workgroups_from_uniform_test.cc
Normal file
@@ -0,0 +1,456 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/num_workgroups_from_uniform.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/transform/canonicalize_entry_point_io.h"
|
||||
#include "src/tint/transform/test_helper.h"
|
||||
#include "src/tint/transform/unshadow.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using NumWorkgroupsFromUniformTest = TransformTest;
|
||||
|
||||
TEST_F(NumWorkgroupsFromUniformTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<NumWorkgroupsFromUniform>(src));
|
||||
}
|
||||
|
||||
TEST_F(NumWorkgroupsFromUniformTest, ShouldRunHasNumWorkgroups) {
|
||||
auto* src = R"(
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main([[builtin(num_workgroups)]] num_wgs : vec3<u32>) {
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<NumWorkgroupsFromUniform>(src));
|
||||
}
|
||||
|
||||
TEST_F(NumWorkgroupsFromUniformTest, Error_MissingTransformData) {
|
||||
auto* src = R"(
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main([[builtin(num_workgroups)]] num_wgs : vec3<u32>) {
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect =
|
||||
"error: missing transform data for "
|
||||
"tint::transform::NumWorkgroupsFromUniform";
|
||||
|
||||
DataMap data;
|
||||
data.Add<CanonicalizeEntryPointIO::Config>(
|
||||
CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
|
||||
auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
|
||||
src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(NumWorkgroupsFromUniformTest, Basic) {
|
||||
auto* src = R"(
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main(@builtin(num_workgroups) num_wgs : vec3<u32>) {
|
||||
let groups_x = num_wgs.x;
|
||||
let groups_y = num_wgs.y;
|
||||
let groups_z = num_wgs.z;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol_2 {
|
||||
num_workgroups : vec3<u32>;
|
||||
}
|
||||
|
||||
@group(0) @binding(30) var<uniform> tint_symbol_3 : tint_symbol_2;
|
||||
|
||||
fn main_inner(num_wgs : vec3<u32>) {
|
||||
let groups_x = num_wgs.x;
|
||||
let groups_y = num_wgs.y;
|
||||
let groups_z = num_wgs.z;
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
main_inner(tint_symbol_3.num_workgroups);
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
data.Add<CanonicalizeEntryPointIO::Config>(
|
||||
CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
|
||||
data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
|
||||
auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
|
||||
src, data);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(NumWorkgroupsFromUniformTest, StructOnlyMember) {
|
||||
auto* src = R"(
|
||||
struct Builtins {
|
||||
@builtin(num_workgroups) num_wgs : vec3<u32>;
|
||||
};
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main(in : Builtins) {
|
||||
let groups_x = in.num_wgs.x;
|
||||
let groups_y = in.num_wgs.y;
|
||||
let groups_z = in.num_wgs.z;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol_2 {
|
||||
num_workgroups : vec3<u32>;
|
||||
}
|
||||
|
||||
@group(0) @binding(30) var<uniform> tint_symbol_3 : tint_symbol_2;
|
||||
|
||||
struct Builtins {
|
||||
num_wgs : vec3<u32>;
|
||||
}
|
||||
|
||||
fn main_inner(in : Builtins) {
|
||||
let groups_x = in.num_wgs.x;
|
||||
let groups_y = in.num_wgs.y;
|
||||
let groups_z = in.num_wgs.z;
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
main_inner(Builtins(tint_symbol_3.num_workgroups));
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
data.Add<CanonicalizeEntryPointIO::Config>(
|
||||
CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
|
||||
data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
|
||||
auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
|
||||
src, data);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(NumWorkgroupsFromUniformTest, StructOnlyMember_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main(in : Builtins) {
|
||||
let groups_x = in.num_wgs.x;
|
||||
let groups_y = in.num_wgs.y;
|
||||
let groups_z = in.num_wgs.z;
|
||||
}
|
||||
|
||||
struct Builtins {
|
||||
@builtin(num_workgroups) num_wgs : vec3<u32>;
|
||||
};
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol_2 {
|
||||
num_workgroups : vec3<u32>;
|
||||
}
|
||||
|
||||
@group(0) @binding(30) var<uniform> tint_symbol_3 : tint_symbol_2;
|
||||
|
||||
fn main_inner(in : Builtins) {
|
||||
let groups_x = in.num_wgs.x;
|
||||
let groups_y = in.num_wgs.y;
|
||||
let groups_z = in.num_wgs.z;
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
main_inner(Builtins(tint_symbol_3.num_workgroups));
|
||||
}
|
||||
|
||||
struct Builtins {
|
||||
num_wgs : vec3<u32>;
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
data.Add<CanonicalizeEntryPointIO::Config>(
|
||||
CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
|
||||
data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
|
||||
auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
|
||||
src, data);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(NumWorkgroupsFromUniformTest, StructMultipleMembers) {
|
||||
auto* src = R"(
|
||||
struct Builtins {
|
||||
@builtin(global_invocation_id) gid : vec3<u32>;
|
||||
@builtin(num_workgroups) num_wgs : vec3<u32>;
|
||||
@builtin(workgroup_id) wgid : vec3<u32>;
|
||||
};
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main(in : Builtins) {
|
||||
let groups_x = in.num_wgs.x;
|
||||
let groups_y = in.num_wgs.y;
|
||||
let groups_z = in.num_wgs.z;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol_2 {
|
||||
num_workgroups : vec3<u32>;
|
||||
}
|
||||
|
||||
@group(0) @binding(30) var<uniform> tint_symbol_3 : tint_symbol_2;
|
||||
|
||||
struct Builtins {
|
||||
gid : vec3<u32>;
|
||||
num_wgs : vec3<u32>;
|
||||
wgid : vec3<u32>;
|
||||
}
|
||||
|
||||
struct tint_symbol_1 {
|
||||
@builtin(global_invocation_id)
|
||||
gid : vec3<u32>;
|
||||
@builtin(workgroup_id)
|
||||
wgid : vec3<u32>;
|
||||
}
|
||||
|
||||
fn main_inner(in : Builtins) {
|
||||
let groups_x = in.num_wgs.x;
|
||||
let groups_y = in.num_wgs.y;
|
||||
let groups_z = in.num_wgs.z;
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main(tint_symbol : tint_symbol_1) {
|
||||
main_inner(Builtins(tint_symbol.gid, tint_symbol_3.num_workgroups, tint_symbol.wgid));
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
data.Add<CanonicalizeEntryPointIO::Config>(
|
||||
CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
|
||||
data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
|
||||
auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
|
||||
src, data);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(NumWorkgroupsFromUniformTest, StructMultipleMembers_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main(in : Builtins) {
|
||||
let groups_x = in.num_wgs.x;
|
||||
let groups_y = in.num_wgs.y;
|
||||
let groups_z = in.num_wgs.z;
|
||||
}
|
||||
|
||||
struct Builtins {
|
||||
@builtin(global_invocation_id) gid : vec3<u32>;
|
||||
@builtin(num_workgroups) num_wgs : vec3<u32>;
|
||||
@builtin(workgroup_id) wgid : vec3<u32>;
|
||||
};
|
||||
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol_2 {
|
||||
num_workgroups : vec3<u32>;
|
||||
}
|
||||
|
||||
@group(0) @binding(30) var<uniform> tint_symbol_3 : tint_symbol_2;
|
||||
|
||||
struct tint_symbol_1 {
|
||||
@builtin(global_invocation_id)
|
||||
gid : vec3<u32>;
|
||||
@builtin(workgroup_id)
|
||||
wgid : vec3<u32>;
|
||||
}
|
||||
|
||||
fn main_inner(in : Builtins) {
|
||||
let groups_x = in.num_wgs.x;
|
||||
let groups_y = in.num_wgs.y;
|
||||
let groups_z = in.num_wgs.z;
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main(tint_symbol : tint_symbol_1) {
|
||||
main_inner(Builtins(tint_symbol.gid, tint_symbol_3.num_workgroups, tint_symbol.wgid));
|
||||
}
|
||||
|
||||
struct Builtins {
|
||||
gid : vec3<u32>;
|
||||
num_wgs : vec3<u32>;
|
||||
wgid : vec3<u32>;
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
data.Add<CanonicalizeEntryPointIO::Config>(
|
||||
CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
|
||||
data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
|
||||
auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
|
||||
src, data);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(NumWorkgroupsFromUniformTest, MultipleEntryPoints) {
|
||||
auto* src = R"(
|
||||
struct Builtins1 {
|
||||
@builtin(num_workgroups) num_wgs : vec3<u32>;
|
||||
};
|
||||
|
||||
struct Builtins2 {
|
||||
@builtin(global_invocation_id) gid : vec3<u32>;
|
||||
@builtin(num_workgroups) num_wgs : vec3<u32>;
|
||||
@builtin(workgroup_id) wgid : vec3<u32>;
|
||||
};
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main1(in : Builtins1) {
|
||||
let groups_x = in.num_wgs.x;
|
||||
let groups_y = in.num_wgs.y;
|
||||
let groups_z = in.num_wgs.z;
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main2(in : Builtins2) {
|
||||
let groups_x = in.num_wgs.x;
|
||||
let groups_y = in.num_wgs.y;
|
||||
let groups_z = in.num_wgs.z;
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main3(@builtin(num_workgroups) num_wgs : vec3<u32>) {
|
||||
let groups_x = num_wgs.x;
|
||||
let groups_y = num_wgs.y;
|
||||
let groups_z = num_wgs.z;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol_6 {
|
||||
num_workgroups : vec3<u32>;
|
||||
}
|
||||
|
||||
@group(0) @binding(30) var<uniform> tint_symbol_7 : tint_symbol_6;
|
||||
|
||||
struct Builtins1 {
|
||||
num_wgs : vec3<u32>;
|
||||
}
|
||||
|
||||
struct Builtins2 {
|
||||
gid : vec3<u32>;
|
||||
num_wgs : vec3<u32>;
|
||||
wgid : vec3<u32>;
|
||||
}
|
||||
|
||||
fn main1_inner(in : Builtins1) {
|
||||
let groups_x = in.num_wgs.x;
|
||||
let groups_y = in.num_wgs.y;
|
||||
let groups_z = in.num_wgs.z;
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main1() {
|
||||
main1_inner(Builtins1(tint_symbol_7.num_workgroups));
|
||||
}
|
||||
|
||||
struct tint_symbol_3 {
|
||||
@builtin(global_invocation_id)
|
||||
gid : vec3<u32>;
|
||||
@builtin(workgroup_id)
|
||||
wgid : vec3<u32>;
|
||||
}
|
||||
|
||||
fn main2_inner(in : Builtins2) {
|
||||
let groups_x = in.num_wgs.x;
|
||||
let groups_y = in.num_wgs.y;
|
||||
let groups_z = in.num_wgs.z;
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main2(tint_symbol_2 : tint_symbol_3) {
|
||||
main2_inner(Builtins2(tint_symbol_2.gid, tint_symbol_7.num_workgroups, tint_symbol_2.wgid));
|
||||
}
|
||||
|
||||
fn main3_inner(num_wgs : vec3<u32>) {
|
||||
let groups_x = num_wgs.x;
|
||||
let groups_y = num_wgs.y;
|
||||
let groups_z = num_wgs.z;
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main3() {
|
||||
main3_inner(tint_symbol_7.num_workgroups);
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
data.Add<CanonicalizeEntryPointIO::Config>(
|
||||
CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
|
||||
data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
|
||||
auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
|
||||
src, data);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(NumWorkgroupsFromUniformTest, NoUsages) {
|
||||
auto* src = R"(
|
||||
struct Builtins {
|
||||
@builtin(global_invocation_id) gid : vec3<u32>;
|
||||
@builtin(workgroup_id) wgid : vec3<u32>;
|
||||
};
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main(in : Builtins) {
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct Builtins {
|
||||
gid : vec3<u32>;
|
||||
wgid : vec3<u32>;
|
||||
}
|
||||
|
||||
struct tint_symbol_1 {
|
||||
@builtin(global_invocation_id)
|
||||
gid : vec3<u32>;
|
||||
@builtin(workgroup_id)
|
||||
wgid : vec3<u32>;
|
||||
}
|
||||
|
||||
fn main_inner(in : Builtins) {
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main(tint_symbol : tint_symbol_1) {
|
||||
main_inner(Builtins(tint_symbol.gid, tint_symbol.wgid));
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
data.Add<CanonicalizeEntryPointIO::Config>(
|
||||
CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
|
||||
data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
|
||||
auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
|
||||
src, data);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
177
src/tint/transform/pad_array_elements.cc
Normal file
177
src/tint/transform/pad_array_elements.cc
Normal file
@@ -0,0 +1,177 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/pad_array_elements.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/array.h"
|
||||
#include "src/tint/sem/call.h"
|
||||
#include "src/tint/sem/expression.h"
|
||||
#include "src/tint/sem/type_constructor.h"
|
||||
#include "src/tint/utils/map.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::PadArrayElements);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using ArrayBuilder = std::function<const ast::Array*()>;
|
||||
|
||||
/// PadArray returns a function that constructs a new array in `ctx.dst` with
|
||||
/// the element type padded to account for the explicit stride. PadArray will
|
||||
/// recursively pad arrays-of-arrays. The new array element type will be added
|
||||
/// to module-scope type declarations of `ctx.dst`.
|
||||
/// @param ctx the CloneContext
|
||||
/// @param create_ast_type_for Transform::CreateASTTypeFor()
|
||||
/// @param padded_arrays a map of src array type to the new array name
|
||||
/// @param array the array type
|
||||
/// @return the new AST array
|
||||
template <typename CREATE_AST_TYPE_FOR>
|
||||
ArrayBuilder PadArray(
|
||||
CloneContext& ctx,
|
||||
CREATE_AST_TYPE_FOR&& create_ast_type_for,
|
||||
std::unordered_map<const sem::Array*, ArrayBuilder>& padded_arrays,
|
||||
const sem::Array* array) {
|
||||
if (array->IsStrideImplicit()) {
|
||||
// We don't want to wrap arrays that have an implicit stride
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return utils::GetOrCreate(padded_arrays, array, [&] {
|
||||
// Generate a unique name for the array element type
|
||||
auto name = ctx.dst->Symbols().New("tint_padded_array_element");
|
||||
|
||||
// Examine the element type. Is it also an array?
|
||||
const ast::Type* el_ty = nullptr;
|
||||
if (auto* el_array = array->ElemType()->As<sem::Array>()) {
|
||||
// Array of array - call PadArray() on the element type
|
||||
if (auto p =
|
||||
PadArray(ctx, create_ast_type_for, padded_arrays, el_array)) {
|
||||
el_ty = p();
|
||||
}
|
||||
}
|
||||
|
||||
// If the element wasn't a padded array, just create the typical AST type
|
||||
// for it
|
||||
if (el_ty == nullptr) {
|
||||
el_ty = create_ast_type_for(ctx, array->ElemType());
|
||||
}
|
||||
|
||||
// Structure() will create and append the ast::Struct to the
|
||||
// global declarations of `ctx.dst`. As we haven't finished building the
|
||||
// current module-scope statement or function, this will be placed
|
||||
// immediately before the usage.
|
||||
ctx.dst->Structure(
|
||||
name,
|
||||
{ctx.dst->Member("el", el_ty, {ctx.dst->MemberSize(array->Stride())})});
|
||||
|
||||
auto* dst = ctx.dst;
|
||||
return [=] {
|
||||
if (array->IsRuntimeSized()) {
|
||||
return dst->ty.array(dst->create<ast::TypeName>(name));
|
||||
} else {
|
||||
return dst->ty.array(dst->create<ast::TypeName>(name), array->Count());
|
||||
}
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
PadArrayElements::PadArrayElements() = default;
|
||||
|
||||
PadArrayElements::~PadArrayElements() = default;
|
||||
|
||||
bool PadArrayElements::ShouldRun(const Program* program, const DataMap&) const {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (auto* var = node->As<ast::Type>()) {
|
||||
if (auto* arr = program->Sem().Get<sem::Array>(var)) {
|
||||
if (!arr->IsStrideImplicit()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void PadArrayElements::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
auto& sem = ctx.src->Sem();
|
||||
|
||||
std::unordered_map<const sem::Array*, ArrayBuilder> padded_arrays;
|
||||
auto pad = [&](const sem::Array* array) {
|
||||
return PadArray(ctx, CreateASTTypeFor, padded_arrays, array);
|
||||
};
|
||||
|
||||
// Replace all array types with their corresponding padded array type
|
||||
ctx.ReplaceAll([&](const ast::Type* ast_type) -> const ast::Type* {
|
||||
auto* type = ctx.src->TypeOf(ast_type);
|
||||
if (auto* array = type->UnwrapRef()->As<sem::Array>()) {
|
||||
if (auto p = pad(array)) {
|
||||
return p();
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
// Fix up index accessors so `a[1]` becomes `a[1].el`
|
||||
ctx.ReplaceAll([&](const ast::IndexAccessorExpression* accessor)
|
||||
-> const ast::Expression* {
|
||||
if (auto* array = tint::As<sem::Array>(
|
||||
sem.Get(accessor->object)->Type()->UnwrapRef())) {
|
||||
if (pad(array)) {
|
||||
// Array element is wrapped in a structure. Emit a member accessor
|
||||
// to get to the actual array element.
|
||||
auto* idx = ctx.CloneWithoutTransform(accessor);
|
||||
return ctx.dst->MemberAccessor(idx, "el");
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
// Fix up array constructors so `A(1,2)` becomes
|
||||
// `A(padded(1), padded(2))`
|
||||
ctx.ReplaceAll(
|
||||
[&](const ast::CallExpression* expr) -> const ast::Expression* {
|
||||
auto* call = sem.Get(expr);
|
||||
if (auto* ctor = call->Target()->As<sem::TypeConstructor>()) {
|
||||
if (auto* array = ctor->ReturnType()->As<sem::Array>()) {
|
||||
if (auto p = pad(array)) {
|
||||
auto* arr_ty = p();
|
||||
auto el_typename = arr_ty->type->As<ast::TypeName>()->name;
|
||||
|
||||
ast::ExpressionList args;
|
||||
args.reserve(call->Arguments().size());
|
||||
for (auto* arg : call->Arguments()) {
|
||||
auto* val = ctx.Clone(arg->Declaration());
|
||||
args.emplace_back(ctx.dst->Construct(
|
||||
ctx.dst->create<ast::TypeName>(el_typename), val));
|
||||
}
|
||||
|
||||
return ctx.dst->Construct(arr_ty, args);
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
62
src/tint/transform/pad_array_elements.h
Normal file
62
src/tint/transform/pad_array_elements.h
Normal file
@@ -0,0 +1,62 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_PAD_ARRAY_ELEMENTS_H_
|
||||
#define SRC_TINT_TRANSFORM_PAD_ARRAY_ELEMENTS_H_
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// PadArrayElements is a transform that replaces array types with an explicit
|
||||
/// stride that is larger than the implicit stride, with an array of a new
|
||||
/// structure type. This structure holds with a single field of the element
|
||||
/// type, decorated with a `@size` attribute to pad the structure to the
|
||||
/// required array stride. The new array types have no explicit stride,
|
||||
/// structure size is equal to the desired stride.
|
||||
/// Array index expressions and constructors are also adjusted to deal with this
|
||||
/// structure element type.
|
||||
/// This transform helps with backends that cannot directly return arrays or use
|
||||
/// them as parameters.
|
||||
class PadArrayElements : public Castable<PadArrayElements, Transform> {
|
||||
public:
|
||||
/// Constructor
|
||||
PadArrayElements();
|
||||
|
||||
/// Destructor
|
||||
~PadArrayElements() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_PAD_ARRAY_ELEMENTS_H_
|
||||
518
src/tint/transform/pad_array_elements_test.cc
Normal file
518
src/tint/transform/pad_array_elements_test.cc
Normal file
@@ -0,0 +1,518 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/pad_array_elements.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/transform/test_helper.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using PadArrayElementsTest = TransformTest;
|
||||
|
||||
TEST_F(PadArrayElementsTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<PadArrayElements>(src));
|
||||
}
|
||||
|
||||
TEST_F(PadArrayElementsTest, ShouldRunHasImplicitArrayStride) {
|
||||
auto* src = R"(
|
||||
var<private> arr : array<i32, 4>;
|
||||
)";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<PadArrayElements>(src));
|
||||
}
|
||||
|
||||
TEST_F(PadArrayElementsTest, ShouldRunHasExplicitArrayStride) {
|
||||
auto* src = R"(
|
||||
var<private> arr : [[stride(8)]] array<i32, 4>;
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<PadArrayElements>(src));
|
||||
}
|
||||
|
||||
TEST_F(PadArrayElementsTest, EmptyModule) {
|
||||
auto* src = "";
|
||||
auto* expect = "";
|
||||
|
||||
auto got = Run<PadArrayElements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PadArrayElementsTest, ImplicitArrayStride) {
|
||||
auto* src = R"(
|
||||
var<private> arr : array<i32, 4>;
|
||||
)";
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<PadArrayElements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PadArrayElementsTest, ArrayAsGlobal) {
|
||||
auto* src = R"(
|
||||
var<private> arr : @stride(8) array<i32, 4>;
|
||||
)";
|
||||
auto* expect = R"(
|
||||
struct tint_padded_array_element {
|
||||
@size(8)
|
||||
el : i32;
|
||||
}
|
||||
|
||||
var<private> arr : array<tint_padded_array_element, 4u>;
|
||||
)";
|
||||
|
||||
auto got = Run<PadArrayElements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PadArrayElementsTest, RuntimeArray) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
rta : @stride(8) array<i32>;
|
||||
};
|
||||
)";
|
||||
auto* expect = R"(
|
||||
struct tint_padded_array_element {
|
||||
@size(8)
|
||||
el : i32;
|
||||
}
|
||||
|
||||
struct S {
|
||||
rta : array<tint_padded_array_element>;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<PadArrayElements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PadArrayElementsTest, ArrayFunctionVar) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var arr : @stride(16) array<i32, 4>;
|
||||
arr = @stride(16) array<i32, 4>();
|
||||
arr = @stride(16) array<i32, 4>(1, 2, 3, 4);
|
||||
let x = arr[3];
|
||||
}
|
||||
)";
|
||||
auto* expect = R"(
|
||||
struct tint_padded_array_element {
|
||||
@size(16)
|
||||
el : i32;
|
||||
}
|
||||
|
||||
fn f() {
|
||||
var arr : array<tint_padded_array_element, 4u>;
|
||||
arr = array<tint_padded_array_element, 4u>();
|
||||
arr = array<tint_padded_array_element, 4u>(tint_padded_array_element(1), tint_padded_array_element(2), tint_padded_array_element(3), tint_padded_array_element(4));
|
||||
let x = arr[3].el;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<PadArrayElements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PadArrayElementsTest, ArrayAsParam) {
|
||||
auto* src = R"(
|
||||
fn f(a : @stride(12) array<i32, 4>) -> i32 {
|
||||
return a[2];
|
||||
}
|
||||
)";
|
||||
auto* expect = R"(
|
||||
struct tint_padded_array_element {
|
||||
@size(12)
|
||||
el : i32;
|
||||
}
|
||||
|
||||
fn f(a : array<tint_padded_array_element, 4u>) -> i32 {
|
||||
return a[2].el;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<PadArrayElements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
// TODO(crbug.com/tint/781): Cannot parse the stride on the return array type.
|
||||
TEST_F(PadArrayElementsTest, DISABLED_ArrayAsReturn) {
|
||||
auto* src = R"(
|
||||
fn f() -> @stride(8) array<i32, 4> {
|
||||
return array<i32, 4>(1, 2, 3, 4);
|
||||
}
|
||||
)";
|
||||
auto* expect = R"(
|
||||
struct tint_padded_array_element {
|
||||
el : i32;
|
||||
@size(4)
|
||||
padding : u32;
|
||||
};
|
||||
|
||||
fn f() -> array<tint_padded_array_element, 4> {
|
||||
return array<tint_padded_array_element, 4>(tint_padded_array_element(1, 0u), tint_padded_array_element(2, 0u), tint_padded_array_element(3, 0u), tint_padded_array_element(4, 0u));
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<PadArrayElements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PadArrayElementsTest, ArrayAlias) {
|
||||
auto* src = R"(
|
||||
type Array = @stride(16) array<i32, 4>;
|
||||
|
||||
fn f() {
|
||||
var arr : Array;
|
||||
arr = Array();
|
||||
arr = Array(1, 2, 3, 4);
|
||||
let vals : Array = Array(1, 2, 3, 4);
|
||||
arr = vals;
|
||||
let x = arr[3];
|
||||
}
|
||||
)";
|
||||
auto* expect = R"(
|
||||
struct tint_padded_array_element {
|
||||
@size(16)
|
||||
el : i32;
|
||||
}
|
||||
|
||||
type Array = array<tint_padded_array_element, 4u>;
|
||||
|
||||
fn f() {
|
||||
var arr : array<tint_padded_array_element, 4u>;
|
||||
arr = array<tint_padded_array_element, 4u>();
|
||||
arr = array<tint_padded_array_element, 4u>(tint_padded_array_element(1), tint_padded_array_element(2), tint_padded_array_element(3), tint_padded_array_element(4));
|
||||
let vals : array<tint_padded_array_element, 4u> = array<tint_padded_array_element, 4u>(tint_padded_array_element(1), tint_padded_array_element(2), tint_padded_array_element(3), tint_padded_array_element(4));
|
||||
arr = vals;
|
||||
let x = arr[3].el;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<PadArrayElements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PadArrayElementsTest, ArrayAlias_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var arr : Array;
|
||||
arr = Array();
|
||||
arr = Array(1, 2, 3, 4);
|
||||
let vals : Array = Array(1, 2, 3, 4);
|
||||
arr = vals;
|
||||
let x = arr[3];
|
||||
}
|
||||
|
||||
type Array = @stride(16) array<i32, 4>;
|
||||
)";
|
||||
auto* expect = R"(
|
||||
struct tint_padded_array_element {
|
||||
@size(16)
|
||||
el : i32;
|
||||
}
|
||||
|
||||
fn f() {
|
||||
var arr : array<tint_padded_array_element, 4u>;
|
||||
arr = array<tint_padded_array_element, 4u>();
|
||||
arr = array<tint_padded_array_element, 4u>(tint_padded_array_element(1), tint_padded_array_element(2), tint_padded_array_element(3), tint_padded_array_element(4));
|
||||
let vals : array<tint_padded_array_element, 4u> = array<tint_padded_array_element, 4u>(tint_padded_array_element(1), tint_padded_array_element(2), tint_padded_array_element(3), tint_padded_array_element(4));
|
||||
arr = vals;
|
||||
let x = arr[3].el;
|
||||
}
|
||||
|
||||
type Array = array<tint_padded_array_element, 4u>;
|
||||
)";
|
||||
|
||||
auto got = Run<PadArrayElements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PadArrayElementsTest, ArraysInStruct) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
a : @stride(8) array<i32, 4>;
|
||||
b : @stride(8) array<i32, 8>;
|
||||
c : @stride(8) array<i32, 4>;
|
||||
d : @stride(12) array<i32, 8>;
|
||||
};
|
||||
)";
|
||||
auto* expect = R"(
|
||||
struct tint_padded_array_element {
|
||||
@size(8)
|
||||
el : i32;
|
||||
}
|
||||
|
||||
struct tint_padded_array_element_1 {
|
||||
@size(8)
|
||||
el : i32;
|
||||
}
|
||||
|
||||
struct tint_padded_array_element_2 {
|
||||
@size(12)
|
||||
el : i32;
|
||||
}
|
||||
|
||||
struct S {
|
||||
a : array<tint_padded_array_element, 4u>;
|
||||
b : array<tint_padded_array_element_1, 8u>;
|
||||
c : array<tint_padded_array_element, 4u>;
|
||||
d : array<tint_padded_array_element_2, 8u>;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<PadArrayElements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PadArrayElementsTest, ArraysOfArraysInStruct) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
a : @stride(512) array<i32, 4>;
|
||||
b : @stride(512) array<@stride(32) array<i32, 4>, 4>;
|
||||
c : @stride(512) array<@stride(64) array<@stride(8) array<i32, 4>, 4>, 4>;
|
||||
};
|
||||
)";
|
||||
auto* expect = R"(
|
||||
struct tint_padded_array_element {
|
||||
@size(512)
|
||||
el : i32;
|
||||
}
|
||||
|
||||
struct tint_padded_array_element_2 {
|
||||
@size(32)
|
||||
el : i32;
|
||||
}
|
||||
|
||||
struct tint_padded_array_element_1 {
|
||||
@size(512)
|
||||
el : array<tint_padded_array_element_2, 4u>;
|
||||
}
|
||||
|
||||
struct tint_padded_array_element_5 {
|
||||
@size(8)
|
||||
el : i32;
|
||||
}
|
||||
|
||||
struct tint_padded_array_element_4 {
|
||||
@size(64)
|
||||
el : array<tint_padded_array_element_5, 4u>;
|
||||
}
|
||||
|
||||
struct tint_padded_array_element_3 {
|
||||
@size(512)
|
||||
el : array<tint_padded_array_element_4, 4u>;
|
||||
}
|
||||
|
||||
struct S {
|
||||
a : array<tint_padded_array_element, 4u>;
|
||||
b : array<tint_padded_array_element_1, 4u>;
|
||||
c : array<tint_padded_array_element_3, 4u>;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<PadArrayElements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PadArrayElementsTest, AccessArraysOfArraysInStruct) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
a : @stride(512) array<i32, 4>;
|
||||
b : @stride(512) array<@stride(32) array<i32, 4>, 4>;
|
||||
c : @stride(512) array<@stride(64) array<@stride(8) array<i32, 4>, 4>, 4>;
|
||||
};
|
||||
|
||||
fn f(s : S) -> i32 {
|
||||
return s.a[2] + s.b[1][2] + s.c[3][1][2];
|
||||
}
|
||||
)";
|
||||
auto* expect = R"(
|
||||
struct tint_padded_array_element {
|
||||
@size(512)
|
||||
el : i32;
|
||||
}
|
||||
|
||||
struct tint_padded_array_element_2 {
|
||||
@size(32)
|
||||
el : i32;
|
||||
}
|
||||
|
||||
struct tint_padded_array_element_1 {
|
||||
@size(512)
|
||||
el : array<tint_padded_array_element_2, 4u>;
|
||||
}
|
||||
|
||||
struct tint_padded_array_element_5 {
|
||||
@size(8)
|
||||
el : i32;
|
||||
}
|
||||
|
||||
struct tint_padded_array_element_4 {
|
||||
@size(64)
|
||||
el : array<tint_padded_array_element_5, 4u>;
|
||||
}
|
||||
|
||||
struct tint_padded_array_element_3 {
|
||||
@size(512)
|
||||
el : array<tint_padded_array_element_4, 4u>;
|
||||
}
|
||||
|
||||
struct S {
|
||||
a : array<tint_padded_array_element, 4u>;
|
||||
b : array<tint_padded_array_element_1, 4u>;
|
||||
c : array<tint_padded_array_element_3, 4u>;
|
||||
}
|
||||
|
||||
fn f(s : S) -> i32 {
|
||||
return ((s.a[2].el + s.b[1].el[2].el) + s.c[3].el[1].el[2].el);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<PadArrayElements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PadArrayElementsTest, AccessArraysOfArraysInStruct_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
fn f(s : S) -> i32 {
|
||||
return s.a[2] + s.b[1][2] + s.c[3][1][2];
|
||||
}
|
||||
|
||||
struct S {
|
||||
a : @stride(512) array<i32, 4>;
|
||||
b : @stride(512) array<@stride(32) array<i32, 4>, 4>;
|
||||
c : @stride(512) array<@stride(64) array<@stride(8) array<i32, 4>, 4>, 4>;
|
||||
};
|
||||
)";
|
||||
auto* expect = R"(
|
||||
struct tint_padded_array_element {
|
||||
@size(512)
|
||||
el : i32;
|
||||
}
|
||||
|
||||
struct tint_padded_array_element_1 {
|
||||
@size(32)
|
||||
el : i32;
|
||||
}
|
||||
|
||||
struct tint_padded_array_element_2 {
|
||||
@size(512)
|
||||
el : array<tint_padded_array_element_1, 4u>;
|
||||
}
|
||||
|
||||
struct tint_padded_array_element_3 {
|
||||
@size(8)
|
||||
el : i32;
|
||||
}
|
||||
|
||||
struct tint_padded_array_element_4 {
|
||||
@size(64)
|
||||
el : array<tint_padded_array_element_3, 4u>;
|
||||
}
|
||||
|
||||
struct tint_padded_array_element_5 {
|
||||
@size(512)
|
||||
el : array<tint_padded_array_element_4, 4u>;
|
||||
}
|
||||
|
||||
fn f(s : S) -> i32 {
|
||||
return ((s.a[2].el + s.b[1].el[2].el) + s.c[3].el[1].el[2].el);
|
||||
}
|
||||
|
||||
struct S {
|
||||
a : array<tint_padded_array_element, 4u>;
|
||||
b : array<tint_padded_array_element_2, 4u>;
|
||||
c : array<tint_padded_array_element_5, 4u>;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<PadArrayElements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PadArrayElementsTest, DeclarationOrder) {
|
||||
auto* src = R"(
|
||||
type T0 = i32;
|
||||
|
||||
type T1 = @stride(8) array<i32, 1>;
|
||||
|
||||
type T2 = i32;
|
||||
|
||||
fn f1(a : @stride(8) array<i32, 2>) {
|
||||
}
|
||||
|
||||
type T3 = i32;
|
||||
|
||||
fn f2() {
|
||||
var v : @stride(8) array<i32, 3>;
|
||||
}
|
||||
)";
|
||||
auto* expect = R"(
|
||||
type T0 = i32;
|
||||
|
||||
struct tint_padded_array_element {
|
||||
@size(8)
|
||||
el : i32;
|
||||
}
|
||||
|
||||
type T1 = array<tint_padded_array_element, 1u>;
|
||||
|
||||
type T2 = i32;
|
||||
|
||||
struct tint_padded_array_element_1 {
|
||||
@size(8)
|
||||
el : i32;
|
||||
}
|
||||
|
||||
fn f1(a : array<tint_padded_array_element_1, 2u>) {
|
||||
}
|
||||
|
||||
type T3 = i32;
|
||||
|
||||
struct tint_padded_array_element_2 {
|
||||
@size(8)
|
||||
el : i32;
|
||||
}
|
||||
|
||||
fn f2() {
|
||||
var v : array<tint_padded_array_element_2, 3u>;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<PadArrayElements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
83
src/tint/transform/promote_initializers_to_const_var.cc
Normal file
83
src/tint/transform/promote_initializers_to_const_var.cc
Normal file
@@ -0,0 +1,83 @@
|
||||
// Copyright 2022 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/promote_initializers_to_const_var.h"
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/call.h"
|
||||
#include "src/tint/sem/statement.h"
|
||||
#include "src/tint/sem/type_constructor.h"
|
||||
#include "src/tint/transform/utils/hoist_to_decl_before.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::PromoteInitializersToConstVar);
|
||||
|
||||
namespace tint::transform {
|
||||
|
||||
PromoteInitializersToConstVar::PromoteInitializersToConstVar() = default;
|
||||
|
||||
PromoteInitializersToConstVar::~PromoteInitializersToConstVar() = default;
|
||||
|
||||
void PromoteInitializersToConstVar::Run(CloneContext& ctx,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
HoistToDeclBefore hoist_to_decl_before(ctx);
|
||||
|
||||
// Hoists array and structure initializers to a constant variable, declared
|
||||
// just before the statement of usage.
|
||||
auto type_ctor_to_let = [&](const ast::CallExpression* expr) {
|
||||
auto* ctor = ctx.src->Sem().Get(expr);
|
||||
if (!ctor->Target()->Is<sem::TypeConstructor>()) {
|
||||
return true;
|
||||
}
|
||||
auto* sem_stmt = ctor->Stmt();
|
||||
if (!sem_stmt) {
|
||||
// Expression is outside of a statement. This usually means the
|
||||
// expression is part of a global (module-scope) constant declaration.
|
||||
// These must be constexpr, and so cannot contain the type of
|
||||
// expressions that must be sanitized.
|
||||
return true;
|
||||
}
|
||||
|
||||
auto* stmt = sem_stmt->Declaration();
|
||||
|
||||
if (auto* src_var_decl = stmt->As<ast::VariableDeclStatement>()) {
|
||||
if (src_var_decl->variable->constructor == expr) {
|
||||
// This statement is just a variable declaration with the
|
||||
// initializer as the constructor value. This is what we're
|
||||
// attempting to transform to, and so ignore.
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
auto* src_ty = ctor->Type();
|
||||
if (!src_ty->IsAnyOf<sem::Array, sem::Struct>()) {
|
||||
// We only care about array and struct initializers
|
||||
return true;
|
||||
}
|
||||
|
||||
return hoist_to_decl_before.Add(ctor, expr, true);
|
||||
};
|
||||
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
if (auto* call_expr = node->As<ast::CallExpression>()) {
|
||||
if (!type_ctor_to_let(call_expr)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
hoist_to_decl_before.Apply();
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
48
src/tint/transform/promote_initializers_to_const_var.h
Normal file
48
src/tint/transform/promote_initializers_to_const_var.h
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright 2022 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_PROMOTE_INITIALIZERS_TO_CONST_VAR_H_
|
||||
#define SRC_TINT_TRANSFORM_PROMOTE_INITIALIZERS_TO_CONST_VAR_H_
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint::transform {
|
||||
|
||||
/// A transform that hoists the array and structure initializers to a constant
|
||||
/// variable, declared just before the statement of usage.
|
||||
/// @see crbug.com/tint/406
|
||||
class PromoteInitializersToConstVar
|
||||
: public Castable<PromoteInitializersToConstVar, Transform> {
|
||||
public:
|
||||
/// Constructor
|
||||
PromoteInitializersToConstVar();
|
||||
|
||||
/// Destructor
|
||||
~PromoteInitializersToConstVar() override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_PROMOTE_INITIALIZERS_TO_CONST_VAR_H_
|
||||
627
src/tint/transform/promote_initializers_to_const_var_test.cc
Normal file
627
src/tint/transform/promote_initializers_to_const_var_test.cc
Normal file
@@ -0,0 +1,627 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/promote_initializers_to_const_var.h"
|
||||
|
||||
#include "src/tint/transform/test_helper.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using PromoteInitializersToConstVarTest = TransformTest;
|
||||
|
||||
TEST_F(PromoteInitializersToConstVarTest, EmptyModule) {
|
||||
auto* src = "";
|
||||
auto* expect = "";
|
||||
|
||||
auto got = Run<PromoteInitializersToConstVar>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PromoteInitializersToConstVarTest, BasicArray) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var f0 = 1.0;
|
||||
var f1 = 2.0;
|
||||
var f2 = 3.0;
|
||||
var f3 = 4.0;
|
||||
var i = array<f32, 4u>(f0, f1, f2, f3)[2];
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var f0 = 1.0;
|
||||
var f1 = 2.0;
|
||||
var f2 = 3.0;
|
||||
var f3 = 4.0;
|
||||
let tint_symbol = array<f32, 4u>(f0, f1, f2, f3);
|
||||
var i = tint_symbol[2];
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
auto got = Run<PromoteInitializersToConstVar>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PromoteInitializersToConstVarTest, BasicStruct) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
a : i32;
|
||||
b : f32;
|
||||
c : vec3<f32>;
|
||||
};
|
||||
|
||||
fn f() {
|
||||
var x = S(1, 2.0, vec3<f32>()).b;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
a : i32;
|
||||
b : f32;
|
||||
c : vec3<f32>;
|
||||
}
|
||||
|
||||
fn f() {
|
||||
let tint_symbol = S(1, 2.0, vec3<f32>());
|
||||
var x = tint_symbol.b;
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
auto got = Run<PromoteInitializersToConstVar>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PromoteInitializersToConstVarTest, BasicStruct_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var x = S(1, 2.0, vec3<f32>()).b;
|
||||
}
|
||||
|
||||
struct S {
|
||||
a : i32;
|
||||
b : f32;
|
||||
c : vec3<f32>;
|
||||
};
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
let tint_symbol = S(1, 2.0, vec3<f32>());
|
||||
var x = tint_symbol.b;
|
||||
}
|
||||
|
||||
struct S {
|
||||
a : i32;
|
||||
b : f32;
|
||||
c : vec3<f32>;
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
auto got = Run<PromoteInitializersToConstVar>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PromoteInitializersToConstVarTest, ArrayInForLoopInit) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var insert_after = 1;
|
||||
for(var i = array<f32, 4u>(0.0, 1.0, 2.0, 3.0)[2]; ; ) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var insert_after = 1;
|
||||
let tint_symbol = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
|
||||
for(var i = tint_symbol[2]; ; ) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
auto got = Run<PromoteInitializersToConstVar>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PromoteInitializersToConstVarTest, StructInForLoopInit) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
a : i32;
|
||||
b : f32;
|
||||
c : vec3<f32>;
|
||||
};
|
||||
|
||||
fn f() {
|
||||
var insert_after = 1;
|
||||
for(var x = S(1, 2.0, vec3<f32>()).b; ; ) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
a : i32;
|
||||
b : f32;
|
||||
c : vec3<f32>;
|
||||
}
|
||||
|
||||
fn f() {
|
||||
var insert_after = 1;
|
||||
let tint_symbol = S(1, 2.0, vec3<f32>());
|
||||
for(var x = tint_symbol.b; ; ) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
auto got = Run<PromoteInitializersToConstVar>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PromoteInitializersToConstVarTest, StructInForLoopInit_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var insert_after = 1;
|
||||
for(var x = S(1, 2.0, vec3<f32>()).b; ; ) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
struct S {
|
||||
a : i32;
|
||||
b : f32;
|
||||
c : vec3<f32>;
|
||||
};
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var insert_after = 1;
|
||||
let tint_symbol = S(1, 2.0, vec3<f32>());
|
||||
for(var x = tint_symbol.b; ; ) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
struct S {
|
||||
a : i32;
|
||||
b : f32;
|
||||
c : vec3<f32>;
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
auto got = Run<PromoteInitializersToConstVar>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PromoteInitializersToConstVarTest, ArrayInForLoopCond) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var f = 1.0;
|
||||
for(; f == array<f32, 1u>(f)[0]; f = f + 1.0) {
|
||||
var marker = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var f = 1.0;
|
||||
loop {
|
||||
let tint_symbol = array<f32, 1u>(f);
|
||||
if (!((f == tint_symbol[0]))) {
|
||||
break;
|
||||
}
|
||||
{
|
||||
var marker = 1;
|
||||
}
|
||||
|
||||
continuing {
|
||||
f = (f + 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
auto got = Run<PromoteInitializersToConstVar>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PromoteInitializersToConstVarTest, ArrayInForLoopCont) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var f = 0.0;
|
||||
for(; f < 10.0; f = f + array<f32, 1u>(1.0)[0]) {
|
||||
var marker = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var f = 0.0;
|
||||
loop {
|
||||
if (!((f < 10.0))) {
|
||||
break;
|
||||
}
|
||||
{
|
||||
var marker = 1;
|
||||
}
|
||||
|
||||
continuing {
|
||||
let tint_symbol = array<f32, 1u>(1.0);
|
||||
f = (f + tint_symbol[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
auto got = Run<PromoteInitializersToConstVar>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PromoteInitializersToConstVarTest, ArrayInForLoopInitCondCont) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
for(var f = array<f32, 1u>(0.0)[0];
|
||||
f < array<f32, 1u>(1.0)[0];
|
||||
f = f + array<f32, 1u>(2.0)[0]) {
|
||||
var marker = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
let tint_symbol = array<f32, 1u>(0.0);
|
||||
{
|
||||
var f = tint_symbol[0];
|
||||
loop {
|
||||
let tint_symbol_1 = array<f32, 1u>(1.0);
|
||||
if (!((f < tint_symbol_1[0]))) {
|
||||
break;
|
||||
}
|
||||
{
|
||||
var marker = 1;
|
||||
}
|
||||
|
||||
continuing {
|
||||
let tint_symbol_2 = array<f32, 1u>(2.0);
|
||||
f = (f + tint_symbol_2[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
auto got = Run<PromoteInitializersToConstVar>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PromoteInitializersToConstVarTest, ArrayInElseIf) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var f = 1.0;
|
||||
if (true) {
|
||||
var marker = 0;
|
||||
} else if (f == array<f32, 2u>(f, f)[0]) {
|
||||
var marker = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var f = 1.0;
|
||||
if (true) {
|
||||
var marker = 0;
|
||||
} else {
|
||||
let tint_symbol = array<f32, 2u>(f, f);
|
||||
if ((f == tint_symbol[0])) {
|
||||
var marker = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
auto got = Run<PromoteInitializersToConstVar>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PromoteInitializersToConstVarTest, ArrayInElseIfChain) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var f = 1.0;
|
||||
if (true) {
|
||||
var marker = 0;
|
||||
} else if (true) {
|
||||
var marker = 1;
|
||||
} else if (f == array<f32, 2u>(f, f)[0]) {
|
||||
var marker = 2;
|
||||
} else if (f == array<f32, 2u>(f, f)[1]) {
|
||||
var marker = 3;
|
||||
} else if (true) {
|
||||
var marker = 4;
|
||||
} else {
|
||||
var marker = 5;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var f = 1.0;
|
||||
if (true) {
|
||||
var marker = 0;
|
||||
} else if (true) {
|
||||
var marker = 1;
|
||||
} else {
|
||||
let tint_symbol = array<f32, 2u>(f, f);
|
||||
if ((f == tint_symbol[0])) {
|
||||
var marker = 2;
|
||||
} else {
|
||||
let tint_symbol_1 = array<f32, 2u>(f, f);
|
||||
if ((f == tint_symbol_1[1])) {
|
||||
var marker = 3;
|
||||
} else if (true) {
|
||||
var marker = 4;
|
||||
} else {
|
||||
var marker = 5;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
auto got = Run<PromoteInitializersToConstVar>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PromoteInitializersToConstVarTest, ArrayInArrayArray) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var i = array<array<f32, 2u>, 2u>(array<f32, 2u>(1.0, 2.0), array<f32, 2u>(3.0, 4.0))[0][1];
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
let tint_symbol = array<f32, 2u>(1.0, 2.0);
|
||||
let tint_symbol_1 = array<f32, 2u>(3.0, 4.0);
|
||||
let tint_symbol_2 = array<array<f32, 2u>, 2u>(tint_symbol, tint_symbol_1);
|
||||
var i = tint_symbol_2[0][1];
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
auto got = Run<PromoteInitializersToConstVar>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PromoteInitializersToConstVarTest, StructNested) {
|
||||
auto* src = R"(
|
||||
struct S1 {
|
||||
a : i32;
|
||||
};
|
||||
|
||||
struct S2 {
|
||||
a : i32;
|
||||
b : S1;
|
||||
c : i32;
|
||||
};
|
||||
|
||||
struct S3 {
|
||||
a : S2;
|
||||
};
|
||||
|
||||
fn f() {
|
||||
var x = S3(S2(1, S1(2), 3)).a.b.a;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct S1 {
|
||||
a : i32;
|
||||
}
|
||||
|
||||
struct S2 {
|
||||
a : i32;
|
||||
b : S1;
|
||||
c : i32;
|
||||
}
|
||||
|
||||
struct S3 {
|
||||
a : S2;
|
||||
}
|
||||
|
||||
fn f() {
|
||||
let tint_symbol = S1(2);
|
||||
let tint_symbol_1 = S2(1, tint_symbol, 3);
|
||||
let tint_symbol_2 = S3(tint_symbol_1);
|
||||
var x = tint_symbol_2.a.b.a;
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
auto got = Run<PromoteInitializersToConstVar>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PromoteInitializersToConstVarTest, Mixed) {
|
||||
auto* src = R"(
|
||||
struct S1 {
|
||||
a : i32;
|
||||
};
|
||||
|
||||
struct S2 {
|
||||
a : array<S1, 3u>;
|
||||
};
|
||||
|
||||
fn f() {
|
||||
var x = S2(array<S1, 3u>(S1(1), S1(2), S1(3))).a[1].a;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct S1 {
|
||||
a : i32;
|
||||
}
|
||||
|
||||
struct S2 {
|
||||
a : array<S1, 3u>;
|
||||
}
|
||||
|
||||
fn f() {
|
||||
let tint_symbol = S1(1);
|
||||
let tint_symbol_1 = S1(2);
|
||||
let tint_symbol_2 = S1(3);
|
||||
let tint_symbol_3 = array<S1, 3u>(tint_symbol, tint_symbol_1, tint_symbol_2);
|
||||
let tint_symbol_4 = S2(tint_symbol_3);
|
||||
var x = tint_symbol_4.a[1].a;
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
auto got = Run<PromoteInitializersToConstVar>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PromoteInitializersToConstVarTest, Mixed_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var x = S2(array<S1, 3u>(S1(1), S1(2), S1(3))).a[1].a;
|
||||
}
|
||||
|
||||
struct S2 {
|
||||
a : array<S1, 3u>;
|
||||
};
|
||||
|
||||
struct S1 {
|
||||
a : i32;
|
||||
};
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
let tint_symbol = S1(1);
|
||||
let tint_symbol_1 = S1(2);
|
||||
let tint_symbol_2 = S1(3);
|
||||
let tint_symbol_3 = array<S1, 3u>(tint_symbol, tint_symbol_1, tint_symbol_2);
|
||||
let tint_symbol_4 = S2(tint_symbol_3);
|
||||
var x = tint_symbol_4.a[1].a;
|
||||
}
|
||||
|
||||
struct S2 {
|
||||
a : array<S1, 3u>;
|
||||
}
|
||||
|
||||
struct S1 {
|
||||
a : i32;
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
auto got = Run<PromoteInitializersToConstVar>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PromoteInitializersToConstVarTest, NoChangeOnVarDecl) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
a : i32;
|
||||
b : f32;
|
||||
c : i32;
|
||||
}
|
||||
|
||||
fn f() {
|
||||
var local_arr = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
|
||||
var local_str = S(1, 2.0, 3);
|
||||
}
|
||||
|
||||
let module_arr : array<f32, 4u> = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
|
||||
|
||||
let module_str : S = S(1, 2.0, 3);
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
DataMap data;
|
||||
auto got = Run<PromoteInitializersToConstVar>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(PromoteInitializersToConstVarTest, NoChangeOnVarDecl_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var local_arr = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
|
||||
var local_str = S(1, 2.0, 3);
|
||||
}
|
||||
|
||||
let module_str : S = S(1, 2.0, 3);
|
||||
|
||||
struct S {
|
||||
a : i32;
|
||||
b : f32;
|
||||
c : i32;
|
||||
}
|
||||
|
||||
let module_arr : array<f32, 4u> = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
DataMap data;
|
||||
auto got = Run<PromoteInitializersToConstVar>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
156
src/tint/transform/remove_phonies.cc
Normal file
156
src/tint/transform/remove_phonies.cc
Normal file
@@ -0,0 +1,156 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/remove_phonies.h"
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/tint/ast/traverse_expressions.h"
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/block_statement.h"
|
||||
#include "src/tint/sem/function.h"
|
||||
#include "src/tint/sem/statement.h"
|
||||
#include "src/tint/sem/variable.h"
|
||||
#include "src/tint/utils/map.h"
|
||||
#include "src/tint/utils/scoped_assignment.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::RemovePhonies);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
struct SinkSignature {
|
||||
std::vector<const sem::Type*> types;
|
||||
|
||||
bool operator==(const SinkSignature& other) const {
|
||||
if (types.size() != other.types.size()) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < types.size(); i++) {
|
||||
if (types[i] != other.types[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
struct Hasher {
|
||||
/// @param sig the CallTargetSignature to hash
|
||||
/// @return the hash value
|
||||
std::size_t operator()(const SinkSignature& sig) const {
|
||||
size_t hash = tint::utils::Hash(sig.types.size());
|
||||
for (auto* ty : sig.types) {
|
||||
tint::utils::HashCombine(&hash, ty);
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
RemovePhonies::RemovePhonies() = default;
|
||||
|
||||
RemovePhonies::~RemovePhonies() = default;
|
||||
|
||||
bool RemovePhonies::ShouldRun(const Program* program, const DataMap&) const {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (node->Is<ast::PhonyExpression>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
auto& sem = ctx.src->Sem();
|
||||
|
||||
std::unordered_map<SinkSignature, Symbol, SinkSignature::Hasher> sinks;
|
||||
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
if (auto* stmt = node->As<ast::AssignmentStatement>()) {
|
||||
if (stmt->lhs->Is<ast::PhonyExpression>()) {
|
||||
std::vector<const ast::Expression*> side_effects;
|
||||
if (!ast::TraverseExpressions(
|
||||
stmt->rhs, ctx.dst->Diagnostics(),
|
||||
[&](const ast::CallExpression* call) {
|
||||
// ast::CallExpression may map to a function or builtin call
|
||||
// (both may have side-effects), or a type constructor or
|
||||
// type conversion (both do not have side effects).
|
||||
if (sem.Get(call)
|
||||
->Target()
|
||||
->IsAnyOf<sem::Function, sem::Builtin>()) {
|
||||
side_effects.push_back(call);
|
||||
return ast::TraverseAction::Skip;
|
||||
}
|
||||
return ast::TraverseAction::Descend;
|
||||
})) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (side_effects.empty()) {
|
||||
// Phony assignment with no side effects.
|
||||
// Just remove it.
|
||||
RemoveStatement(ctx, stmt);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (side_effects.size() == 1) {
|
||||
if (auto* call = side_effects[0]->As<ast::CallExpression>()) {
|
||||
// Phony assignment with single call side effect.
|
||||
// Replace phony assignment with call.
|
||||
ctx.Replace(
|
||||
stmt, [&, call] { return ctx.dst->CallStmt(ctx.Clone(call)); });
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Phony assignment with multiple side effects.
|
||||
// Generate a call to a placeholder function with the side
|
||||
// effects as arguments.
|
||||
ctx.Replace(stmt, [&, side_effects] {
|
||||
SinkSignature sig;
|
||||
for (auto* arg : side_effects) {
|
||||
sig.types.push_back(sem.Get(arg)->Type()->UnwrapRef());
|
||||
}
|
||||
auto sink = utils::GetOrCreate(sinks, sig, [&] {
|
||||
auto name = ctx.dst->Symbols().New("phony_sink");
|
||||
ast::VariableList params;
|
||||
for (auto* ty : sig.types) {
|
||||
auto* ast_ty = CreateASTTypeFor(ctx, ty);
|
||||
params.push_back(
|
||||
ctx.dst->Param("p" + std::to_string(params.size()), ast_ty));
|
||||
}
|
||||
ctx.dst->Func(name, params, ctx.dst->ty.void_(), {});
|
||||
return name;
|
||||
});
|
||||
ast::ExpressionList args;
|
||||
for (auto* arg : side_effects) {
|
||||
args.push_back(ctx.Clone(arg));
|
||||
}
|
||||
return ctx.dst->CallStmt(ctx.dst->Call(sink, args));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
58
src/tint/transform/remove_phonies.h
Normal file
58
src/tint/transform/remove_phonies.h
Normal file
@@ -0,0 +1,58 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_REMOVE_PHONIES_H_
|
||||
#define SRC_TINT_TRANSFORM_REMOVE_PHONIES_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// RemovePhonies is a Transform that removes all phony-assignment statements,
|
||||
/// while preserving function call expressions in the RHS of the assignment that
|
||||
/// may have side-effects.
|
||||
class RemovePhonies : public Castable<RemovePhonies, Transform> {
|
||||
public:
|
||||
/// Constructor
|
||||
RemovePhonies();
|
||||
|
||||
/// Destructor
|
||||
~RemovePhonies() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_REMOVE_PHONIES_H_
|
||||
431
src/tint/transform/remove_phonies_test.cc
Normal file
431
src/tint/transform/remove_phonies_test.cc
Normal file
@@ -0,0 +1,431 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/remove_phonies.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/tint/transform/test_helper.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using RemovePhoniesTest = TransformTest;
|
||||
|
||||
TEST_F(RemovePhoniesTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<RemovePhonies>(src));
|
||||
}
|
||||
|
||||
TEST_F(RemovePhoniesTest, ShouldRunHasPhony) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
_ = 1;
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<RemovePhonies>(src));
|
||||
}
|
||||
|
||||
TEST_F(RemovePhoniesTest, EmptyModule) {
|
||||
auto* src = "";
|
||||
auto* expect = "";
|
||||
|
||||
auto got = Run<RemovePhonies>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemovePhoniesTest, NoSideEffects) {
|
||||
auto* src = R"(
|
||||
@group(0) @binding(0) var t : texture_2d<f32>;
|
||||
|
||||
fn f() {
|
||||
var v : i32;
|
||||
_ = &v;
|
||||
_ = 1;
|
||||
_ = 1 + 2;
|
||||
_ = t;
|
||||
_ = u32(3.0);
|
||||
_ = f32(i32(4u));
|
||||
_ = vec2<f32>(5.0);
|
||||
_ = vec3<i32>(6, 7, 8);
|
||||
_ = mat2x2<f32>(9.0, 10.0, 11.0, 12.0);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@group(0) @binding(0) var t : texture_2d<f32>;
|
||||
|
||||
fn f() {
|
||||
var v : i32;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<RemovePhonies>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemovePhoniesTest, SingleSideEffects) {
|
||||
auto* src = R"(
|
||||
fn neg(a : i32) -> i32 {
|
||||
return -(a);
|
||||
}
|
||||
|
||||
fn add(a : i32, b : i32) -> i32 {
|
||||
return (a + b);
|
||||
}
|
||||
|
||||
fn f() {
|
||||
_ = neg(1);
|
||||
_ = add(2, 3);
|
||||
_ = add(neg(4), neg(5));
|
||||
_ = u32(neg(6));
|
||||
_ = f32(add(7, 8));
|
||||
_ = vec2<f32>(f32(neg(9)));
|
||||
_ = vec3<i32>(1, neg(10), 3);
|
||||
_ = mat2x2<f32>(1.0, f32(add(11, 12)), 3.0, 4.0);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn neg(a : i32) -> i32 {
|
||||
return -(a);
|
||||
}
|
||||
|
||||
fn add(a : i32, b : i32) -> i32 {
|
||||
return (a + b);
|
||||
}
|
||||
|
||||
fn f() {
|
||||
neg(1);
|
||||
add(2, 3);
|
||||
add(neg(4), neg(5));
|
||||
neg(6);
|
||||
add(7, 8);
|
||||
neg(9);
|
||||
neg(10);
|
||||
add(11, 12);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<RemovePhonies>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemovePhoniesTest, SingleSideEffects_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
_ = neg(1);
|
||||
_ = add(2, 3);
|
||||
_ = add(neg(4), neg(5));
|
||||
_ = u32(neg(6));
|
||||
_ = f32(add(7, 8));
|
||||
_ = vec2<f32>(f32(neg(9)));
|
||||
_ = vec3<i32>(1, neg(10), 3);
|
||||
_ = mat2x2<f32>(1.0, f32(add(11, 12)), 3.0, 4.0);
|
||||
}
|
||||
|
||||
fn add(a : i32, b : i32) -> i32 {
|
||||
return (a + b);
|
||||
}
|
||||
|
||||
fn neg(a : i32) -> i32 {
|
||||
return -(a);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
neg(1);
|
||||
add(2, 3);
|
||||
add(neg(4), neg(5));
|
||||
neg(6);
|
||||
add(7, 8);
|
||||
neg(9);
|
||||
neg(10);
|
||||
add(11, 12);
|
||||
}
|
||||
|
||||
fn add(a : i32, b : i32) -> i32 {
|
||||
return (a + b);
|
||||
}
|
||||
|
||||
fn neg(a : i32) -> i32 {
|
||||
return -(a);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<RemovePhonies>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemovePhoniesTest, MultipleSideEffects) {
|
||||
auto* src = R"(
|
||||
fn neg(a : i32) -> i32 {
|
||||
return -(a);
|
||||
}
|
||||
|
||||
fn add(a : i32, b : i32) -> i32 {
|
||||
return (a + b);
|
||||
}
|
||||
|
||||
fn xor(a : u32, b : u32) -> u32 {
|
||||
return (a ^ b);
|
||||
}
|
||||
|
||||
fn f() {
|
||||
_ = (1 + add(2 + add(3, 4), 5)) * add(6, 7) * neg(8);
|
||||
_ = add(9, neg(10)) + neg(11);
|
||||
_ = xor(12u, 13u) + xor(14u, 15u);
|
||||
_ = neg(16) / neg(17) + add(18, 19);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn neg(a : i32) -> i32 {
|
||||
return -(a);
|
||||
}
|
||||
|
||||
fn add(a : i32, b : i32) -> i32 {
|
||||
return (a + b);
|
||||
}
|
||||
|
||||
fn xor(a : u32, b : u32) -> u32 {
|
||||
return (a ^ b);
|
||||
}
|
||||
|
||||
fn phony_sink(p0 : i32, p1 : i32, p2 : i32) {
|
||||
}
|
||||
|
||||
fn phony_sink_1(p0 : i32, p1 : i32) {
|
||||
}
|
||||
|
||||
fn phony_sink_2(p0 : u32, p1 : u32) {
|
||||
}
|
||||
|
||||
fn f() {
|
||||
phony_sink(add((2 + add(3, 4)), 5), add(6, 7), neg(8));
|
||||
phony_sink_1(add(9, neg(10)), neg(11));
|
||||
phony_sink_2(xor(12u, 13u), xor(14u, 15u));
|
||||
phony_sink(neg(16), neg(17), add(18, 19));
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<RemovePhonies>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemovePhoniesTest, MultipleSideEffects_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
_ = (1 + add(2 + add(3, 4), 5)) * add(6, 7) * neg(8);
|
||||
_ = add(9, neg(10)) + neg(11);
|
||||
_ = xor(12u, 13u) + xor(14u, 15u);
|
||||
_ = neg(16) / neg(17) + add(18, 19);
|
||||
}
|
||||
|
||||
fn neg(a : i32) -> i32 {
|
||||
return -(a);
|
||||
}
|
||||
|
||||
fn add(a : i32, b : i32) -> i32 {
|
||||
return (a + b);
|
||||
}
|
||||
|
||||
fn xor(a : u32, b : u32) -> u32 {
|
||||
return (a ^ b);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn phony_sink(p0 : i32, p1 : i32, p2 : i32) {
|
||||
}
|
||||
|
||||
fn phony_sink_1(p0 : i32, p1 : i32) {
|
||||
}
|
||||
|
||||
fn phony_sink_2(p0 : u32, p1 : u32) {
|
||||
}
|
||||
|
||||
fn f() {
|
||||
phony_sink(add((2 + add(3, 4)), 5), add(6, 7), neg(8));
|
||||
phony_sink_1(add(9, neg(10)), neg(11));
|
||||
phony_sink_2(xor(12u, 13u), xor(14u, 15u));
|
||||
phony_sink(neg(16), neg(17), add(18, 19));
|
||||
}
|
||||
|
||||
fn neg(a : i32) -> i32 {
|
||||
return -(a);
|
||||
}
|
||||
|
||||
fn add(a : i32, b : i32) -> i32 {
|
||||
return (a + b);
|
||||
}
|
||||
|
||||
fn xor(a : u32, b : u32) -> u32 {
|
||||
return (a ^ b);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<RemovePhonies>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemovePhoniesTest, ForLoop) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
arr : array<i32>;
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> s : S;
|
||||
|
||||
fn x() -> i32 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
fn y() -> i32 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
fn z() -> i32 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
fn f() {
|
||||
for (_ = &s.arr; ;_ = &s.arr) {
|
||||
break;
|
||||
}
|
||||
for (_ = x(); ;_ = y() + z()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
arr : array<i32>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> s : S;
|
||||
|
||||
fn x() -> i32 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
fn y() -> i32 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
fn z() -> i32 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
fn phony_sink(p0 : i32, p1 : i32) {
|
||||
}
|
||||
|
||||
fn f() {
|
||||
for(; ; ) {
|
||||
break;
|
||||
}
|
||||
for(x(); ; phony_sink(y(), z())) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<RemovePhonies>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemovePhoniesTest, ForLoop_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
for (_ = &s.arr; ;_ = &s.arr) {
|
||||
break;
|
||||
}
|
||||
for (_ = x(); ;_ = y() + z()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
fn x() -> i32 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
fn y() -> i32 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
fn z() -> i32 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct S {
|
||||
arr : array<i32>;
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> s : S;
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn phony_sink(p0 : i32, p1 : i32) {
|
||||
}
|
||||
|
||||
fn f() {
|
||||
for(; ; ) {
|
||||
break;
|
||||
}
|
||||
for(x(); ; phony_sink(y(), z())) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
fn x() -> i32 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
fn y() -> i32 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
fn z() -> i32 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct S {
|
||||
arr : array<i32>;
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> s : S;
|
||||
)";
|
||||
|
||||
auto got = Run<RemovePhonies>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
67
src/tint/transform/remove_unreachable_statements.cc
Normal file
67
src/tint/transform/remove_unreachable_statements.cc
Normal file
@@ -0,0 +1,67 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/remove_unreachable_statements.h"
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/tint/ast/traverse_expressions.h"
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/block_statement.h"
|
||||
#include "src/tint/sem/function.h"
|
||||
#include "src/tint/sem/statement.h"
|
||||
#include "src/tint/sem/variable.h"
|
||||
#include "src/tint/utils/map.h"
|
||||
#include "src/tint/utils/scoped_assignment.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::RemoveUnreachableStatements);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
RemoveUnreachableStatements::RemoveUnreachableStatements() = default;
|
||||
|
||||
RemoveUnreachableStatements::~RemoveUnreachableStatements() = default;
|
||||
|
||||
bool RemoveUnreachableStatements::ShouldRun(const Program* program,
|
||||
const DataMap&) const {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (auto* stmt = program->Sem().Get<sem::Statement>(node)) {
|
||||
if (!stmt->IsReachable()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void RemoveUnreachableStatements::Run(CloneContext& ctx,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
if (auto* stmt = ctx.src->Sem().Get<sem::Statement>(node)) {
|
||||
if (!stmt->IsReachable()) {
|
||||
RemoveStatement(ctx, stmt->Declaration());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
58
src/tint/transform/remove_unreachable_statements.h
Normal file
58
src/tint/transform/remove_unreachable_statements.h
Normal file
@@ -0,0 +1,58 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_REMOVE_UNREACHABLE_STATEMENTS_H_
|
||||
#define SRC_TINT_TRANSFORM_REMOVE_UNREACHABLE_STATEMENTS_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// RemoveUnreachableStatements is a Transform that removes all statements
|
||||
/// marked as unreachable.
|
||||
class RemoveUnreachableStatements
|
||||
: public Castable<RemoveUnreachableStatements, Transform> {
|
||||
public:
|
||||
/// Constructor
|
||||
RemoveUnreachableStatements();
|
||||
|
||||
/// Destructor
|
||||
~RemoveUnreachableStatements() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_REMOVE_UNREACHABLE_STATEMENTS_H_
|
||||
571
src/tint/transform/remove_unreachable_statements_test.cc
Normal file
571
src/tint/transform/remove_unreachable_statements_test.cc
Normal file
@@ -0,0 +1,571 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/remove_unreachable_statements.h"
|
||||
|
||||
#include "src/tint/transform/test_helper.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using RemoveUnreachableStatementsTest = TransformTest;
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<RemoveUnreachableStatements>(src));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, ShouldRunHasNoUnreachable) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
if (true) {
|
||||
var x = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<RemoveUnreachableStatements>(src));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, ShouldRunHasUnreachable) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
return;
|
||||
if (true) {
|
||||
var x = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<RemoveUnreachableStatements>(src));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, EmptyModule) {
|
||||
auto* src = "";
|
||||
auto* expect = "";
|
||||
|
||||
auto got = Run<RemoveUnreachableStatements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, Return) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
return;
|
||||
var remove_me = 1;
|
||||
if (true) {
|
||||
var remove_me_too = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
return;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<RemoveUnreachableStatements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, NestedReturn) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
{
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
var remove_me = 1;
|
||||
if (true) {
|
||||
var remove_me_too = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
{
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<RemoveUnreachableStatements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, Discard) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
discard;
|
||||
var remove_me = 1;
|
||||
if (true) {
|
||||
var remove_me_too = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
discard;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<RemoveUnreachableStatements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, NestedDiscard) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
{
|
||||
{
|
||||
discard;
|
||||
}
|
||||
}
|
||||
var remove_me = 1;
|
||||
if (true) {
|
||||
var remove_me_too = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
{
|
||||
{
|
||||
discard;
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<RemoveUnreachableStatements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, CallToFuncWithDiscard) {
|
||||
auto* src = R"(
|
||||
fn DISCARD() {
|
||||
discard;
|
||||
}
|
||||
|
||||
fn f() {
|
||||
DISCARD();
|
||||
var remove_me = 1;
|
||||
if (true) {
|
||||
var remove_me_too = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn DISCARD() {
|
||||
discard;
|
||||
}
|
||||
|
||||
fn f() {
|
||||
DISCARD();
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<RemoveUnreachableStatements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, CallToFuncWithIfDiscard) {
|
||||
auto* src = R"(
|
||||
fn DISCARD() {
|
||||
if (true) {
|
||||
discard;
|
||||
}
|
||||
}
|
||||
|
||||
fn f() {
|
||||
DISCARD();
|
||||
var preserve_me = 1;
|
||||
if (true) {
|
||||
var preserve_me_too = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<RemoveUnreachableStatements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, IfDiscardElseDiscard) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
if (true) {
|
||||
discard;
|
||||
} else {
|
||||
discard;
|
||||
}
|
||||
var remove_me = 1;
|
||||
if (true) {
|
||||
var remove_me_too = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
if (true) {
|
||||
discard;
|
||||
} else {
|
||||
discard;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<RemoveUnreachableStatements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, IfDiscardElseReturn) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
if (true) {
|
||||
discard;
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
var remove_me = 1;
|
||||
if (true) {
|
||||
var remove_me_too = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
if (true) {
|
||||
discard;
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<RemoveUnreachableStatements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, IfDiscard) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
if (true) {
|
||||
discard;
|
||||
}
|
||||
var preserve_me = 1;
|
||||
if (true) {
|
||||
var preserve_me_too = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<RemoveUnreachableStatements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, IfReturn) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
if (true) {
|
||||
return;
|
||||
}
|
||||
var preserve_me = 1;
|
||||
if (true) {
|
||||
var preserve_me_too = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<RemoveUnreachableStatements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, IfElseDiscard) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
if (true) {
|
||||
} else {
|
||||
discard;
|
||||
}
|
||||
var preserve_me = 1;
|
||||
if (true) {
|
||||
var preserve_me_too = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<RemoveUnreachableStatements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, IfElseReturn) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
if (true) {
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
var preserve_me = 1;
|
||||
if (true) {
|
||||
var preserve_me_too = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<RemoveUnreachableStatements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, LoopWithDiscard) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
loop {
|
||||
var a = 1;
|
||||
discard;
|
||||
|
||||
continuing {
|
||||
var b = 2;
|
||||
}
|
||||
}
|
||||
var remove_me = 1;
|
||||
if (true) {
|
||||
var remove_me_too = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
loop {
|
||||
var a = 1;
|
||||
discard;
|
||||
|
||||
continuing {
|
||||
var b = 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<RemoveUnreachableStatements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, LoopWithConditionalBreak) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
loop {
|
||||
var a = 1;
|
||||
if (true) {
|
||||
break;
|
||||
}
|
||||
|
||||
continuing {
|
||||
var b = 2;
|
||||
}
|
||||
}
|
||||
var preserve_me = 1;
|
||||
if (true) {
|
||||
var preserve_me_too = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<RemoveUnreachableStatements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, LoopWithConditionalBreakInContinuing) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
loop {
|
||||
|
||||
continuing {
|
||||
if (true) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
var preserve_me = 1;
|
||||
if (true) {
|
||||
var preserve_me_too = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<RemoveUnreachableStatements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, SwitchDefaultDiscard) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
switch(1) {
|
||||
default: {
|
||||
discard;
|
||||
}
|
||||
}
|
||||
var remove_me = 1;
|
||||
if (true) {
|
||||
var remove_me_too = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
switch(1) {
|
||||
default: {
|
||||
discard;
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<RemoveUnreachableStatements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, SwitchCaseReturnDefaultDiscard) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
switch(1) {
|
||||
case 0: {
|
||||
return;
|
||||
}
|
||||
default: {
|
||||
discard;
|
||||
}
|
||||
}
|
||||
var remove_me = 1;
|
||||
if (true) {
|
||||
var remove_me_too = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
switch(1) {
|
||||
case 0: {
|
||||
return;
|
||||
}
|
||||
default: {
|
||||
discard;
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<RemoveUnreachableStatements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, SwitchCaseBreakDefaultDiscard) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
switch(1) {
|
||||
case 0: {
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
discard;
|
||||
}
|
||||
}
|
||||
var preserve_me = 1;
|
||||
if (true) {
|
||||
var preserve_me_too = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<RemoveUnreachableStatements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, SwitchCaseReturnDefaultBreak) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
switch(1) {
|
||||
case 0: {
|
||||
return;
|
||||
}
|
||||
default: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
var preserve_me = 1;
|
||||
if (true) {
|
||||
var preserve_me_too = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<RemoveUnreachableStatements>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
1366
src/tint/transform/renamer.cc
Normal file
1366
src/tint/transform/renamer.cc
Normal file
File diff suppressed because it is too large
Load Diff
97
src/tint/transform/renamer.h
Normal file
97
src/tint/transform/renamer.h
Normal file
@@ -0,0 +1,97 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_RENAMER_H_
|
||||
#define SRC_TINT_TRANSFORM_RENAMER_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint::transform {
|
||||
|
||||
/// Renamer is a Transform that renames all the symbols in a program.
|
||||
class Renamer : public Castable<Renamer, Transform> {
|
||||
public:
|
||||
/// Data is outputted by the Renamer transform.
|
||||
/// Data holds information about shader usage and constant buffer offsets.
|
||||
struct Data : public Castable<Data, transform::Data> {
|
||||
/// Remappings is a map of old symbol name to new symbol name
|
||||
using Remappings = std::unordered_map<std::string, std::string>;
|
||||
|
||||
/// Constructor
|
||||
/// @param remappings the symbol remappings
|
||||
explicit Data(Remappings&& remappings);
|
||||
|
||||
/// Copy constructor
|
||||
Data(const Data&);
|
||||
|
||||
/// Destructor
|
||||
~Data() override;
|
||||
|
||||
/// A map of old symbol name to new symbol name
|
||||
const Remappings remappings;
|
||||
};
|
||||
|
||||
/// Target is an enumerator of rename targets that can be used
|
||||
enum class Target {
|
||||
/// Rename every symbol.
|
||||
kAll,
|
||||
/// Only rename symbols that are reserved keywords in GLSL.
|
||||
kGlslKeywords,
|
||||
/// Only rename symbols that are reserved keywords in HLSL.
|
||||
kHlslKeywords,
|
||||
/// Only rename symbols that are reserved keywords in MSL.
|
||||
kMslKeywords,
|
||||
};
|
||||
|
||||
/// Optional configuration options for the transform.
|
||||
/// If omitted, then the renamer will use Target::kAll.
|
||||
struct Config : public Castable<Config, transform::Data> {
|
||||
/// Constructor
|
||||
/// @param tgt the targets to rename
|
||||
/// @param keep_unicode if false, symbols with non-ascii code-points are
|
||||
/// renamed
|
||||
explicit Config(Target tgt, bool keep_unicode = false);
|
||||
|
||||
/// Copy constructor
|
||||
Config(const Config&);
|
||||
|
||||
/// Destructor
|
||||
~Config() override;
|
||||
|
||||
/// The targets to rename
|
||||
Target const target = Target::kAll;
|
||||
|
||||
/// If false, symbols with non-ascii code-points are renamed.
|
||||
bool preserve_unicode = false;
|
||||
};
|
||||
|
||||
/// Constructor using a the configuration provided in the input Data
|
||||
Renamer();
|
||||
|
||||
/// Destructor
|
||||
~Renamer() override;
|
||||
|
||||
/// Runs the transform on `program`, returning the transformation result.
|
||||
/// @param program the source program to transform
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns the transformation result
|
||||
Output Run(const Program* program, const DataMap& data = {}) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_RENAMER_H_
|
||||
1463
src/tint/transform/renamer_test.cc
Normal file
1463
src/tint/transform/renamer_test.cc
Normal file
File diff suppressed because it is too large
Load Diff
321
src/tint/transform/robustness.cc
Normal file
321
src/tint/transform/robustness.cc
Normal file
@@ -0,0 +1,321 @@
|
||||
// Copyright 2020 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/robustness.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/block_statement.h"
|
||||
#include "src/tint/sem/call.h"
|
||||
#include "src/tint/sem/expression.h"
|
||||
#include "src/tint/sem/reference_type.h"
|
||||
#include "src/tint/sem/statement.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::Robustness);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::Robustness::Config);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// State holds the current transform state
|
||||
struct Robustness::State {
|
||||
/// The clone context
|
||||
CloneContext& ctx;
|
||||
|
||||
/// Set of storage classes to not apply the transform to
|
||||
std::unordered_set<ast::StorageClass> omitted_classes;
|
||||
|
||||
/// Applies the transformation state to `ctx`.
|
||||
void Transform() {
|
||||
ctx.ReplaceAll([&](const ast::IndexAccessorExpression* expr) {
|
||||
return Transform(expr);
|
||||
});
|
||||
ctx.ReplaceAll(
|
||||
[&](const ast::CallExpression* expr) { return Transform(expr); });
|
||||
}
|
||||
|
||||
/// Apply bounds clamping to array, vector and matrix indexing
|
||||
/// @param expr the array, vector or matrix index expression
|
||||
/// @return the clamped replacement expression, or nullptr if `expr` should be
|
||||
/// cloned without changes.
|
||||
const ast::IndexAccessorExpression* Transform(
|
||||
const ast::IndexAccessorExpression* expr) {
|
||||
auto* ret_type = ctx.src->Sem().Get(expr->object)->Type();
|
||||
|
||||
auto* ref = ret_type->As<sem::Reference>();
|
||||
if (ref && omitted_classes.count(ref->StorageClass()) != 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto* ret_unwrapped = ret_type->UnwrapRef();
|
||||
|
||||
ProgramBuilder& b = *ctx.dst;
|
||||
using u32 = ProgramBuilder::u32;
|
||||
|
||||
struct Value {
|
||||
const ast::Expression* expr = nullptr; // If null, then is a constant
|
||||
union {
|
||||
uint32_t u32 = 0; // use if is_signed == false
|
||||
int32_t i32; // use if is_signed == true
|
||||
};
|
||||
bool is_signed = false;
|
||||
};
|
||||
|
||||
Value size; // size of the array, vector or matrix
|
||||
size.is_signed = false; // size is always unsigned
|
||||
if (auto* vec = ret_unwrapped->As<sem::Vector>()) {
|
||||
size.u32 = vec->Width();
|
||||
|
||||
} else if (auto* arr = ret_unwrapped->As<sem::Array>()) {
|
||||
size.u32 = arr->Count();
|
||||
} else if (auto* mat = ret_unwrapped->As<sem::Matrix>()) {
|
||||
// The row accessor would have been an embedded index accessor and already
|
||||
// handled, so we just need to do columns here.
|
||||
size.u32 = mat->columns();
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (size.u32 == 0) {
|
||||
if (!ret_unwrapped->Is<sem::Array>()) {
|
||||
b.Diagnostics().add_error(diag::System::Transform,
|
||||
"invalid 0 sized non-array", expr->source);
|
||||
return nullptr;
|
||||
}
|
||||
// Runtime sized array
|
||||
auto* arr = ctx.Clone(expr->object);
|
||||
size.expr = b.Call("arrayLength", b.AddressOf(arr));
|
||||
}
|
||||
|
||||
// Calculate the maximum possible index value (size-1u)
|
||||
// Size must be positive (non-zero), so we can safely subtract 1 here
|
||||
// without underflow.
|
||||
Value limit;
|
||||
limit.is_signed = false; // Like size, limit is always unsigned.
|
||||
if (size.expr) {
|
||||
// Dynamic size
|
||||
limit.expr = b.Sub(size.expr, 1u);
|
||||
} else {
|
||||
// Constant size
|
||||
limit.u32 = size.u32 - 1u;
|
||||
}
|
||||
|
||||
Value idx; // index value
|
||||
|
||||
auto* idx_sem = ctx.src->Sem().Get(expr->index);
|
||||
auto* idx_ty = idx_sem->Type()->UnwrapRef();
|
||||
if (!idx_ty->IsAnyOf<sem::I32, sem::U32>()) {
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "index must be u32 or i32, got " << idx_sem->Type()->type_name();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (auto idx_constant = idx_sem->ConstantValue()) {
|
||||
// Constant value index
|
||||
if (idx_constant.Type()->Is<sem::I32>()) {
|
||||
idx.i32 = idx_constant.Elements()[0].i32;
|
||||
idx.is_signed = true;
|
||||
} else if (idx_constant.Type()->Is<sem::U32>()) {
|
||||
idx.u32 = idx_constant.Elements()[0].u32;
|
||||
idx.is_signed = false;
|
||||
} else {
|
||||
b.Diagnostics().add_error(diag::System::Transform,
|
||||
"unsupported constant value for accessor: " +
|
||||
idx_constant.Type()->type_name(),
|
||||
expr->source);
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
// Dynamic value index
|
||||
idx.expr = ctx.Clone(expr->index);
|
||||
idx.is_signed = idx_ty->Is<sem::I32>();
|
||||
}
|
||||
|
||||
// Clamp the index so that it cannot exceed limit.
|
||||
if (idx.expr || limit.expr) {
|
||||
// One of, or both of idx and limit are non-constant.
|
||||
|
||||
// If the index is signed, cast it to a u32 (with clamping if constant).
|
||||
if (idx.is_signed) {
|
||||
if (idx.expr) {
|
||||
// We don't use a max(idx, 0) here, as that incurs a runtime
|
||||
// performance cost, and if the unsigned value will be clamped by
|
||||
// limit, resulting in a value between [0..limit)
|
||||
idx.expr = b.Construct<u32>(idx.expr);
|
||||
idx.is_signed = false;
|
||||
} else {
|
||||
idx.u32 = static_cast<uint32_t>(std::max(idx.i32, 0));
|
||||
idx.is_signed = false;
|
||||
}
|
||||
}
|
||||
|
||||
// Convert idx and limit to expressions, so we can emit `min(idx, limit)`.
|
||||
if (!idx.expr) {
|
||||
idx.expr = b.Expr(idx.u32);
|
||||
}
|
||||
if (!limit.expr) {
|
||||
limit.expr = b.Expr(limit.u32);
|
||||
}
|
||||
|
||||
// Perform the clamp with `min(idx, limit)`
|
||||
idx.expr = b.Call("min", idx.expr, limit.expr);
|
||||
} else {
|
||||
// Both idx and max are constant.
|
||||
if (idx.is_signed) {
|
||||
// The index is signed. Calculate limit as signed.
|
||||
int32_t signed_limit = static_cast<int32_t>(
|
||||
std::min<uint32_t>(limit.u32, std::numeric_limits<int32_t>::max()));
|
||||
idx.i32 = std::max(idx.i32, 0);
|
||||
idx.i32 = std::min(idx.i32, signed_limit);
|
||||
} else {
|
||||
// The index is unsigned.
|
||||
idx.u32 = std::min(idx.u32, limit.u32);
|
||||
}
|
||||
}
|
||||
|
||||
// Convert idx to an expression, so we can emit the new accessor.
|
||||
if (!idx.expr) {
|
||||
idx.expr = idx.is_signed
|
||||
? static_cast<const ast::Expression*>(b.Expr(idx.i32))
|
||||
: static_cast<const ast::Expression*>(b.Expr(idx.u32));
|
||||
}
|
||||
|
||||
// Clone arguments outside of create() call to have deterministic ordering
|
||||
auto src = ctx.Clone(expr->source);
|
||||
auto* obj = ctx.Clone(expr->object);
|
||||
return b.IndexAccessor(src, obj, idx.expr);
|
||||
}
|
||||
|
||||
/// @param type builtin type
|
||||
/// @returns true if the given builtin is a texture function that requires
|
||||
/// argument clamping,
|
||||
bool TextureBuiltinNeedsClamping(sem::BuiltinType type) {
|
||||
return type == sem::BuiltinType::kTextureLoad ||
|
||||
type == sem::BuiltinType::kTextureStore;
|
||||
}
|
||||
|
||||
/// Apply bounds clamping to the coordinates, array index and level arguments
|
||||
/// of the `textureLoad()` and `textureStore()` builtins.
|
||||
/// @param expr the builtin call expression
|
||||
/// @return the clamped replacement call expression, or nullptr if `expr`
|
||||
/// should be cloned without changes.
|
||||
const ast::CallExpression* Transform(const ast::CallExpression* expr) {
|
||||
auto* call = ctx.src->Sem().Get(expr);
|
||||
auto* call_target = call->Target();
|
||||
auto* builtin = call_target->As<sem::Builtin>();
|
||||
if (!builtin || !TextureBuiltinNeedsClamping(builtin->Type())) {
|
||||
return nullptr; // No transform, just clone.
|
||||
}
|
||||
|
||||
ProgramBuilder& b = *ctx.dst;
|
||||
|
||||
// Indices of the mandatory texture and coords parameters, and the optional
|
||||
// array and level parameters.
|
||||
auto& signature = builtin->Signature();
|
||||
auto texture_idx = signature.IndexOf(sem::ParameterUsage::kTexture);
|
||||
auto coords_idx = signature.IndexOf(sem::ParameterUsage::kCoords);
|
||||
auto array_idx = signature.IndexOf(sem::ParameterUsage::kArrayIndex);
|
||||
auto level_idx = signature.IndexOf(sem::ParameterUsage::kLevel);
|
||||
|
||||
auto* texture_arg = expr->args[texture_idx];
|
||||
auto* coords_arg = expr->args[coords_idx];
|
||||
auto* coords_ty = builtin->Parameters()[coords_idx]->Type();
|
||||
|
||||
// If the level is provided, then we need to clamp this. As the level is
|
||||
// used by textureDimensions() and the texture[Load|Store]() calls, we need
|
||||
// to clamp both usages.
|
||||
// TODO(bclayton): We probably want to place this into a let so that the
|
||||
// calculation can be reused. This is fiddly to get right.
|
||||
std::function<const ast::Expression*()> level_arg;
|
||||
if (level_idx >= 0) {
|
||||
level_arg = [&] {
|
||||
auto* arg = expr->args[level_idx];
|
||||
auto* num_levels = b.Call("textureNumLevels", ctx.Clone(texture_arg));
|
||||
auto* zero = b.Expr(0);
|
||||
auto* max = ctx.dst->Sub(num_levels, 1);
|
||||
auto* clamped = b.Call("clamp", ctx.Clone(arg), zero, max);
|
||||
return clamped;
|
||||
};
|
||||
}
|
||||
|
||||
// Clamp the coordinates argument
|
||||
{
|
||||
auto* texture_dims =
|
||||
level_arg
|
||||
? b.Call("textureDimensions", ctx.Clone(texture_arg), level_arg())
|
||||
: b.Call("textureDimensions", ctx.Clone(texture_arg));
|
||||
auto* zero = b.Construct(CreateASTTypeFor(ctx, coords_ty));
|
||||
auto* max = ctx.dst->Sub(
|
||||
texture_dims, b.Construct(CreateASTTypeFor(ctx, coords_ty), 1));
|
||||
auto* clamped_coords = b.Call("clamp", ctx.Clone(coords_arg), zero, max);
|
||||
ctx.Replace(coords_arg, clamped_coords);
|
||||
}
|
||||
|
||||
// Clamp the array_index argument, if provided
|
||||
if (array_idx >= 0) {
|
||||
auto* arg = expr->args[array_idx];
|
||||
auto* num_layers = b.Call("textureNumLayers", ctx.Clone(texture_arg));
|
||||
auto* zero = b.Expr(0);
|
||||
auto* max = ctx.dst->Sub(num_layers, 1);
|
||||
auto* clamped = b.Call("clamp", ctx.Clone(arg), zero, max);
|
||||
ctx.Replace(arg, clamped);
|
||||
}
|
||||
|
||||
// Clamp the level argument, if provided
|
||||
if (level_idx >= 0) {
|
||||
auto* arg = expr->args[level_idx];
|
||||
ctx.Replace(arg, level_arg ? level_arg() : ctx.dst->Expr(0));
|
||||
}
|
||||
|
||||
return nullptr; // Clone, which will use the argument replacements above.
|
||||
}
|
||||
};
|
||||
|
||||
Robustness::Config::Config() = default;
|
||||
Robustness::Config::Config(const Config&) = default;
|
||||
Robustness::Config::~Config() = default;
|
||||
Robustness::Config& Robustness::Config::operator=(const Config&) = default;
|
||||
|
||||
Robustness::Robustness() = default;
|
||||
Robustness::~Robustness() = default;
|
||||
|
||||
void Robustness::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
|
||||
Config cfg;
|
||||
if (auto* cfg_data = inputs.Get<Config>()) {
|
||||
cfg = *cfg_data;
|
||||
}
|
||||
|
||||
std::unordered_set<ast::StorageClass> omitted_classes;
|
||||
for (auto sc : cfg.omitted_classes) {
|
||||
switch (sc) {
|
||||
case StorageClass::kUniform:
|
||||
omitted_classes.insert(ast::StorageClass::kUniform);
|
||||
break;
|
||||
case StorageClass::kStorage:
|
||||
omitted_classes.insert(ast::StorageClass::kStorage);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
State state{ctx, std::move(omitted_classes)};
|
||||
|
||||
state.Transform();
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
88
src/tint/transform/robustness.h
Normal file
88
src/tint/transform/robustness.h
Normal file
@@ -0,0 +1,88 @@
|
||||
// Copyright 2020 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_ROBUSTNESS_H_
|
||||
#define SRC_TINT_TRANSFORM_ROBUSTNESS_H_
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
// Forward declarations
|
||||
namespace tint {
|
||||
namespace ast {
|
||||
class IndexAccessorExpression;
|
||||
class CallExpression;
|
||||
} // namespace ast
|
||||
} // namespace tint
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// This transform is responsible for clamping all array accesses to be within
|
||||
/// the bounds of the array. Any access before the start of the array will clamp
|
||||
/// to zero and any access past the end of the array will clamp to
|
||||
/// (array length - 1).
|
||||
class Robustness : public Castable<Robustness, Transform> {
|
||||
public:
|
||||
/// Storage class to be skipped in the transform
|
||||
enum class StorageClass {
|
||||
kUniform,
|
||||
kStorage,
|
||||
};
|
||||
|
||||
/// Configuration options for the transform
|
||||
struct Config : public Castable<Config, Data> {
|
||||
/// Constructor
|
||||
Config();
|
||||
|
||||
/// Copy constructor
|
||||
Config(const Config&);
|
||||
|
||||
/// Destructor
|
||||
~Config() override;
|
||||
|
||||
/// Assignment operator
|
||||
/// @returns this Config
|
||||
Config& operator=(const Config&);
|
||||
|
||||
/// Storage classes to omit from apply the transform to.
|
||||
/// This allows for optimizing on hardware that provide safe accesses.
|
||||
std::unordered_set<StorageClass> omitted_classes;
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
Robustness();
|
||||
/// Destructor
|
||||
~Robustness() override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
private:
|
||||
struct State;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_ROBUSTNESS_H_
|
||||
1745
src/tint/transform/robustness_test.cc
Normal file
1745
src/tint/transform/robustness_test.cc
Normal file
File diff suppressed because it is too large
Load Diff
239
src/tint/transform/simplify_pointers.cc
Normal file
239
src/tint/transform/simplify_pointers.cc
Normal file
@@ -0,0 +1,239 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/simplify_pointers.h"
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/block_statement.h"
|
||||
#include "src/tint/sem/function.h"
|
||||
#include "src/tint/sem/statement.h"
|
||||
#include "src/tint/sem/variable.h"
|
||||
#include "src/tint/transform/unshadow.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::SimplifyPointers);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
namespace {
|
||||
|
||||
/// PointerOp describes either possible indirection or address-of action on an
|
||||
/// expression.
|
||||
struct PointerOp {
|
||||
/// Positive: Number of times the `expr` was dereferenced (*expr)
|
||||
/// Negative: Number of times the `expr` was 'addressed-of' (&expr)
|
||||
/// Zero: no pointer op on `expr`
|
||||
int indirections = 0;
|
||||
/// The expression being operated on
|
||||
const ast::Expression* expr = nullptr;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
/// The PIMPL state for the SimplifyPointers transform
|
||||
struct SimplifyPointers::State {
|
||||
/// The clone context
|
||||
CloneContext& ctx;
|
||||
|
||||
/// Constructor
|
||||
/// @param context the clone context
|
||||
explicit State(CloneContext& context) : ctx(context) {}
|
||||
|
||||
/// Traverses the expression `expr` looking for non-literal array indexing
|
||||
/// expressions that would affect the computed address of a pointer
|
||||
/// expression. The function-like argument `cb` is called for each found.
|
||||
/// @param expr the expression to traverse
|
||||
/// @param cb a function-like object with the signature
|
||||
/// `void(const ast::Expression*)`, which is called for each array index
|
||||
/// expression
|
||||
template <typename F>
|
||||
static void CollectSavedArrayIndices(const ast::Expression* expr, F&& cb) {
|
||||
if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
|
||||
CollectSavedArrayIndices(a->object, cb);
|
||||
if (!a->index->Is<ast::LiteralExpression>()) {
|
||||
cb(a->index);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
|
||||
CollectSavedArrayIndices(m->structure, cb);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto* u = expr->As<ast::UnaryOpExpression>()) {
|
||||
CollectSavedArrayIndices(u->expr, cb);
|
||||
return;
|
||||
}
|
||||
|
||||
// Note: Other ast::Expression types can be safely ignored as they cannot be
|
||||
// used to generate a reference or pointer.
|
||||
// See https://gpuweb.github.io/gpuweb/wgsl/#forming-references-and-pointers
|
||||
}
|
||||
|
||||
/// Reduce walks the expression chain, collapsing all address-of and
|
||||
/// indirection ops into a PointerOp.
|
||||
/// @param in the expression to walk
|
||||
/// @returns the reduced PointerOp
|
||||
PointerOp Reduce(const ast::Expression* in) const {
|
||||
PointerOp op{0, in};
|
||||
while (true) {
|
||||
if (auto* unary = op.expr->As<ast::UnaryOpExpression>()) {
|
||||
switch (unary->op) {
|
||||
case ast::UnaryOp::kIndirection:
|
||||
op.indirections++;
|
||||
op.expr = unary->expr;
|
||||
continue;
|
||||
case ast::UnaryOp::kAddressOf:
|
||||
op.indirections--;
|
||||
op.expr = unary->expr;
|
||||
continue;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (auto* user = ctx.src->Sem().Get<sem::VariableUser>(op.expr)) {
|
||||
auto* var = user->Variable();
|
||||
if (var->Is<sem::LocalVariable>() && //
|
||||
var->Declaration()->is_const && //
|
||||
var->Type()->Is<sem::Pointer>()) {
|
||||
op.expr = var->Declaration()->constructor;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
return op;
|
||||
}
|
||||
}
|
||||
|
||||
/// Performs the transformation
|
||||
void Run() {
|
||||
// A map of saved expressions to their saved variable name
|
||||
std::unordered_map<const ast::Expression*, Symbol> saved_vars;
|
||||
|
||||
// Register the ast::Expression transform handler.
|
||||
// This performs two different transformations:
|
||||
// * Identifiers that resolve to the pointer-typed `let` declarations are
|
||||
// replaced with the recursively inlined initializer expression for the
|
||||
// `let` declaration.
|
||||
// * Sub-expressions inside the pointer-typed `let` initializer expression
|
||||
// that have been hoisted to a saved variable are replaced with the saved
|
||||
// variable identifier.
|
||||
ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
|
||||
// Look to see if we need to swap this Expression with a saved variable.
|
||||
auto it = saved_vars.find(expr);
|
||||
if (it != saved_vars.end()) {
|
||||
return ctx.dst->Expr(it->second);
|
||||
}
|
||||
|
||||
// Reduce the expression, folding away chains of address-of / indirections
|
||||
auto op = Reduce(expr);
|
||||
|
||||
// Clone the reduced root expression
|
||||
expr = ctx.CloneWithoutTransform(op.expr);
|
||||
|
||||
// And reapply the minimum number of address-of / indirections
|
||||
for (int i = 0; i < op.indirections; i++) {
|
||||
expr = ctx.dst->Deref(expr);
|
||||
}
|
||||
for (int i = 0; i > op.indirections; i--) {
|
||||
expr = ctx.dst->AddressOf(expr);
|
||||
}
|
||||
return expr;
|
||||
});
|
||||
|
||||
// Find all the pointer-typed `let` declarations.
|
||||
// Note that these must be function-scoped, as module-scoped `let`s are not
|
||||
// permitted.
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
if (auto* let = node->As<ast::VariableDeclStatement>()) {
|
||||
if (!let->variable->is_const) {
|
||||
continue; // Not a `let` declaration. Ignore.
|
||||
}
|
||||
|
||||
auto* var = ctx.src->Sem().Get(let->variable);
|
||||
if (!var->Type()->Is<sem::Pointer>()) {
|
||||
continue; // Not a pointer type. Ignore.
|
||||
}
|
||||
|
||||
// We're dealing with a pointer-typed `let` declaration.
|
||||
|
||||
// Scan the initializer expression for array index expressions that need
|
||||
// to be hoist to temporary "saved" variables.
|
||||
std::vector<const ast::VariableDeclStatement*> saved;
|
||||
CollectSavedArrayIndices(
|
||||
var->Declaration()->constructor,
|
||||
[&](const ast::Expression* idx_expr) {
|
||||
// We have a sub-expression that needs to be saved.
|
||||
// Create a new variable
|
||||
auto saved_name = ctx.dst->Symbols().New(
|
||||
ctx.src->Symbols().NameFor(var->Declaration()->symbol) +
|
||||
"_save");
|
||||
auto* decl = ctx.dst->Decl(
|
||||
ctx.dst->Const(saved_name, nullptr, ctx.Clone(idx_expr)));
|
||||
saved.emplace_back(decl);
|
||||
// Record the substitution of `idx_expr` to the saved variable
|
||||
// with the symbol `saved_name`. This will be used by the
|
||||
// ReplaceAll() handler above.
|
||||
saved_vars.emplace(idx_expr, saved_name);
|
||||
});
|
||||
|
||||
// Find the place to insert the saved declarations.
|
||||
// Special care needs to be made for lets declared as the initializer
|
||||
// part of for-loops. In this case the block will hold the for-loop
|
||||
// statement, not the let.
|
||||
if (!saved.empty()) {
|
||||
auto* stmt = ctx.src->Sem().Get(let);
|
||||
auto* block = stmt->Block();
|
||||
// Find the statement owned by the block (either the let decl or a
|
||||
// for-loop)
|
||||
while (block != stmt->Parent()) {
|
||||
stmt = stmt->Parent();
|
||||
}
|
||||
// Declare the stored variables just before stmt. Order here is
|
||||
// important as order-of-operations needs to be preserved.
|
||||
// CollectSavedArrayIndices() visits the LHS of an index accessor
|
||||
// before the index expression.
|
||||
for (auto* decl : saved) {
|
||||
// Note that repeated calls to InsertBefore() with the same `before`
|
||||
// argument will result in nodes to inserted in the order the
|
||||
// calls are made (last call is inserted last).
|
||||
ctx.InsertBefore(block->Declaration()->statements,
|
||||
stmt->Declaration(), decl);
|
||||
}
|
||||
}
|
||||
|
||||
// As the original `let` declaration will be fully inlined, there's no
|
||||
// need for the original declaration to exist. Remove it.
|
||||
RemoveStatement(ctx, let);
|
||||
}
|
||||
}
|
||||
ctx.Clone();
|
||||
}
|
||||
};
|
||||
|
||||
SimplifyPointers::SimplifyPointers() = default;
|
||||
|
||||
SimplifyPointers::~SimplifyPointers() = default;
|
||||
|
||||
void SimplifyPointers::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
State(ctx).Run();
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
60
src/tint/transform/simplify_pointers.h
Normal file
60
src/tint/transform/simplify_pointers.h
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_SIMPLIFY_POINTERS_H_
|
||||
#define SRC_TINT_TRANSFORM_SIMPLIFY_POINTERS_H_
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// SimplifyPointers is a Transform that moves all usage of function-scope
|
||||
/// `let` statements of a pointer type into their places of usage, while also
|
||||
/// simplifying any chains of address-of or indirections operators.
|
||||
///
|
||||
/// Parameters of a pointer type are not adjusted.
|
||||
///
|
||||
/// Note: SimplifyPointers does not operate on module-scope `let`s, as these
|
||||
/// cannot be pointers: https://gpuweb.github.io/gpuweb/wgsl/#module-constants
|
||||
/// `A module-scope let-declared constant must be of constructible type.`
|
||||
///
|
||||
/// @note Depends on the following transforms to have been run first:
|
||||
/// * Unshadow
|
||||
class SimplifyPointers : public Castable<SimplifyPointers, Transform> {
|
||||
public:
|
||||
/// Constructor
|
||||
SimplifyPointers();
|
||||
|
||||
/// Destructor
|
||||
~SimplifyPointers() override;
|
||||
|
||||
protected:
|
||||
struct State;
|
||||
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_SIMPLIFY_POINTERS_H_
|
||||
370
src/tint/transform/simplify_pointers_test.cc
Normal file
370
src/tint/transform/simplify_pointers_test.cc
Normal file
@@ -0,0 +1,370 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/simplify_pointers.h"
|
||||
|
||||
#include "src/tint/transform/test_helper.h"
|
||||
#include "src/tint/transform/unshadow.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using SimplifyPointersTest = TransformTest;
|
||||
|
||||
TEST_F(SimplifyPointersTest, EmptyModule) {
|
||||
auto* src = "";
|
||||
auto* expect = "";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SimplifyPointersTest, FoldPointer) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var v : i32;
|
||||
let p : ptr<function, i32> = &v;
|
||||
let x : i32 = *p;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var v : i32;
|
||||
let x : i32 = v;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SimplifyPointersTest, AddressOfDeref) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var v : i32;
|
||||
let p : ptr<function, i32> = &(v);
|
||||
let x : ptr<function, i32> = &(*(p));
|
||||
let y : ptr<function, i32> = &(*(&(*(p))));
|
||||
let z : ptr<function, i32> = &(*(&(*(&(*(&(*(p))))))));
|
||||
var a = *x;
|
||||
var b = *y;
|
||||
var c = *z;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var v : i32;
|
||||
var a = v;
|
||||
var b = v;
|
||||
var c = v;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SimplifyPointersTest, DerefAddressOf) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var v : i32;
|
||||
let x : i32 = *(&(v));
|
||||
let y : i32 = *(&(*(&(v))));
|
||||
let z : i32 = *(&(*(&(*(&(*(&(v))))))));
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var v : i32;
|
||||
let x : i32 = v;
|
||||
let y : i32 = v;
|
||||
let z : i32 = v;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SimplifyPointersTest, ComplexChain) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var a : array<mat4x4<f32>, 4>;
|
||||
let ap : ptr<function, array<mat4x4<f32>, 4>> = &a;
|
||||
let mp : ptr<function, mat4x4<f32>> = &(*ap)[3];
|
||||
let vp : ptr<function, vec4<f32>> = &(*mp)[2];
|
||||
let v : vec4<f32> = *vp;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var a : array<mat4x4<f32>, 4>;
|
||||
let v : vec4<f32> = a[3][2];
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SimplifyPointersTest, SavedVars) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
i : i32;
|
||||
};
|
||||
|
||||
fn arr() {
|
||||
var a : array<S, 2>;
|
||||
var i : i32 = 0;
|
||||
var j : i32 = 0;
|
||||
let p : ptr<function, i32> = &a[i + j].i;
|
||||
i = 2;
|
||||
*p = 4;
|
||||
}
|
||||
|
||||
fn matrix() {
|
||||
var m : mat3x3<f32>;
|
||||
var i : i32 = 0;
|
||||
var j : i32 = 0;
|
||||
let p : ptr<function, vec3<f32>> = &m[i + j];
|
||||
i = 2;
|
||||
*p = vec3<f32>(4.0, 5.0, 6.0);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct S {
|
||||
i : i32;
|
||||
}
|
||||
|
||||
fn arr() {
|
||||
var a : array<S, 2>;
|
||||
var i : i32 = 0;
|
||||
var j : i32 = 0;
|
||||
let p_save = (i + j);
|
||||
i = 2;
|
||||
a[p_save].i = 4;
|
||||
}
|
||||
|
||||
fn matrix() {
|
||||
var m : mat3x3<f32>;
|
||||
var i : i32 = 0;
|
||||
var j : i32 = 0;
|
||||
let p_save_1 = (i + j);
|
||||
i = 2;
|
||||
m[p_save_1] = vec3<f32>(4.0, 5.0, 6.0);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SimplifyPointersTest, DontSaveLiterals) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var arr : array<i32, 2>;
|
||||
let p1 : ptr<function, i32> = &arr[1];
|
||||
*p1 = 4;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var arr : array<i32, 2>;
|
||||
arr[1] = 4;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SimplifyPointersTest, SavedVarsChain) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
var arr : array<array<i32, 2>, 2>;
|
||||
let i : i32 = 0;
|
||||
let j : i32 = 1;
|
||||
let p : ptr<function, array<i32, 2>> = &arr[i];
|
||||
let q : ptr<function, i32> = &(*p)[j];
|
||||
*q = 12;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var arr : array<array<i32, 2>, 2>;
|
||||
let i : i32 = 0;
|
||||
let j : i32 = 1;
|
||||
let p_save = i;
|
||||
let q_save = j;
|
||||
arr[p_save][q_save] = 12;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SimplifyPointersTest, ForLoopInit) {
|
||||
auto* src = R"(
|
||||
fn foo() -> i32 {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
var arr = array<f32, 4>();
|
||||
for (let a = &arr[foo()]; ;) {
|
||||
let x = *a;
|
||||
break;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn foo() -> i32 {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
var arr = array<f32, 4>();
|
||||
let a_save = foo();
|
||||
for(; ; ) {
|
||||
let x = arr[a_save];
|
||||
break;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SimplifyPointersTest, MultiSavedVarsInSinglePtrLetExpr) {
|
||||
auto* src = R"(
|
||||
fn x() -> i32 {
|
||||
return 1;
|
||||
}
|
||||
|
||||
fn y() -> i32 {
|
||||
return 1;
|
||||
}
|
||||
|
||||
fn z() -> i32 {
|
||||
return 1;
|
||||
}
|
||||
|
||||
struct Inner {
|
||||
a : array<i32, 2>;
|
||||
};
|
||||
|
||||
struct Outer {
|
||||
a : array<Inner, 2>;
|
||||
};
|
||||
|
||||
fn f() {
|
||||
var arr : array<Outer, 2>;
|
||||
let p : ptr<function, i32> = &arr[x()].a[y()].a[z()];
|
||||
*p = 1;
|
||||
*p = 2;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn x() -> i32 {
|
||||
return 1;
|
||||
}
|
||||
|
||||
fn y() -> i32 {
|
||||
return 1;
|
||||
}
|
||||
|
||||
fn z() -> i32 {
|
||||
return 1;
|
||||
}
|
||||
|
||||
struct Inner {
|
||||
a : array<i32, 2>;
|
||||
}
|
||||
|
||||
struct Outer {
|
||||
a : array<Inner, 2>;
|
||||
}
|
||||
|
||||
fn f() {
|
||||
var arr : array<Outer, 2>;
|
||||
let p_save = x();
|
||||
let p_save_1 = y();
|
||||
let p_save_2 = z();
|
||||
arr[p_save].a[p_save_1].a[p_save_2] = 1;
|
||||
arr[p_save].a[p_save_1].a[p_save_2] = 2;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SimplifyPointersTest, ShadowPointer) {
|
||||
auto* src = R"(
|
||||
var<private> a : array<i32, 2>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
let x = &a;
|
||||
var a : i32 = (*x)[0];
|
||||
{
|
||||
var a : i32 = (*x)[1];
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
var<private> a : array<i32, 2>;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
var a_1 : i32 = a[0];
|
||||
{
|
||||
var a_2 : i32 = a[1];
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
117
src/tint/transform/single_entry_point.cc
Normal file
117
src/tint/transform/single_entry_point.cc
Normal file
@@ -0,0 +1,117 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/single_entry_point.h"
|
||||
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/function.h"
|
||||
#include "src/tint/sem/variable.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::SingleEntryPoint);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::SingleEntryPoint::Config);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
SingleEntryPoint::SingleEntryPoint() = default;
|
||||
|
||||
SingleEntryPoint::~SingleEntryPoint() = default;
|
||||
|
||||
void SingleEntryPoint::Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap&) const {
|
||||
auto* cfg = inputs.Get<Config>();
|
||||
if (cfg == nullptr) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
diag::System::Transform,
|
||||
"missing transform data for " + std::string(TypeInfo().name));
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Find the target entry point.
|
||||
const ast::Function* entry_point = nullptr;
|
||||
for (auto* f : ctx.src->AST().Functions()) {
|
||||
if (!f->IsEntryPoint()) {
|
||||
continue;
|
||||
}
|
||||
if (ctx.src->Symbols().NameFor(f->symbol) == cfg->entry_point_name) {
|
||||
entry_point = f;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (entry_point == nullptr) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
diag::System::Transform,
|
||||
"entry point '" + cfg->entry_point_name + "' not found");
|
||||
return;
|
||||
}
|
||||
|
||||
auto& sem = ctx.src->Sem();
|
||||
|
||||
// Build set of referenced module-scope variables for faster lookups later.
|
||||
std::unordered_set<const ast::Variable*> referenced_vars;
|
||||
for (auto* var : sem.Get(entry_point)->TransitivelyReferencedGlobals()) {
|
||||
referenced_vars.emplace(var->Declaration());
|
||||
}
|
||||
|
||||
// Clone any module-scope variables, types, and functions that are statically
|
||||
// referenced by the target entry point.
|
||||
for (auto* decl : ctx.src->AST().GlobalDeclarations()) {
|
||||
if (auto* ty = decl->As<ast::TypeDecl>()) {
|
||||
// TODO(jrprice): Strip unused types.
|
||||
ctx.dst->AST().AddTypeDecl(ctx.Clone(ty));
|
||||
} else if (auto* var = decl->As<ast::Variable>()) {
|
||||
if (referenced_vars.count(var)) {
|
||||
if (var->is_overridable) {
|
||||
// It is an overridable constant
|
||||
if (!ast::HasAttribute<ast::IdAttribute>(var->attributes)) {
|
||||
// If the constant doesn't already have an @id() attribute, add one
|
||||
// so that its allocated ID so that it won't be affected by other
|
||||
// stripped away constants
|
||||
auto* global = sem.Get(var)->As<sem::GlobalVariable>();
|
||||
const auto* id = ctx.dst->Id(global->ConstantId());
|
||||
ctx.InsertFront(var->attributes, id);
|
||||
}
|
||||
}
|
||||
ctx.dst->AST().AddGlobalVariable(ctx.Clone(var));
|
||||
}
|
||||
} else if (auto* func = decl->As<ast::Function>()) {
|
||||
if (sem.Get(func)->HasAncestorEntryPoint(entry_point->symbol)) {
|
||||
ctx.dst->AST().AddFunction(ctx.Clone(func));
|
||||
}
|
||||
} else {
|
||||
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
|
||||
<< "unhandled global declaration: " << decl->TypeInfo().name;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Clone the entry point.
|
||||
ctx.dst->AST().AddFunction(ctx.Clone(entry_point));
|
||||
}
|
||||
|
||||
SingleEntryPoint::Config::Config(std::string entry_point)
|
||||
: entry_point_name(entry_point) {}
|
||||
|
||||
SingleEntryPoint::Config::Config(const Config&) = default;
|
||||
SingleEntryPoint::Config::~Config() = default;
|
||||
SingleEntryPoint::Config& SingleEntryPoint::Config::operator=(const Config&) =
|
||||
default;
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
72
src/tint/transform/single_entry_point.h
Normal file
72
src/tint/transform/single_entry_point.h
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_SINGLE_ENTRY_POINT_H_
|
||||
#define SRC_TINT_TRANSFORM_SINGLE_ENTRY_POINT_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// Strip all but one entry point a module.
|
||||
///
|
||||
/// All module-scope variables, types, and functions that are not used by the
|
||||
/// target entry point will also be removed.
|
||||
class SingleEntryPoint : public Castable<SingleEntryPoint, Transform> {
|
||||
public:
|
||||
/// Configuration options for the transform
|
||||
struct Config : public Castable<Config, Data> {
|
||||
/// Constructor
|
||||
/// @param entry_point the name of the entry point to keep
|
||||
explicit Config(std::string entry_point = "");
|
||||
|
||||
/// Copy constructor
|
||||
Config(const Config&);
|
||||
|
||||
/// Destructor
|
||||
~Config() override;
|
||||
|
||||
/// Assignment operator
|
||||
/// @returns this Config
|
||||
Config& operator=(const Config&);
|
||||
|
||||
/// The name of the entry point to keep.
|
||||
std::string entry_point_name;
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
SingleEntryPoint();
|
||||
|
||||
/// Destructor
|
||||
~SingleEntryPoint() override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_SINGLE_ENTRY_POINT_H_
|
||||
517
src/tint/transform/single_entry_point_test.cc
Normal file
517
src/tint/transform/single_entry_point_test.cc
Normal file
@@ -0,0 +1,517 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/single_entry_point.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/transform/test_helper.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using SingleEntryPointTest = TransformTest;
|
||||
|
||||
TEST_F(SingleEntryPointTest, Error_MissingTransformData) {
|
||||
auto* src = "";
|
||||
|
||||
auto* expect =
|
||||
"error: missing transform data for tint::transform::SingleEntryPoint";
|
||||
|
||||
auto got = Run<SingleEntryPoint>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SingleEntryPointTest, Error_NoEntryPoints) {
|
||||
auto* src = "";
|
||||
|
||||
auto* expect = "error: entry point 'main' not found";
|
||||
|
||||
DataMap data;
|
||||
data.Add<SingleEntryPoint::Config>("main");
|
||||
auto got = Run<SingleEntryPoint>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SingleEntryPointTest, Error_InvalidEntryPoint) {
|
||||
auto* src = R"(
|
||||
@stage(vertex)
|
||||
fn main() -> @builtin(position) vec4<f32> {
|
||||
return vec4<f32>();
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = "error: entry point '_' not found";
|
||||
|
||||
SingleEntryPoint::Config cfg("_");
|
||||
|
||||
DataMap data;
|
||||
data.Add<SingleEntryPoint::Config>(cfg);
|
||||
auto got = Run<SingleEntryPoint>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SingleEntryPointTest, Error_NotAnEntryPoint) {
|
||||
auto* src = R"(
|
||||
fn foo() {}
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {}
|
||||
)";
|
||||
|
||||
auto* expect = "error: entry point 'foo' not found";
|
||||
|
||||
SingleEntryPoint::Config cfg("foo");
|
||||
|
||||
DataMap data;
|
||||
data.Add<SingleEntryPoint::Config>(cfg);
|
||||
auto got = Run<SingleEntryPoint>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SingleEntryPointTest, SingleEntryPoint) {
|
||||
auto* src = R"(
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn main() {
|
||||
}
|
||||
)";
|
||||
|
||||
SingleEntryPoint::Config cfg("main");
|
||||
|
||||
DataMap data;
|
||||
data.Add<SingleEntryPoint::Config>(cfg);
|
||||
auto got = Run<SingleEntryPoint>(src, data);
|
||||
|
||||
EXPECT_EQ(src, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SingleEntryPointTest, MultipleEntryPoints) {
|
||||
auto* src = R"(
|
||||
@stage(vertex)
|
||||
fn vert_main() -> @builtin(position) vec4<f32> {
|
||||
return vec4<f32>();
|
||||
}
|
||||
|
||||
@stage(fragment)
|
||||
fn frag_main() {
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main1() {
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main2() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main1() {
|
||||
}
|
||||
)";
|
||||
|
||||
SingleEntryPoint::Config cfg("comp_main1");
|
||||
|
||||
DataMap data;
|
||||
data.Add<SingleEntryPoint::Config>(cfg);
|
||||
auto got = Run<SingleEntryPoint>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SingleEntryPointTest, GlobalVariables) {
|
||||
auto* src = R"(
|
||||
var<private> a : f32;
|
||||
|
||||
var<private> b : f32;
|
||||
|
||||
var<private> c : f32;
|
||||
|
||||
var<private> d : f32;
|
||||
|
||||
@stage(vertex)
|
||||
fn vert_main() -> @builtin(position) vec4<f32> {
|
||||
a = 0.0;
|
||||
return vec4<f32>();
|
||||
}
|
||||
|
||||
@stage(fragment)
|
||||
fn frag_main() {
|
||||
b = 0.0;
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main1() {
|
||||
c = 0.0;
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main2() {
|
||||
d = 0.0;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
var<private> c : f32;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main1() {
|
||||
c = 0.0;
|
||||
}
|
||||
)";
|
||||
|
||||
SingleEntryPoint::Config cfg("comp_main1");
|
||||
|
||||
DataMap data;
|
||||
data.Add<SingleEntryPoint::Config>(cfg);
|
||||
auto got = Run<SingleEntryPoint>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SingleEntryPointTest, GlobalConstants) {
|
||||
auto* src = R"(
|
||||
let a : f32 = 1.0;
|
||||
|
||||
let b : f32 = 1.0;
|
||||
|
||||
let c : f32 = 1.0;
|
||||
|
||||
let d : f32 = 1.0;
|
||||
|
||||
@stage(vertex)
|
||||
fn vert_main() -> @builtin(position) vec4<f32> {
|
||||
let local_a : f32 = a;
|
||||
return vec4<f32>();
|
||||
}
|
||||
|
||||
@stage(fragment)
|
||||
fn frag_main() {
|
||||
let local_b : f32 = b;
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main1() {
|
||||
let local_c : f32 = c;
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main2() {
|
||||
let local_d : f32 = d;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
let c : f32 = 1.0;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main1() {
|
||||
let local_c : f32 = c;
|
||||
}
|
||||
)";
|
||||
|
||||
SingleEntryPoint::Config cfg("comp_main1");
|
||||
|
||||
DataMap data;
|
||||
data.Add<SingleEntryPoint::Config>(cfg);
|
||||
auto got = Run<SingleEntryPoint>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SingleEntryPointTest, WorkgroupSizeLetPreserved) {
|
||||
auto* src = R"(
|
||||
let size : i32 = 1;
|
||||
|
||||
@stage(compute) @workgroup_size(size)
|
||||
fn main() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
SingleEntryPoint::Config cfg("main");
|
||||
|
||||
DataMap data;
|
||||
data.Add<SingleEntryPoint::Config>(cfg);
|
||||
auto got = Run<SingleEntryPoint>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SingleEntryPointTest, OverridableConstants) {
|
||||
auto* src = R"(
|
||||
@id(1001) override c1 : u32 = 1u;
|
||||
override c2 : u32 = 1u;
|
||||
@id(0) override c3 : u32 = 1u;
|
||||
@id(9999) override c4 : u32 = 1u;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main1() {
|
||||
let local_d = c1;
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main2() {
|
||||
let local_d = c2;
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main3() {
|
||||
let local_d = c3;
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main4() {
|
||||
let local_d = c4;
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main5() {
|
||||
let local_d = 1u;
|
||||
}
|
||||
)";
|
||||
|
||||
{
|
||||
SingleEntryPoint::Config cfg("comp_main1");
|
||||
auto* expect = R"(
|
||||
@id(1001) override c1 : u32 = 1u;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main1() {
|
||||
let local_d = c1;
|
||||
}
|
||||
)";
|
||||
DataMap data;
|
||||
data.Add<SingleEntryPoint::Config>(cfg);
|
||||
auto got = Run<SingleEntryPoint>(src, data);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
{
|
||||
SingleEntryPoint::Config cfg("comp_main2");
|
||||
// The decorator is replaced with the one with explicit id
|
||||
// And should not be affected by other constants stripped away
|
||||
auto* expect = R"(
|
||||
@id(1) override c2 : u32 = 1u;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main2() {
|
||||
let local_d = c2;
|
||||
}
|
||||
)";
|
||||
DataMap data;
|
||||
data.Add<SingleEntryPoint::Config>(cfg);
|
||||
auto got = Run<SingleEntryPoint>(src, data);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
{
|
||||
SingleEntryPoint::Config cfg("comp_main3");
|
||||
auto* expect = R"(
|
||||
@id(0) override c3 : u32 = 1u;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main3() {
|
||||
let local_d = c3;
|
||||
}
|
||||
)";
|
||||
DataMap data;
|
||||
data.Add<SingleEntryPoint::Config>(cfg);
|
||||
auto got = Run<SingleEntryPoint>(src, data);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
{
|
||||
SingleEntryPoint::Config cfg("comp_main4");
|
||||
auto* expect = R"(
|
||||
@id(9999) override c4 : u32 = 1u;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main4() {
|
||||
let local_d = c4;
|
||||
}
|
||||
)";
|
||||
DataMap data;
|
||||
data.Add<SingleEntryPoint::Config>(cfg);
|
||||
auto got = Run<SingleEntryPoint>(src, data);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
{
|
||||
SingleEntryPoint::Config cfg("comp_main5");
|
||||
auto* expect = R"(
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main5() {
|
||||
let local_d = 1u;
|
||||
}
|
||||
)";
|
||||
DataMap data;
|
||||
data.Add<SingleEntryPoint::Config>(cfg);
|
||||
auto got = Run<SingleEntryPoint>(src, data);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(SingleEntryPointTest, CalledFunctions) {
|
||||
auto* src = R"(
|
||||
fn inner1() {
|
||||
}
|
||||
|
||||
fn inner2() {
|
||||
}
|
||||
|
||||
fn inner_shared() {
|
||||
}
|
||||
|
||||
fn outer1() {
|
||||
inner1();
|
||||
inner_shared();
|
||||
}
|
||||
|
||||
fn outer2() {
|
||||
inner2();
|
||||
inner_shared();
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main1() {
|
||||
outer1();
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main2() {
|
||||
outer2();
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn inner1() {
|
||||
}
|
||||
|
||||
fn inner_shared() {
|
||||
}
|
||||
|
||||
fn outer1() {
|
||||
inner1();
|
||||
inner_shared();
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main1() {
|
||||
outer1();
|
||||
}
|
||||
)";
|
||||
|
||||
SingleEntryPoint::Config cfg("comp_main1");
|
||||
|
||||
DataMap data;
|
||||
data.Add<SingleEntryPoint::Config>(cfg);
|
||||
auto got = Run<SingleEntryPoint>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SingleEntryPointTest, GlobalsReferencedByCalledFunctions) {
|
||||
auto* src = R"(
|
||||
var<private> inner1_var : f32;
|
||||
|
||||
var<private> inner2_var : f32;
|
||||
|
||||
var<private> inner_shared_var : f32;
|
||||
|
||||
var<private> outer1_var : f32;
|
||||
|
||||
var<private> outer2_var : f32;
|
||||
|
||||
fn inner1() {
|
||||
inner1_var = 0.0;
|
||||
}
|
||||
|
||||
fn inner2() {
|
||||
inner2_var = 0.0;
|
||||
}
|
||||
|
||||
fn inner_shared() {
|
||||
inner_shared_var = 0.0;
|
||||
}
|
||||
|
||||
fn outer1() {
|
||||
inner1();
|
||||
inner_shared();
|
||||
outer1_var = 0.0;
|
||||
}
|
||||
|
||||
fn outer2() {
|
||||
inner2();
|
||||
inner_shared();
|
||||
outer2_var = 0.0;
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main1() {
|
||||
outer1();
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main2() {
|
||||
outer2();
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
var<private> inner1_var : f32;
|
||||
|
||||
var<private> inner_shared_var : f32;
|
||||
|
||||
var<private> outer1_var : f32;
|
||||
|
||||
fn inner1() {
|
||||
inner1_var = 0.0;
|
||||
}
|
||||
|
||||
fn inner_shared() {
|
||||
inner_shared_var = 0.0;
|
||||
}
|
||||
|
||||
fn outer1() {
|
||||
inner1();
|
||||
inner_shared();
|
||||
outer1_var = 0.0;
|
||||
}
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn comp_main1() {
|
||||
outer1();
|
||||
}
|
||||
)";
|
||||
|
||||
SingleEntryPoint::Config cfg("comp_main1");
|
||||
|
||||
DataMap data;
|
||||
data.Add<SingleEntryPoint::Config>(cfg);
|
||||
auto got = Run<SingleEntryPoint>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
153
src/tint/transform/test_helper.h
Normal file
153
src/tint/transform/test_helper.h
Normal file
@@ -0,0 +1,153 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_TEST_HELPER_H_
|
||||
#define SRC_TINT_TRANSFORM_TEST_HELPER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "src/tint/reader/wgsl/parser.h"
|
||||
#include "src/tint/transform/manager.h"
|
||||
#include "src/tint/writer/wgsl/generator.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// @param program the program to get an output WGSL string from
|
||||
/// @returns the output program as a WGSL string, or an error string if the
|
||||
/// program is not valid.
|
||||
inline std::string str(const Program& program) {
|
||||
diag::Formatter::Style style;
|
||||
style.print_newline_at_end = false;
|
||||
|
||||
if (!program.IsValid()) {
|
||||
return diag::Formatter(style).format(program.Diagnostics());
|
||||
}
|
||||
|
||||
writer::wgsl::Options options;
|
||||
auto result = writer::wgsl::Generate(&program, options);
|
||||
if (!result.success) {
|
||||
return "WGSL writer failed:\n" + result.error;
|
||||
}
|
||||
|
||||
auto res = result.wgsl;
|
||||
if (res.empty()) {
|
||||
return res;
|
||||
}
|
||||
// The WGSL sometimes has two trailing newlines. Strip them
|
||||
while (res.back() == '\n') {
|
||||
res.pop_back();
|
||||
}
|
||||
if (res.empty()) {
|
||||
return res;
|
||||
}
|
||||
return "\n" + res + "\n";
|
||||
}
|
||||
|
||||
/// Helper class for testing transforms
|
||||
template <typename BASE>
|
||||
class TransformTestBase : public BASE {
|
||||
public:
|
||||
/// Transforms and returns the WGSL source `in`, transformed using
|
||||
/// `transform`.
|
||||
/// @param transform the transform to apply
|
||||
/// @param in the input WGSL source
|
||||
/// @param data the optional DataMap to pass to Transform::Run()
|
||||
/// @return the transformed output
|
||||
Output Run(std::string in,
|
||||
std::unique_ptr<transform::Transform> transform,
|
||||
const DataMap& data = {}) {
|
||||
std::vector<std::unique_ptr<transform::Transform>> transforms;
|
||||
transforms.emplace_back(std::move(transform));
|
||||
return Run(std::move(in), std::move(transforms), data);
|
||||
}
|
||||
|
||||
/// Transforms and returns the WGSL source `in`, transformed using
|
||||
/// a transform of type `TRANSFORM`.
|
||||
/// @param in the input WGSL source
|
||||
/// @param data the optional DataMap to pass to Transform::Run()
|
||||
/// @return the transformed output
|
||||
template <typename... TRANSFORMS>
|
||||
Output Run(std::string in, const DataMap& data = {}) {
|
||||
auto file = std::make_unique<Source::File>("test", in);
|
||||
auto program = reader::wgsl::Parse(file.get());
|
||||
|
||||
// Keep this pointer alive after Transform() returns
|
||||
files_.emplace_back(std::move(file));
|
||||
|
||||
return Run<TRANSFORMS...>(std::move(program), data);
|
||||
}
|
||||
|
||||
/// Transforms and returns program `program`, transformed using a transform of
|
||||
/// type `TRANSFORM`.
|
||||
/// @param program the input Program
|
||||
/// @param data the optional DataMap to pass to Transform::Run()
|
||||
/// @return the transformed output
|
||||
template <typename... TRANSFORMS>
|
||||
Output Run(Program&& program, const DataMap& data = {}) {
|
||||
if (!program.IsValid()) {
|
||||
return Output(std::move(program));
|
||||
}
|
||||
|
||||
Manager manager;
|
||||
for (auto* transform_ptr :
|
||||
std::initializer_list<Transform*>{new TRANSFORMS()...}) {
|
||||
manager.append(std::unique_ptr<Transform>(transform_ptr));
|
||||
}
|
||||
return manager.Run(&program, data);
|
||||
}
|
||||
|
||||
/// @param program the input program
|
||||
/// @param data the optional DataMap to pass to Transform::Run()
|
||||
/// @return true if the transform should be run for the given input.
|
||||
template <typename TRANSFORM>
|
||||
bool ShouldRun(Program&& program, const DataMap& data = {}) {
|
||||
EXPECT_TRUE(program.IsValid()) << program.Diagnostics().str();
|
||||
return TRANSFORM().ShouldRun(&program, data);
|
||||
}
|
||||
|
||||
/// @param in the input WGSL source
|
||||
/// @param data the optional DataMap to pass to Transform::Run()
|
||||
/// @return true if the transform should be run for the given input.
|
||||
template <typename TRANSFORM>
|
||||
bool ShouldRun(std::string in, const DataMap& data = {}) {
|
||||
auto file = std::make_unique<Source::File>("test", in);
|
||||
auto program = reader::wgsl::Parse(file.get());
|
||||
return ShouldRun<TRANSFORM>(std::move(program), data);
|
||||
}
|
||||
|
||||
/// @param output the output of the transform
|
||||
/// @returns the output program as a WGSL string, or an error string if the
|
||||
/// program is not valid.
|
||||
std::string str(const Output& output) {
|
||||
return transform::str(output.program);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<Source::File>> files_;
|
||||
};
|
||||
|
||||
using TransformTest = TransformTestBase<testing::Test>;
|
||||
|
||||
template <typename T>
|
||||
using TransformTestWithParam = TransformTestBase<testing::TestWithParam<T>>;
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_TEST_HELPER_H_
|
||||
160
src/tint/transform/transform.cc
Normal file
160
src/tint/transform/transform.cc
Normal file
@@ -0,0 +1,160 @@
|
||||
// Copyright 2020 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/atomic_type.h"
|
||||
#include "src/tint/sem/block_statement.h"
|
||||
#include "src/tint/sem/depth_multisampled_texture_type.h"
|
||||
#include "src/tint/sem/for_loop_statement.h"
|
||||
#include "src/tint/sem/reference_type.h"
|
||||
#include "src/tint/sem/sampler_type.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::Transform);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::Data);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
Data::Data() = default;
|
||||
Data::Data(const Data&) = default;
|
||||
Data::~Data() = default;
|
||||
Data& Data::operator=(const Data&) = default;
|
||||
|
||||
DataMap::DataMap() = default;
|
||||
DataMap::DataMap(DataMap&&) = default;
|
||||
DataMap::~DataMap() = default;
|
||||
DataMap& DataMap::operator=(DataMap&&) = default;
|
||||
|
||||
Output::Output() = default;
|
||||
Output::Output(Program&& p) : program(std::move(p)) {}
|
||||
Transform::Transform() = default;
|
||||
Transform::~Transform() = default;
|
||||
|
||||
Output Transform::Run(const Program* program,
|
||||
const DataMap& data /* = {} */) const {
|
||||
ProgramBuilder builder;
|
||||
CloneContext ctx(&builder, program);
|
||||
Output output;
|
||||
Run(ctx, data, output.data);
|
||||
output.program = Program(std::move(builder));
|
||||
return output;
|
||||
}
|
||||
|
||||
void Transform::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
TINT_UNIMPLEMENTED(Transform, ctx.dst->Diagnostics())
|
||||
<< "Transform::Run() unimplemented for " << TypeInfo().name;
|
||||
}
|
||||
|
||||
bool Transform::ShouldRun(const Program*, const DataMap&) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
void Transform::RemoveStatement(CloneContext& ctx, const ast::Statement* stmt) {
|
||||
auto* sem = ctx.src->Sem().Get(stmt);
|
||||
if (auto* block = tint::As<sem::BlockStatement>(sem->Parent())) {
|
||||
ctx.Remove(block->Declaration()->statements, stmt);
|
||||
return;
|
||||
}
|
||||
if (tint::Is<sem::ForLoopStatement>(sem->Parent())) {
|
||||
ctx.Replace(stmt, static_cast<ast::Expression*>(nullptr));
|
||||
return;
|
||||
}
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
<< "unable to remove statement from parent of type "
|
||||
<< sem->TypeInfo().name;
|
||||
}
|
||||
|
||||
const ast::Type* Transform::CreateASTTypeFor(CloneContext& ctx,
|
||||
const sem::Type* ty) {
|
||||
if (ty->Is<sem::Void>()) {
|
||||
return ctx.dst->create<ast::Void>();
|
||||
}
|
||||
if (ty->Is<sem::I32>()) {
|
||||
return ctx.dst->create<ast::I32>();
|
||||
}
|
||||
if (ty->Is<sem::U32>()) {
|
||||
return ctx.dst->create<ast::U32>();
|
||||
}
|
||||
if (ty->Is<sem::F32>()) {
|
||||
return ctx.dst->create<ast::F32>();
|
||||
}
|
||||
if (ty->Is<sem::Bool>()) {
|
||||
return ctx.dst->create<ast::Bool>();
|
||||
}
|
||||
if (auto* m = ty->As<sem::Matrix>()) {
|
||||
auto* el = CreateASTTypeFor(ctx, m->type());
|
||||
return ctx.dst->create<ast::Matrix>(el, m->rows(), m->columns());
|
||||
}
|
||||
if (auto* v = ty->As<sem::Vector>()) {
|
||||
auto* el = CreateASTTypeFor(ctx, v->type());
|
||||
return ctx.dst->create<ast::Vector>(el, v->Width());
|
||||
}
|
||||
if (auto* a = ty->As<sem::Array>()) {
|
||||
auto* el = CreateASTTypeFor(ctx, a->ElemType());
|
||||
ast::AttributeList attrs;
|
||||
if (!a->IsStrideImplicit()) {
|
||||
attrs.emplace_back(ctx.dst->create<ast::StrideAttribute>(a->Stride()));
|
||||
}
|
||||
if (a->IsRuntimeSized()) {
|
||||
return ctx.dst->ty.array(el, nullptr, std::move(attrs));
|
||||
} else {
|
||||
return ctx.dst->ty.array(el, a->Count(), std::move(attrs));
|
||||
}
|
||||
}
|
||||
if (auto* s = ty->As<sem::Struct>()) {
|
||||
return ctx.dst->create<ast::TypeName>(ctx.Clone(s->Declaration()->name));
|
||||
}
|
||||
if (auto* s = ty->As<sem::Reference>()) {
|
||||
return CreateASTTypeFor(ctx, s->StoreType());
|
||||
}
|
||||
if (auto* a = ty->As<sem::Atomic>()) {
|
||||
return ctx.dst->create<ast::Atomic>(CreateASTTypeFor(ctx, a->Type()));
|
||||
}
|
||||
if (auto* t = ty->As<sem::DepthTexture>()) {
|
||||
return ctx.dst->create<ast::DepthTexture>(t->dim());
|
||||
}
|
||||
if (auto* t = ty->As<sem::DepthMultisampledTexture>()) {
|
||||
return ctx.dst->create<ast::DepthMultisampledTexture>(t->dim());
|
||||
}
|
||||
if (ty->Is<sem::ExternalTexture>()) {
|
||||
return ctx.dst->create<ast::ExternalTexture>();
|
||||
}
|
||||
if (auto* t = ty->As<sem::MultisampledTexture>()) {
|
||||
return ctx.dst->create<ast::MultisampledTexture>(
|
||||
t->dim(), CreateASTTypeFor(ctx, t->type()));
|
||||
}
|
||||
if (auto* t = ty->As<sem::SampledTexture>()) {
|
||||
return ctx.dst->create<ast::SampledTexture>(
|
||||
t->dim(), CreateASTTypeFor(ctx, t->type()));
|
||||
}
|
||||
if (auto* t = ty->As<sem::StorageTexture>()) {
|
||||
return ctx.dst->create<ast::StorageTexture>(
|
||||
t->dim(), t->texel_format(), CreateASTTypeFor(ctx, t->type()),
|
||||
t->access());
|
||||
}
|
||||
if (auto* s = ty->As<sem::Sampler>()) {
|
||||
return ctx.dst->create<ast::Sampler>(s->kind());
|
||||
}
|
||||
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
|
||||
<< "Unhandled type: " << ty->TypeInfo().name;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
199
src/tint/transform/transform.h
Normal file
199
src/tint/transform/transform.h
Normal file
@@ -0,0 +1,199 @@
|
||||
// Copyright 2020 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_TRANSFORM_H_
|
||||
#define SRC_TINT_TRANSFORM_TRANSFORM_H_
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/castable.h"
|
||||
#include "src/tint/program.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// Data is the base class for transforms that accept extra input or emit extra
|
||||
/// output information along with a Program.
|
||||
class Data : public Castable<Data> {
|
||||
public:
|
||||
/// Constructor
|
||||
Data();
|
||||
|
||||
/// Copy constructor
|
||||
Data(const Data&);
|
||||
|
||||
/// Destructor
|
||||
~Data() override;
|
||||
|
||||
/// Assignment operator
|
||||
/// @returns this Data
|
||||
Data& operator=(const Data&);
|
||||
};
|
||||
|
||||
/// DataMap is a map of Data unique pointers keyed by the Data's ClassID.
|
||||
class DataMap {
|
||||
public:
|
||||
/// Constructor
|
||||
DataMap();
|
||||
|
||||
/// Move constructor
|
||||
DataMap(DataMap&&);
|
||||
|
||||
/// Constructor
|
||||
/// @param data_unique_ptrs a variadic list of additional data unique_ptrs
|
||||
/// produced by the transform
|
||||
template <typename... DATA>
|
||||
explicit DataMap(DATA... data_unique_ptrs) {
|
||||
PutAll(std::forward<DATA>(data_unique_ptrs)...);
|
||||
}
|
||||
|
||||
/// Destructor
|
||||
~DataMap();
|
||||
|
||||
/// Move assignment operator
|
||||
/// @param rhs the DataMap to move into this DataMap
|
||||
/// @return this DataMap
|
||||
DataMap& operator=(DataMap&& rhs);
|
||||
|
||||
/// Adds the data into DataMap keyed by the ClassID of type T.
|
||||
/// @param data the data to add to the DataMap
|
||||
template <typename T>
|
||||
void Put(std::unique_ptr<T>&& data) {
|
||||
static_assert(std::is_base_of<Data, T>::value,
|
||||
"T does not derive from Data");
|
||||
map_[&TypeInfo::Of<T>()] = std::move(data);
|
||||
}
|
||||
|
||||
/// Creates the data of type `T` with the provided arguments and adds it into
|
||||
/// DataMap keyed by the ClassID of type T.
|
||||
/// @param args the arguments forwarded to the constructor for type T
|
||||
template <typename T, typename... ARGS>
|
||||
void Add(ARGS&&... args) {
|
||||
Put(std::make_unique<T>(std::forward<ARGS>(args)...));
|
||||
}
|
||||
|
||||
/// @returns a pointer to the Data placed into the DataMap with a call to
|
||||
/// Put()
|
||||
template <typename T>
|
||||
T const* Get() const {
|
||||
auto it = map_.find(&TypeInfo::Of<T>());
|
||||
if (it == map_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return static_cast<T*>(it->second.get());
|
||||
}
|
||||
|
||||
/// Add moves all the data from other into this DataMap
|
||||
/// @param other the DataMap to move into this DataMap
|
||||
void Add(DataMap&& other) {
|
||||
for (auto& it : other.map_) {
|
||||
map_.emplace(it.first, std::move(it.second));
|
||||
}
|
||||
other.map_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T0>
|
||||
void PutAll(T0&& first) {
|
||||
Put(std::forward<T0>(first));
|
||||
}
|
||||
|
||||
template <typename T0, typename... Tn>
|
||||
void PutAll(T0&& first, Tn&&... remainder) {
|
||||
Put(std::forward<T0>(first));
|
||||
PutAll(std::forward<Tn>(remainder)...);
|
||||
}
|
||||
|
||||
std::unordered_map<const TypeInfo*, std::unique_ptr<Data>> map_;
|
||||
};
|
||||
|
||||
/// The return type of Run()
|
||||
class Output {
|
||||
public:
|
||||
/// Constructor
|
||||
Output();
|
||||
|
||||
/// Constructor
|
||||
/// @param program the program to move into this Output
|
||||
explicit Output(Program&& program);
|
||||
|
||||
/// Constructor
|
||||
/// @param program_ the program to move into this Output
|
||||
/// @param data_ a variadic list of additional data unique_ptrs produced by
|
||||
/// the transform
|
||||
template <typename... DATA>
|
||||
Output(Program&& program_, DATA... data_)
|
||||
: program(std::move(program_)), data(std::forward<DATA>(data_)...) {}
|
||||
|
||||
/// The transformed program. May be empty on error.
|
||||
Program program;
|
||||
|
||||
/// Extra output generated by the transforms.
|
||||
DataMap data;
|
||||
};
|
||||
|
||||
/// Interface for Program transforms
|
||||
class Transform : public Castable<Transform> {
|
||||
public:
|
||||
/// Constructor
|
||||
Transform();
|
||||
/// Destructor
|
||||
~Transform() override;
|
||||
|
||||
/// Runs the transform on `program`, returning the transformation result.
|
||||
/// @param program the source program to transform
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns the transformation result
|
||||
virtual Output Run(const Program* program, const DataMap& data = {}) const;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
virtual bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
virtual void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const;
|
||||
|
||||
/// Removes the statement `stmt` from the transformed program.
|
||||
/// RemoveStatement handles edge cases, like statements in the initializer and
|
||||
/// continuing of for-loops.
|
||||
/// @param ctx the clone context
|
||||
/// @param stmt the statement to remove when the program is cloned
|
||||
static void RemoveStatement(CloneContext& ctx, const ast::Statement* stmt);
|
||||
|
||||
/// CreateASTTypeFor constructs new ast::Type nodes that reconstructs the
|
||||
/// semantic type `ty`.
|
||||
/// @param ctx the clone context
|
||||
/// @param ty the semantic type to reconstruct
|
||||
/// @returns a ast::Type that when resolved, will produce the semantic type
|
||||
/// `ty`.
|
||||
static const ast::Type* CreateASTTypeFor(CloneContext& ctx,
|
||||
const sem::Type* ty);
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_TRANSFORM_H_
|
||||
123
src/tint/transform/transform_test.cc
Normal file
123
src/tint/transform/transform_test.cc
Normal file
@@ -0,0 +1,123 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
#include "src/tint/clone_context.h"
|
||||
#include "src/tint/program_builder.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
// Inherit from Transform so we have access to protected methods
|
||||
struct CreateASTTypeForTest : public testing::Test, public Transform {
|
||||
Output Run(const Program*, const DataMap&) const override { return {}; }
|
||||
|
||||
const ast::Type* create(
|
||||
std::function<sem::Type*(ProgramBuilder&)> create_sem_type) {
|
||||
ProgramBuilder sem_type_builder;
|
||||
auto* sem_type = create_sem_type(sem_type_builder);
|
||||
Program program(std::move(sem_type_builder));
|
||||
CloneContext ctx(&ast_type_builder, &program, false);
|
||||
return CreateASTTypeFor(ctx, sem_type);
|
||||
}
|
||||
|
||||
ProgramBuilder ast_type_builder;
|
||||
};
|
||||
|
||||
TEST_F(CreateASTTypeForTest, Basic) {
|
||||
EXPECT_TRUE(create([](ProgramBuilder& b) {
|
||||
return b.create<sem::I32>();
|
||||
})->Is<ast::I32>());
|
||||
EXPECT_TRUE(create([](ProgramBuilder& b) {
|
||||
return b.create<sem::U32>();
|
||||
})->Is<ast::U32>());
|
||||
EXPECT_TRUE(create([](ProgramBuilder& b) {
|
||||
return b.create<sem::F32>();
|
||||
})->Is<ast::F32>());
|
||||
EXPECT_TRUE(create([](ProgramBuilder& b) {
|
||||
return b.create<sem::Bool>();
|
||||
})->Is<ast::Bool>());
|
||||
EXPECT_TRUE(create([](ProgramBuilder& b) {
|
||||
return b.create<sem::Void>();
|
||||
})->Is<ast::Void>());
|
||||
}
|
||||
|
||||
TEST_F(CreateASTTypeForTest, Matrix) {
|
||||
auto* mat = create([](ProgramBuilder& b) {
|
||||
auto* column_type = b.create<sem::Vector>(b.create<sem::F32>(), 2u);
|
||||
return b.create<sem::Matrix>(column_type, 3u);
|
||||
});
|
||||
ASSERT_TRUE(mat->Is<ast::Matrix>());
|
||||
ASSERT_TRUE(mat->As<ast::Matrix>()->type->Is<ast::F32>());
|
||||
ASSERT_EQ(mat->As<ast::Matrix>()->columns, 3u);
|
||||
ASSERT_EQ(mat->As<ast::Matrix>()->rows, 2u);
|
||||
}
|
||||
|
||||
TEST_F(CreateASTTypeForTest, Vector) {
|
||||
auto* vec = create([](ProgramBuilder& b) {
|
||||
return b.create<sem::Vector>(b.create<sem::F32>(), 2);
|
||||
});
|
||||
ASSERT_TRUE(vec->Is<ast::Vector>());
|
||||
ASSERT_TRUE(vec->As<ast::Vector>()->type->Is<ast::F32>());
|
||||
ASSERT_EQ(vec->As<ast::Vector>()->width, 2u);
|
||||
}
|
||||
|
||||
TEST_F(CreateASTTypeForTest, ArrayImplicitStride) {
|
||||
auto* arr = create([](ProgramBuilder& b) {
|
||||
return b.create<sem::Array>(b.create<sem::F32>(), 2, 4, 4, 32u, 32u);
|
||||
});
|
||||
ASSERT_TRUE(arr->Is<ast::Array>());
|
||||
ASSERT_TRUE(arr->As<ast::Array>()->type->Is<ast::F32>());
|
||||
ASSERT_EQ(arr->As<ast::Array>()->attributes.size(), 0u);
|
||||
|
||||
auto* size = arr->As<ast::Array>()->count->As<ast::IntLiteralExpression>();
|
||||
ASSERT_NE(size, nullptr);
|
||||
EXPECT_EQ(size->ValueAsI32(), 2);
|
||||
}
|
||||
|
||||
TEST_F(CreateASTTypeForTest, ArrayNonImplicitStride) {
|
||||
auto* arr = create([](ProgramBuilder& b) {
|
||||
return b.create<sem::Array>(b.create<sem::F32>(), 2, 4, 4, 64u, 32u);
|
||||
});
|
||||
ASSERT_TRUE(arr->Is<ast::Array>());
|
||||
ASSERT_TRUE(arr->As<ast::Array>()->type->Is<ast::F32>());
|
||||
ASSERT_EQ(arr->As<ast::Array>()->attributes.size(), 1u);
|
||||
ASSERT_TRUE(arr->As<ast::Array>()->attributes[0]->Is<ast::StrideAttribute>());
|
||||
ASSERT_EQ(
|
||||
arr->As<ast::Array>()->attributes[0]->As<ast::StrideAttribute>()->stride,
|
||||
64u);
|
||||
|
||||
auto* size = arr->As<ast::Array>()->count->As<ast::IntLiteralExpression>();
|
||||
ASSERT_NE(size, nullptr);
|
||||
EXPECT_EQ(size->ValueAsI32(), 2);
|
||||
}
|
||||
|
||||
TEST_F(CreateASTTypeForTest, Struct) {
|
||||
auto* str = create([](ProgramBuilder& b) {
|
||||
auto* decl = b.Structure("S", {}, {});
|
||||
return b.create<sem::Struct>(decl, decl->name, sem::StructMemberList{},
|
||||
4 /* align */, 4 /* size */,
|
||||
4 /* size_no_padding */);
|
||||
});
|
||||
ASSERT_TRUE(str->Is<ast::TypeName>());
|
||||
EXPECT_EQ(ast_type_builder.Symbols().NameFor(str->As<ast::TypeName>()->name),
|
||||
"S");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
99
src/tint/transform/unshadow.cc
Normal file
99
src/tint/transform/unshadow.cc
Normal file
@@ -0,0 +1,99 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/unshadow.h"
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/block_statement.h"
|
||||
#include "src/tint/sem/function.h"
|
||||
#include "src/tint/sem/statement.h"
|
||||
#include "src/tint/sem/variable.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::Unshadow);
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// The PIMPL state for the Unshadow transform
|
||||
struct Unshadow::State {
|
||||
/// The clone context
|
||||
CloneContext& ctx;
|
||||
|
||||
/// Constructor
|
||||
/// @param context the clone context
|
||||
explicit State(CloneContext& context) : ctx(context) {}
|
||||
|
||||
/// Performs the transformation
|
||||
void Run() {
|
||||
auto& sem = ctx.src->Sem();
|
||||
|
||||
// Maps a variable to its new name.
|
||||
std::unordered_map<const sem::Variable*, Symbol> renamed_to;
|
||||
|
||||
auto rename = [&](const sem::Variable* var) -> const ast::Variable* {
|
||||
auto* decl = var->Declaration();
|
||||
auto name = ctx.src->Symbols().NameFor(decl->symbol);
|
||||
auto symbol = ctx.dst->Symbols().New(name);
|
||||
renamed_to.emplace(var, symbol);
|
||||
|
||||
auto source = ctx.Clone(decl->source);
|
||||
auto* type = ctx.Clone(decl->type);
|
||||
auto* constructor = ctx.Clone(decl->constructor);
|
||||
auto attributes = ctx.Clone(decl->attributes);
|
||||
return ctx.dst->create<ast::Variable>(
|
||||
source, symbol, decl->declared_storage_class, decl->declared_access,
|
||||
type, decl->is_const, decl->is_overridable, constructor, attributes);
|
||||
};
|
||||
|
||||
ctx.ReplaceAll([&](const ast::Variable* var) -> const ast::Variable* {
|
||||
if (auto* local = sem.Get<sem::LocalVariable>(var)) {
|
||||
if (local->Shadows()) {
|
||||
return rename(local);
|
||||
}
|
||||
}
|
||||
if (auto* param = sem.Get<sem::Parameter>(var)) {
|
||||
if (param->Shadows()) {
|
||||
return rename(param);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
ctx.ReplaceAll([&](const ast::IdentifierExpression* ident)
|
||||
-> const tint::ast::IdentifierExpression* {
|
||||
if (auto* user = sem.Get<sem::VariableUser>(ident)) {
|
||||
auto it = renamed_to.find(user->Variable());
|
||||
if (it != renamed_to.end()) {
|
||||
return ctx.dst->Expr(it->second);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
ctx.Clone();
|
||||
}
|
||||
};
|
||||
|
||||
Unshadow::Unshadow() = default;
|
||||
|
||||
Unshadow::~Unshadow() = default;
|
||||
|
||||
void Unshadow::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
State(ctx).Run();
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
50
src/tint/transform/unshadow.h
Normal file
50
src/tint/transform/unshadow.h
Normal file
@@ -0,0 +1,50 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_UNSHADOW_H_
|
||||
#define SRC_TINT_TRANSFORM_UNSHADOW_H_
|
||||
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
|
||||
/// Unshadow is a Transform that renames any variables that shadow another
|
||||
/// variable.
|
||||
class Unshadow : public Castable<Unshadow, Transform> {
|
||||
public:
|
||||
/// Constructor
|
||||
Unshadow();
|
||||
|
||||
/// Destructor
|
||||
~Unshadow() override;
|
||||
|
||||
protected:
|
||||
struct State;
|
||||
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_UNSHADOW_H_
|
||||
609
src/tint/transform/unshadow_test.cc
Normal file
609
src/tint/transform/unshadow_test.cc
Normal file
@@ -0,0 +1,609 @@
|
||||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/unshadow.h"
|
||||
|
||||
#include "src/tint/transform/test_helper.h"
|
||||
|
||||
namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using UnshadowTest = TransformTest;
|
||||
|
||||
TEST_F(UnshadowTest, EmptyModule) {
|
||||
auto* src = "";
|
||||
auto* expect = "";
|
||||
|
||||
auto got = Run<Unshadow>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(UnshadowTest, Noop) {
|
||||
auto* src = R"(
|
||||
var<private> a : i32;
|
||||
|
||||
let b : i32 = 1;
|
||||
|
||||
fn F(c : i32) {
|
||||
var d : i32;
|
||||
let e : i32 = 1;
|
||||
{
|
||||
var f : i32;
|
||||
let g : i32 = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<Unshadow>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(UnshadowTest, LocalShadowsAlias) {
|
||||
auto* src = R"(
|
||||
type a = i32;
|
||||
|
||||
fn X() {
|
||||
var a = false;
|
||||
}
|
||||
|
||||
fn Y() {
|
||||
let a = true;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
type a = i32;
|
||||
|
||||
fn X() {
|
||||
var a_1 = false;
|
||||
}
|
||||
|
||||
fn Y() {
|
||||
let a_2 = true;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(UnshadowTest, LocalShadowsAlias_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
fn X() {
|
||||
var a = false;
|
||||
}
|
||||
|
||||
fn Y() {
|
||||
let a = true;
|
||||
}
|
||||
|
||||
type a = i32;
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn X() {
|
||||
var a_1 = false;
|
||||
}
|
||||
|
||||
fn Y() {
|
||||
let a_2 = true;
|
||||
}
|
||||
|
||||
type a = i32;
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(UnshadowTest, LocalShadowsStruct) {
|
||||
auto* src = R"(
|
||||
struct a {
|
||||
m : i32;
|
||||
};
|
||||
|
||||
fn X() {
|
||||
var a = true;
|
||||
}
|
||||
|
||||
fn Y() {
|
||||
let a = false;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct a {
|
||||
m : i32;
|
||||
}
|
||||
|
||||
fn X() {
|
||||
var a_1 = true;
|
||||
}
|
||||
|
||||
fn Y() {
|
||||
let a_2 = false;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(UnshadowTest, LocalShadowsStruct_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
fn X() {
|
||||
var a = true;
|
||||
}
|
||||
|
||||
fn Y() {
|
||||
let a = false;
|
||||
}
|
||||
|
||||
struct a {
|
||||
m : i32;
|
||||
};
|
||||
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn X() {
|
||||
var a_1 = true;
|
||||
}
|
||||
|
||||
fn Y() {
|
||||
let a_2 = false;
|
||||
}
|
||||
|
||||
struct a {
|
||||
m : i32;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(UnshadowTest, LocalShadowsFunction) {
|
||||
auto* src = R"(
|
||||
fn a() {
|
||||
var a = true;
|
||||
var b = false;
|
||||
}
|
||||
|
||||
fn b() {
|
||||
let a = true;
|
||||
let b = false;
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn a() {
|
||||
var a_1 = true;
|
||||
var b_1 = false;
|
||||
}
|
||||
|
||||
fn b() {
|
||||
let a_2 = true;
|
||||
let b_2 = false;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(UnshadowTest, LocalShadowsFunction_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
fn b() {
|
||||
let a = true;
|
||||
let b = false;
|
||||
}
|
||||
|
||||
fn a() {
|
||||
var a = true;
|
||||
var b = false;
|
||||
}
|
||||
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn b() {
|
||||
let a_1 = true;
|
||||
let b_1 = false;
|
||||
}
|
||||
|
||||
fn a() {
|
||||
var a_2 = true;
|
||||
var b_2 = false;
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(UnshadowTest, LocalShadowsGlobalVar) {
|
||||
auto* src = R"(
|
||||
var<private> a : i32;
|
||||
|
||||
fn X() {
|
||||
var a = (a == 123);
|
||||
}
|
||||
|
||||
fn Y() {
|
||||
let a = (a == 321);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
var<private> a : i32;
|
||||
|
||||
fn X() {
|
||||
var a_1 = (a == 123);
|
||||
}
|
||||
|
||||
fn Y() {
|
||||
let a_2 = (a == 321);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(UnshadowTest, LocalShadowsGlobalVar_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
fn X() {
|
||||
var a = (a == 123);
|
||||
}
|
||||
|
||||
fn Y() {
|
||||
let a = (a == 321);
|
||||
}
|
||||
|
||||
var<private> a : i32;
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn X() {
|
||||
var a_1 = (a == 123);
|
||||
}
|
||||
|
||||
fn Y() {
|
||||
let a_2 = (a == 321);
|
||||
}
|
||||
|
||||
var<private> a : i32;
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(UnshadowTest, LocalShadowsGlobalLet) {
|
||||
auto* src = R"(
|
||||
let a : i32 = 1;
|
||||
|
||||
fn X() {
|
||||
var a = (a == 123);
|
||||
}
|
||||
|
||||
fn Y() {
|
||||
let a = (a == 321);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
let a : i32 = 1;
|
||||
|
||||
fn X() {
|
||||
var a_1 = (a == 123);
|
||||
}
|
||||
|
||||
fn Y() {
|
||||
let a_2 = (a == 321);
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(UnshadowTest, LocalShadowsGlobalLet_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
fn X() {
|
||||
var a = (a == 123);
|
||||
}
|
||||
|
||||
fn Y() {
|
||||
let a = (a == 321);
|
||||
}
|
||||
|
||||
let a : i32 = 1;
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn X() {
|
||||
var a_1 = (a == 123);
|
||||
}
|
||||
|
||||
fn Y() {
|
||||
let a_2 = (a == 321);
|
||||
}
|
||||
|
||||
let a : i32 = 1;
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(UnshadowTest, LocalShadowsLocalVar) {
|
||||
auto* src = R"(
|
||||
fn X() {
|
||||
var a : i32;
|
||||
{
|
||||
var a = (a == 123);
|
||||
}
|
||||
{
|
||||
let a = (a == 321);
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn X() {
|
||||
var a : i32;
|
||||
{
|
||||
var a_1 = (a == 123);
|
||||
}
|
||||
{
|
||||
let a_2 = (a == 321);
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(UnshadowTest, LocalShadowsLocalLet) {
|
||||
auto* src = R"(
|
||||
fn X() {
|
||||
let a = 1;
|
||||
{
|
||||
var a = (a == 123);
|
||||
}
|
||||
{
|
||||
let a = (a == 321);
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn X() {
|
||||
let a = 1;
|
||||
{
|
||||
var a_1 = (a == 123);
|
||||
}
|
||||
{
|
||||
let a_2 = (a == 321);
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(UnshadowTest, LocalShadowsParam) {
|
||||
auto* src = R"(
|
||||
fn F(a : i32) {
|
||||
{
|
||||
var a = (a == 123);
|
||||
}
|
||||
{
|
||||
let a = (a == 321);
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn F(a : i32) {
|
||||
{
|
||||
var a_1 = (a == 123);
|
||||
}
|
||||
{
|
||||
let a_2 = (a == 321);
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(UnshadowTest, ParamShadowsFunction) {
|
||||
auto* src = R"(
|
||||
fn a(a : i32) {
|
||||
{
|
||||
var a = (a == 123);
|
||||
}
|
||||
{
|
||||
let a = (a == 321);
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn a(a_1 : i32) {
|
||||
{
|
||||
var a_2 = (a_1 == 123);
|
||||
}
|
||||
{
|
||||
let a_3 = (a_1 == 321);
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(UnshadowTest, ParamShadowsGlobalVar) {
|
||||
auto* src = R"(
|
||||
var<private> a : i32;
|
||||
|
||||
fn F(a : bool) {
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
var<private> a : i32;
|
||||
|
||||
fn F(a_1 : bool) {
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(UnshadowTest, ParamShadowsGlobalLet) {
|
||||
auto* src = R"(
|
||||
let a : i32 = 1;
|
||||
|
||||
fn F(a : bool) {
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
let a : i32 = 1;
|
||||
|
||||
fn F(a_1 : bool) {
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(UnshadowTest, ParamShadowsGlobalLet_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
fn F(a : bool) {
|
||||
}
|
||||
|
||||
let a : i32 = 1;
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn F(a_1 : bool) {
|
||||
}
|
||||
|
||||
let a : i32 = 1;
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(UnshadowTest, ParamShadowsAlias) {
|
||||
auto* src = R"(
|
||||
type a = i32;
|
||||
|
||||
fn F(a : a) {
|
||||
{
|
||||
var a = (a == 123);
|
||||
}
|
||||
{
|
||||
let a = (a == 321);
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
type a = i32;
|
||||
|
||||
fn F(a_1 : a) {
|
||||
{
|
||||
var a_2 = (a_1 == 123);
|
||||
}
|
||||
{
|
||||
let a_3 = (a_1 == 321);
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(UnshadowTest, ParamShadowsAlias_OutOfOrder) {
|
||||
auto* src = R"(
|
||||
fn F(a : a) {
|
||||
{
|
||||
var a = (a == 123);
|
||||
}
|
||||
{
|
||||
let a = (a == 321);
|
||||
}
|
||||
}
|
||||
|
||||
type a = i32;
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn F(a_1 : a) {
|
||||
{
|
||||
var a_2 = (a_1 == 123);
|
||||
}
|
||||
{
|
||||
let a_3 = (a_1 == 321);
|
||||
}
|
||||
}
|
||||
|
||||
type a = i32;
|
||||
)";
|
||||
|
||||
auto got = Run<Unshadow>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
327
src/tint/transform/utils/hoist_to_decl_before.cc
Normal file
327
src/tint/transform/utils/hoist_to_decl_before.cc
Normal file
@@ -0,0 +1,327 @@
|
||||
// Copyright 2022 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/utils/hoist_to_decl_before.h"
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "src/tint/ast/variable_decl_statement.h"
|
||||
#include "src/tint/sem/block_statement.h"
|
||||
#include "src/tint/sem/for_loop_statement.h"
|
||||
#include "src/tint/sem/if_statement.h"
|
||||
#include "src/tint/sem/reference_type.h"
|
||||
#include "src/tint/sem/variable.h"
|
||||
#include "src/tint/utils/reverse.h"
|
||||
|
||||
namespace tint::transform {
|
||||
|
||||
/// Private implementation of HoistToDeclBefore transform
|
||||
class HoistToDeclBefore::State {
|
||||
CloneContext& ctx;
|
||||
ProgramBuilder& b;
|
||||
|
||||
/// Holds information about a for-loop that needs to be decomposed into a
|
||||
/// loop, so that declaration statements can be inserted before the
|
||||
/// condition expression or continuing statement.
|
||||
struct LoopInfo {
|
||||
ast::StatementList cond_decls;
|
||||
ast::StatementList cont_decls;
|
||||
};
|
||||
|
||||
/// Holds information about 'if's with 'else-if' statements that need to be
|
||||
/// decomposed into 'if {else}' so that declaration statements can be
|
||||
/// inserted before the condition expression.
|
||||
struct IfInfo {
|
||||
/// Info for each else-if that needs decomposing
|
||||
struct ElseIfInfo {
|
||||
/// Decls to insert before condition
|
||||
ast::StatementList cond_decls;
|
||||
};
|
||||
|
||||
/// 'else if's that need to be decomposed to 'else { if }'
|
||||
std::unordered_map<const sem::ElseStatement*, ElseIfInfo> else_ifs;
|
||||
};
|
||||
|
||||
/// For-loops that need to be decomposed to loops.
|
||||
std::unordered_map<const sem::ForLoopStatement*, LoopInfo> loops;
|
||||
|
||||
/// If statements with 'else if's that need to be decomposed to 'else { if
|
||||
/// }'
|
||||
std::unordered_map<const sem::IfStatement*, IfInfo> ifs;
|
||||
|
||||
// Inserts `decl` before `sem_expr`, possibly marking a for-loop to be
|
||||
// converted to a loop, or an else-if to an else { if }.
|
||||
bool InsertBefore(const sem::Expression* sem_expr,
|
||||
const ast::VariableDeclStatement* decl) {
|
||||
auto* sem_stmt = sem_expr->Stmt();
|
||||
auto* stmt = sem_stmt->Declaration();
|
||||
|
||||
if (auto* else_if = sem_stmt->As<sem::ElseStatement>()) {
|
||||
// Expression used in 'else if' condition.
|
||||
// Need to convert 'else if' to 'else { if }'.
|
||||
auto& if_info = ifs[else_if->Parent()->As<sem::IfStatement>()];
|
||||
if_info.else_ifs[else_if].cond_decls.push_back(decl);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (auto* fl = sem_stmt->As<sem::ForLoopStatement>()) {
|
||||
// Expression used in for-loop condition.
|
||||
// For-loop needs to be decomposed to a loop.
|
||||
loops[fl].cond_decls.emplace_back(decl);
|
||||
return true;
|
||||
}
|
||||
|
||||
auto* parent = sem_stmt->Parent(); // The statement's parent
|
||||
if (auto* block = parent->As<sem::BlockStatement>()) {
|
||||
// Expression's statement sits in a block. Simple case.
|
||||
// Insert the decl before the parent statement
|
||||
ctx.InsertBefore(block->Declaration()->statements, stmt, decl);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (auto* fl = parent->As<sem::ForLoopStatement>()) {
|
||||
// Expression is used in a for-loop. These require special care.
|
||||
if (fl->Declaration()->initializer == stmt) {
|
||||
// Expression used in for-loop initializer.
|
||||
// Insert the let above the for-loop.
|
||||
ctx.InsertBefore(fl->Block()->Declaration()->statements,
|
||||
fl->Declaration(), decl);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (fl->Declaration()->continuing == stmt) {
|
||||
// Expression used in for-loop continuing.
|
||||
// For-loop needs to be decomposed to a loop.
|
||||
loops[fl].cont_decls.emplace_back(decl);
|
||||
return true;
|
||||
}
|
||||
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "unhandled use of expression in for-loop";
|
||||
return false;
|
||||
}
|
||||
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "unhandled expression parent statement type: "
|
||||
<< parent->TypeInfo().name;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Converts any for-loops marked for conversion to loops, inserting
|
||||
// registered declaration statements before the condition or continuing
|
||||
// statement.
|
||||
void ForLoopsToLoops() {
|
||||
if (loops.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// At least one for-loop needs to be transformed into a loop.
|
||||
ctx.ReplaceAll(
|
||||
[&](const ast::ForLoopStatement* stmt) -> const ast::Statement* {
|
||||
auto& sem = ctx.src->Sem();
|
||||
|
||||
if (auto* fl = sem.Get(stmt)) {
|
||||
if (auto it = loops.find(fl); it != loops.end()) {
|
||||
auto& info = it->second;
|
||||
auto* for_loop = fl->Declaration();
|
||||
// For-loop needs to be decomposed to a loop.
|
||||
// Build the loop body's statements.
|
||||
// Start with any let declarations for the conditional
|
||||
// expression.
|
||||
auto body_stmts = info.cond_decls;
|
||||
// If the for-loop has a condition, emit this next as:
|
||||
// if (!cond) { break; }
|
||||
if (auto* cond = for_loop->condition) {
|
||||
// !condition
|
||||
auto* not_cond = b.create<ast::UnaryOpExpression>(
|
||||
ast::UnaryOp::kNot, ctx.Clone(cond));
|
||||
// { break; }
|
||||
auto* break_body = b.Block(b.create<ast::BreakStatement>());
|
||||
// if (!condition) { break; }
|
||||
body_stmts.emplace_back(b.If(not_cond, break_body));
|
||||
}
|
||||
// Next emit the for-loop body
|
||||
body_stmts.emplace_back(ctx.Clone(for_loop->body));
|
||||
|
||||
// Finally create the continuing block if there was one.
|
||||
const ast::BlockStatement* continuing = nullptr;
|
||||
if (auto* cont = for_loop->continuing) {
|
||||
// Continuing block starts with any let declarations used by
|
||||
// the continuing.
|
||||
auto cont_stmts = info.cont_decls;
|
||||
cont_stmts.emplace_back(ctx.Clone(cont));
|
||||
continuing = b.Block(cont_stmts);
|
||||
}
|
||||
|
||||
auto* body = b.Block(body_stmts);
|
||||
auto* loop = b.Loop(body, continuing);
|
||||
if (auto* init = for_loop->initializer) {
|
||||
return b.Block(ctx.Clone(init), loop);
|
||||
}
|
||||
return loop;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
}
|
||||
|
||||
void ElseIfsToElseWithNestedIfs() {
|
||||
if (ifs.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
ctx.ReplaceAll([&](const ast::IfStatement* if_stmt) //
|
||||
-> const ast::IfStatement* {
|
||||
auto& sem = ctx.src->Sem();
|
||||
auto* sem_if = sem.Get(if_stmt);
|
||||
if (!sem_if) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto it = ifs.find(sem_if);
|
||||
if (it == ifs.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto& if_info = it->second;
|
||||
|
||||
// This if statement has "else if"s that need to be converted to "else
|
||||
// { if }"s
|
||||
|
||||
ast::ElseStatementList next_else_stmts;
|
||||
next_else_stmts.reserve(if_stmt->else_statements.size());
|
||||
|
||||
for (auto* else_stmt : utils::Reverse(if_stmt->else_statements)) {
|
||||
if (else_stmt->condition == nullptr) {
|
||||
// The last 'else', keep as is
|
||||
next_else_stmts.insert(next_else_stmts.begin(), ctx.Clone(else_stmt));
|
||||
|
||||
} else {
|
||||
auto* sem_else_if = sem.Get(else_stmt);
|
||||
|
||||
auto it2 = if_info.else_ifs.find(sem_else_if);
|
||||
if (it2 == if_info.else_ifs.end()) {
|
||||
// 'else if' we don't need to modify (no decls to insert), so
|
||||
// keep as is
|
||||
next_else_stmts.insert(next_else_stmts.begin(),
|
||||
ctx.Clone(else_stmt));
|
||||
|
||||
} else {
|
||||
// 'else if' we need to replace with 'else <decls> { if }'
|
||||
auto& else_if_info = it2->second;
|
||||
|
||||
// Build the else body's statements, starting with let decls for
|
||||
// the conditional expression
|
||||
auto& body_stmts = else_if_info.cond_decls;
|
||||
|
||||
// Build nested if
|
||||
auto* cond = ctx.Clone(else_stmt->condition);
|
||||
auto* body = ctx.Clone(else_stmt->body);
|
||||
body_stmts.emplace_back(b.If(cond, body, next_else_stmts));
|
||||
|
||||
// Build else
|
||||
auto* else_with_nested_if = b.Else(b.Block(body_stmts));
|
||||
|
||||
// This will be used in parent if (either another nested if, or
|
||||
// top-level if)
|
||||
next_else_stmts = {else_with_nested_if};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build a new top-level if with new else statements
|
||||
if (next_else_stmts.empty()) {
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "Expected else statements to insert into new if";
|
||||
}
|
||||
auto* cond = ctx.Clone(if_stmt->condition);
|
||||
auto* body = ctx.Clone(if_stmt->body);
|
||||
auto* new_if = b.If(cond, body, next_else_stmts);
|
||||
return new_if;
|
||||
});
|
||||
}
|
||||
|
||||
public:
|
||||
/// Constructor
|
||||
/// @param ctx_in the clone context
|
||||
explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {}
|
||||
|
||||
/// Hoists `expr` to a `let` or `var` with optional `decl_name`, inserting it
|
||||
/// before `before_expr`.
|
||||
/// @param before_expr expression to insert `expr` before
|
||||
/// @param expr expression to hoist
|
||||
/// @param as_const hoist to `let` if true, otherwise to `var`
|
||||
/// @param decl_name optional name to use for the variable/constant name
|
||||
/// @return true on success
|
||||
bool HoistToDeclBefore(const sem::Expression* before_expr,
|
||||
const ast::Expression* expr,
|
||||
bool as_const,
|
||||
const char* decl_name = "") {
|
||||
auto name = b.Symbols().New(decl_name);
|
||||
|
||||
auto* sem_expr = ctx.src->Sem().Get(expr);
|
||||
bool is_ref =
|
||||
sem_expr &&
|
||||
!sem_expr->Is<sem::VariableUser>() // Don't need to take a ref to a var
|
||||
&& sem_expr->Type()->Is<sem::Reference>();
|
||||
|
||||
auto* expr_clone = ctx.Clone(expr);
|
||||
if (is_ref) {
|
||||
expr_clone = b.AddressOf(expr_clone);
|
||||
}
|
||||
|
||||
// Construct the let/var that holds the hoisted expr
|
||||
auto* v = as_const ? b.Const(name, nullptr, expr_clone)
|
||||
: b.Var(name, nullptr, expr_clone);
|
||||
auto* decl = b.Decl(v);
|
||||
|
||||
if (!InsertBefore(before_expr, decl)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Replace the initializer expression with a reference to the let
|
||||
const ast::Expression* new_expr = b.Expr(name);
|
||||
if (is_ref) {
|
||||
new_expr = b.Deref(new_expr);
|
||||
}
|
||||
ctx.Replace(expr, new_expr);
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Applies any scheduled insertions from previous calls to Add() to
|
||||
/// CloneContext. Call this once before ctx.Clone().
|
||||
/// @return true on success
|
||||
bool Apply() {
|
||||
ForLoopsToLoops();
|
||||
ElseIfsToElseWithNestedIfs();
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
HoistToDeclBefore::HoistToDeclBefore(CloneContext& ctx)
|
||||
: state_(std::make_unique<State>(ctx)) {}
|
||||
|
||||
HoistToDeclBefore::~HoistToDeclBefore() {}
|
||||
|
||||
bool HoistToDeclBefore::Add(const sem::Expression* before_expr,
|
||||
const ast::Expression* expr,
|
||||
bool as_const,
|
||||
const char* decl_name) {
|
||||
return state_->HoistToDeclBefore(before_expr, expr, as_const, decl_name);
|
||||
}
|
||||
|
||||
bool HoistToDeclBefore::Apply() {
|
||||
return state_->Apply();
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
61
src/tint/transform/utils/hoist_to_decl_before.h
Normal file
61
src/tint/transform/utils/hoist_to_decl_before.h
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright 2022 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_TRANSFORM_UTILS_HOIST_TO_DECL_BEFORE_H_
|
||||
#define SRC_TINT_TRANSFORM_UTILS_HOIST_TO_DECL_BEFORE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "src/tint/sem/expression.h"
|
||||
#include "src/tint/transform/transform.h"
|
||||
|
||||
namespace tint::transform {
|
||||
|
||||
/// Utility class that can be used to hoist expressions before other
|
||||
/// expressions, possibly converting 'for' loops to 'loop's and 'else if to
|
||||
// 'else if'.
|
||||
class HoistToDeclBefore {
|
||||
public:
|
||||
/// Constructor
|
||||
/// @param ctx the clone context
|
||||
explicit HoistToDeclBefore(CloneContext& ctx);
|
||||
|
||||
/// Destructor
|
||||
~HoistToDeclBefore();
|
||||
|
||||
/// Hoists `expr` to a `let` or `var` with optional `decl_name`, inserting it
|
||||
/// before `before_expr`.
|
||||
/// @param before_expr expression to insert `expr` before
|
||||
/// @param expr expression to hoist
|
||||
/// @param as_const hoist to `let` if true, otherwise to `var`
|
||||
/// @param decl_name optional name to use for the variable/constant name
|
||||
/// @return true on success
|
||||
bool Add(const sem::Expression* before_expr,
|
||||
const ast::Expression* expr,
|
||||
bool as_const,
|
||||
const char* decl_name = "");
|
||||
|
||||
/// Applies any scheduled insertions from previous calls to Add() to
|
||||
/// CloneContext. Call this once before ctx.Clone().
|
||||
/// @return true on success
|
||||
bool Apply();
|
||||
|
||||
private:
|
||||
class State;
|
||||
std::unique_ptr<State> state_;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
||||
#endif // SRC_TINT_TRANSFORM_UTILS_HOIST_TO_DECL_BEFORE_H_
|
||||
291
src/tint/transform/utils/hoist_to_decl_before_test.cc
Normal file
291
src/tint/transform/utils/hoist_to_decl_before_test.cc
Normal file
@@ -0,0 +1,291 @@
|
||||
// Copyright 2022 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "gtest/gtest-spi.h"
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/transform/test_helper.h"
|
||||
#include "src/tint/transform/utils/hoist_to_decl_before.h"
|
||||
|
||||
namespace tint::transform {
|
||||
namespace {
|
||||
|
||||
using HoistToDeclBeforeTest = ::testing::Test;
|
||||
|
||||
TEST_F(HoistToDeclBeforeTest, VarInit) {
|
||||
// fn f() {
|
||||
// var a = 1;
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* expr = b.Expr(1);
|
||||
auto* var = b.Decl(b.Var("a", nullptr, expr));
|
||||
b.Func("f", {}, b.ty.void_(), {var});
|
||||
|
||||
Program original(std::move(b));
|
||||
ProgramBuilder cloned_b;
|
||||
CloneContext ctx(&cloned_b, &original);
|
||||
|
||||
HoistToDeclBefore hoistToDeclBefore(ctx);
|
||||
auto* sem_expr = ctx.src->Sem().Get(expr);
|
||||
hoistToDeclBefore.Add(sem_expr, expr, true);
|
||||
hoistToDeclBefore.Apply();
|
||||
|
||||
ctx.Clone();
|
||||
Program cloned(std::move(cloned_b));
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
let tint_symbol = 1;
|
||||
var a = tint_symbol;
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_EQ(expect, str(cloned));
|
||||
}
|
||||
|
||||
TEST_F(HoistToDeclBeforeTest, ForLoopInit) {
|
||||
// fn f() {
|
||||
// for(var a = 1; true; ) {
|
||||
// }
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* expr = b.Expr(1);
|
||||
auto* s =
|
||||
b.For(b.Decl(b.Var("a", nullptr, expr)), b.Expr(true), {}, b.Block());
|
||||
b.Func("f", {}, b.ty.void_(), {s});
|
||||
|
||||
Program original(std::move(b));
|
||||
ProgramBuilder cloned_b;
|
||||
CloneContext ctx(&cloned_b, &original);
|
||||
|
||||
HoistToDeclBefore hoistToDeclBefore(ctx);
|
||||
auto* sem_expr = ctx.src->Sem().Get(expr);
|
||||
hoistToDeclBefore.Add(sem_expr, expr, true);
|
||||
hoistToDeclBefore.Apply();
|
||||
|
||||
ctx.Clone();
|
||||
Program cloned(std::move(cloned_b));
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
let tint_symbol = 1;
|
||||
for(var a = tint_symbol; true; ) {
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_EQ(expect, str(cloned));
|
||||
}
|
||||
|
||||
TEST_F(HoistToDeclBeforeTest, ForLoopCond) {
|
||||
// fn f() {
|
||||
// var a : bool;
|
||||
// for(; a; ) {
|
||||
// }
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* var = b.Decl(b.Var("a", b.ty.bool_()));
|
||||
auto* expr = b.Expr("a");
|
||||
auto* s = b.For({}, expr, {}, b.Block());
|
||||
b.Func("f", {}, b.ty.void_(), {var, s});
|
||||
|
||||
Program original(std::move(b));
|
||||
ProgramBuilder cloned_b;
|
||||
CloneContext ctx(&cloned_b, &original);
|
||||
|
||||
HoistToDeclBefore hoistToDeclBefore(ctx);
|
||||
auto* sem_expr = ctx.src->Sem().Get(expr);
|
||||
hoistToDeclBefore.Add(sem_expr, expr, true);
|
||||
hoistToDeclBefore.Apply();
|
||||
|
||||
ctx.Clone();
|
||||
Program cloned(std::move(cloned_b));
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var a : bool;
|
||||
loop {
|
||||
let tint_symbol = a;
|
||||
if (!(tint_symbol)) {
|
||||
break;
|
||||
}
|
||||
{
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_EQ(expect, str(cloned));
|
||||
}
|
||||
|
||||
TEST_F(HoistToDeclBeforeTest, ForLoopCont) {
|
||||
// fn f() {
|
||||
// for(; true; var a = 1) {
|
||||
// }
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* expr = b.Expr(1);
|
||||
auto* s =
|
||||
b.For({}, b.Expr(true), b.Decl(b.Var("a", nullptr, expr)), b.Block());
|
||||
b.Func("f", {}, b.ty.void_(), {s});
|
||||
|
||||
Program original(std::move(b));
|
||||
ProgramBuilder cloned_b;
|
||||
CloneContext ctx(&cloned_b, &original);
|
||||
|
||||
HoistToDeclBefore hoistToDeclBefore(ctx);
|
||||
auto* sem_expr = ctx.src->Sem().Get(expr);
|
||||
hoistToDeclBefore.Add(sem_expr, expr, true);
|
||||
hoistToDeclBefore.Apply();
|
||||
|
||||
ctx.Clone();
|
||||
Program cloned(std::move(cloned_b));
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
loop {
|
||||
if (!(true)) {
|
||||
break;
|
||||
}
|
||||
{
|
||||
}
|
||||
|
||||
continuing {
|
||||
let tint_symbol = 1;
|
||||
var a = tint_symbol;
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_EQ(expect, str(cloned));
|
||||
}
|
||||
|
||||
TEST_F(HoistToDeclBeforeTest, ElseIf) {
|
||||
// fn f() {
|
||||
// var a : bool;
|
||||
// if (true) {
|
||||
// } else if (a) {
|
||||
// } else {
|
||||
// }
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* var = b.Decl(b.Var("a", b.ty.bool_()));
|
||||
auto* expr = b.Expr("a");
|
||||
auto* s = b.If(b.Expr(true), b.Block(), //
|
||||
b.Else(expr, b.Block()), //
|
||||
b.Else(b.Block()));
|
||||
b.Func("f", {}, b.ty.void_(), {var, s});
|
||||
|
||||
Program original(std::move(b));
|
||||
ProgramBuilder cloned_b;
|
||||
CloneContext ctx(&cloned_b, &original);
|
||||
|
||||
HoistToDeclBefore hoistToDeclBefore(ctx);
|
||||
auto* sem_expr = ctx.src->Sem().Get(expr);
|
||||
hoistToDeclBefore.Add(sem_expr, expr, true);
|
||||
hoistToDeclBefore.Apply();
|
||||
|
||||
ctx.Clone();
|
||||
Program cloned(std::move(cloned_b));
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var a : bool;
|
||||
if (true) {
|
||||
} else {
|
||||
let tint_symbol = a;
|
||||
if (tint_symbol) {
|
||||
} else {
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_EQ(expect, str(cloned));
|
||||
}
|
||||
|
||||
TEST_F(HoistToDeclBeforeTest, Array1D) {
|
||||
// fn f() {
|
||||
// var a : array<i32, 10>;
|
||||
// var b = a[0];
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* var1 = b.Decl(b.Var("a", b.ty.array<ProgramBuilder::i32, 10>()));
|
||||
auto* expr = b.IndexAccessor("a", 0);
|
||||
auto* var2 = b.Decl(b.Var("b", nullptr, expr));
|
||||
b.Func("f", {}, b.ty.void_(), {var1, var2});
|
||||
|
||||
Program original(std::move(b));
|
||||
ProgramBuilder cloned_b;
|
||||
CloneContext ctx(&cloned_b, &original);
|
||||
|
||||
HoistToDeclBefore hoistToDeclBefore(ctx);
|
||||
auto* sem_expr = ctx.src->Sem().Get(expr);
|
||||
hoistToDeclBefore.Add(sem_expr, expr, true);
|
||||
hoistToDeclBefore.Apply();
|
||||
|
||||
ctx.Clone();
|
||||
Program cloned(std::move(cloned_b));
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var a : array<i32, 10>;
|
||||
let tint_symbol = &(a[0]);
|
||||
var b = *(tint_symbol);
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_EQ(expect, str(cloned));
|
||||
}
|
||||
|
||||
TEST_F(HoistToDeclBeforeTest, Array2D) {
|
||||
// fn f() {
|
||||
// var a : array<array<i32, 10>, 10>;
|
||||
// var b = a[0][0];
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
|
||||
auto* var1 =
|
||||
b.Decl(b.Var("a", b.ty.array(b.ty.array<ProgramBuilder::i32, 10>(), 10)));
|
||||
auto* expr = b.IndexAccessor(b.IndexAccessor("a", 0), 0);
|
||||
auto* var2 = b.Decl(b.Var("b", nullptr, expr));
|
||||
b.Func("f", {}, b.ty.void_(), {var1, var2});
|
||||
|
||||
Program original(std::move(b));
|
||||
ProgramBuilder cloned_b;
|
||||
CloneContext ctx(&cloned_b, &original);
|
||||
|
||||
HoistToDeclBefore hoistToDeclBefore(ctx);
|
||||
auto* sem_expr = ctx.src->Sem().Get(expr);
|
||||
hoistToDeclBefore.Add(sem_expr, expr, true);
|
||||
hoistToDeclBefore.Apply();
|
||||
|
||||
ctx.Clone();
|
||||
Program cloned(std::move(cloned_b));
|
||||
|
||||
auto* expect = R"(
|
||||
fn f() {
|
||||
var a : array<array<i32, 10>, 10>;
|
||||
let tint_symbol = &(a[0][0]);
|
||||
var b = *(tint_symbol);
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_EQ(expect, str(cloned));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tint::transform
|
||||
68
src/tint/transform/var_for_dynamic_index.cc
Normal file
68
src/tint/transform/var_for_dynamic_index.cc
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright 2022 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/var_for_dynamic_index.h"
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/transform/utils/hoist_to_decl_before.h"
|
||||
|
||||
namespace tint::transform {
|
||||
|
||||
VarForDynamicIndex::VarForDynamicIndex() = default;
|
||||
|
||||
VarForDynamicIndex::~VarForDynamicIndex() = default;
|
||||
|
||||
void VarForDynamicIndex::Run(CloneContext& ctx,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
HoistToDeclBefore hoist_to_decl_before(ctx);
|
||||
|
||||
// Extracts array and matrix values that are dynamically indexed to a
|
||||
// temporary `var` local that is then indexed.
|
||||
auto dynamic_index_to_var =
|
||||
[&](const ast::IndexAccessorExpression* access_expr) {
|
||||
auto* index_expr = access_expr->index;
|
||||
auto* object_expr = access_expr->object;
|
||||
auto& sem = ctx.src->Sem();
|
||||
|
||||
if (sem.Get(index_expr)->ConstantValue()) {
|
||||
// Index expression resolves to a compile time value.
|
||||
// As this isn't a dynamic index, we can ignore this.
|
||||
return true;
|
||||
}
|
||||
|
||||
auto* indexed = sem.Get(object_expr);
|
||||
if (!indexed->Type()->IsAnyOf<sem::Array, sem::Matrix>()) {
|
||||
// We only care about array and matrices.
|
||||
return true;
|
||||
}
|
||||
|
||||
// TODO(bclayton): group multiple accesses in the same object.
|
||||
// e.g. arr[i] + arr[i+1] // Don't create two vars for this
|
||||
return hoist_to_decl_before.Add(indexed, object_expr, false,
|
||||
"var_for_index");
|
||||
};
|
||||
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
if (auto* access_expr = node->As<ast::IndexAccessorExpression>()) {
|
||||
if (!dynamic_index_to_var(access_expr)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
hoist_to_decl_before.Apply();
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user