tint->dawn: Shuffle source tree in preperation of merging repos

docs/    -> docs/tint/
fuzzers/ -> src/tint/fuzzers/
samples/ -> src/tint/cmd/
src/     -> src/tint/
test/    -> test/tint/

BUG=tint:1418,tint:1433

Change-Id: Id2aa79f989aef3245b80ef4aa37a27ff16cd700b
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/80482
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Ryan Harrison <rharrison@chromium.org>
This commit is contained in:
Ryan Harrison
2022-02-21 15:19:07 +00:00
committed by Tint LUCI CQ
parent 38f1e9c75c
commit dbc13af287
12231 changed files with 4897 additions and 4871 deletions

View File

@@ -0,0 +1,51 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/add_empty_entry_point.h"
#include <utility>
#include "src/tint/program_builder.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::AddEmptyEntryPoint);
namespace tint {
namespace transform {
AddEmptyEntryPoint::AddEmptyEntryPoint() = default;
AddEmptyEntryPoint::~AddEmptyEntryPoint() = default;
bool AddEmptyEntryPoint::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* func : program->AST().Functions()) {
if (func->IsEntryPoint()) {
return false;
}
}
return true;
}
void AddEmptyEntryPoint::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
ctx.dst->Func(ctx.dst->Symbols().New("unused_entry_point"), {},
ctx.dst->ty.void_(), {},
{ctx.dst->Stage(ast::PipelineStage::kCompute),
ctx.dst->WorkgroupSize(1)});
ctx.Clone();
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,52 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_ADD_EMPTY_ENTRY_POINT_H_
#define SRC_TINT_TRANSFORM_ADD_EMPTY_ENTRY_POINT_H_
#include "src/tint/transform/transform.h"
namespace tint {
namespace transform {
/// Add an empty entry point to the module, if no other entry points exist.
class AddEmptyEntryPoint : public Castable<AddEmptyEntryPoint, Transform> {
public:
/// Constructor
AddEmptyEntryPoint();
/// Destructor
~AddEmptyEntryPoint() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_ADD_EMPTY_ENTRY_POINT_H_

View File

@@ -0,0 +1,88 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/add_empty_entry_point.h"
#include <utility>
#include "src/tint/transform/test_helper.h"
namespace tint {
namespace transform {
namespace {
using AddEmptyEntryPointTest = TransformTest;
TEST_F(AddEmptyEntryPointTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_TRUE(ShouldRun<AddEmptyEntryPoint>(src));
}
TEST_F(AddEmptyEntryPointTest, ShouldRunExistingEntryPoint) {
auto* src = R"(
[[stage(compute), workgroup_size(1)]]
fn existing() {}
)";
EXPECT_FALSE(ShouldRun<AddEmptyEntryPoint>(src));
}
TEST_F(AddEmptyEntryPointTest, EmptyModule) {
auto* src = R"()";
auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn unused_entry_point() {
}
)";
auto got = Run<AddEmptyEntryPoint>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(AddEmptyEntryPointTest, ExistingEntryPoint) {
auto* src = R"(
@stage(fragment)
fn main() {
}
)";
auto* expect = src;
auto got = Run<AddEmptyEntryPoint>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(AddEmptyEntryPointTest, NameClash) {
auto* src = R"(var<private> unused_entry_point : f32;)";
auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn unused_entry_point_1() {
}
var<private> unused_entry_point : f32;
)";
auto got = Run<AddEmptyEntryPoint>(src);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,122 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/add_spirv_block_attribute.h"
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "src/tint/program_builder.h"
#include "src/tint/sem/variable.h"
#include "src/tint/utils/map.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::AddSpirvBlockAttribute);
TINT_INSTANTIATE_TYPEINFO(
tint::transform::AddSpirvBlockAttribute::SpirvBlockAttribute);
namespace tint {
namespace transform {
AddSpirvBlockAttribute::AddSpirvBlockAttribute() = default;
AddSpirvBlockAttribute::~AddSpirvBlockAttribute() = default;
void AddSpirvBlockAttribute::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
auto& sem = ctx.src->Sem();
// Collect the set of structs that are nested in other types.
std::unordered_set<const sem::Struct*> nested_structs;
for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* arr = sem.Get<sem::Array>(node->As<ast::Array>())) {
if (auto* nested_str = arr->ElemType()->As<sem::Struct>()) {
nested_structs.insert(nested_str);
}
} else if (auto* str = sem.Get<sem::Struct>(node->As<ast::Struct>())) {
for (auto* member : str->Members()) {
if (auto* nested_str = member->Type()->As<sem::Struct>()) {
nested_structs.insert(nested_str);
}
}
}
}
// A map from a type in the source program to a block-decorated wrapper that
// contains it in the destination program.
std::unordered_map<const sem::Type*, const ast::Struct*> wrapper_structs;
// Process global variables that are buffers.
for (auto* var : ctx.src->AST().GlobalVariables()) {
auto* sem_var = sem.Get<sem::GlobalVariable>(var);
if (var->declared_storage_class != ast::StorageClass::kStorage &&
var->declared_storage_class != ast::StorageClass::kUniform) {
continue;
}
auto* ty = sem.Get(var->type);
auto* str = ty->As<sem::Struct>();
if (!str || nested_structs.count(str)) {
const char* kMemberName = "inner";
// This is a non-struct or a struct that is nested somewhere else, so we
// need to wrap it first.
auto* wrapper = utils::GetOrCreate(wrapper_structs, ty, [&]() {
auto* block =
ctx.dst->ASTNodes().Create<SpirvBlockAttribute>(ctx.dst->ID());
auto wrapper_name = ctx.src->Symbols().NameFor(var->symbol) + "_block";
auto* ret = ctx.dst->create<ast::Struct>(
ctx.dst->Symbols().New(wrapper_name),
ast::StructMemberList{
ctx.dst->Member(kMemberName, CreateASTTypeFor(ctx, ty))},
ast::AttributeList{block});
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), var, ret);
return ret;
});
ctx.Replace(var->type, ctx.dst->ty.Of(wrapper));
// Insert a member accessor to get the original type from the wrapper at
// any usage of the original variable.
for (auto* user : sem_var->Users()) {
ctx.Replace(
user->Declaration(),
ctx.dst->MemberAccessor(ctx.Clone(var->symbol), kMemberName));
}
} else {
// Add a block attribute to this struct directly.
auto* block =
ctx.dst->ASTNodes().Create<SpirvBlockAttribute>(ctx.dst->ID());
ctx.InsertFront(str->Declaration()->attributes, block);
}
}
ctx.Clone();
}
AddSpirvBlockAttribute::SpirvBlockAttribute::SpirvBlockAttribute(ProgramID pid)
: Base(pid) {}
AddSpirvBlockAttribute::SpirvBlockAttribute::~SpirvBlockAttribute() = default;
std::string AddSpirvBlockAttribute::SpirvBlockAttribute::InternalName() const {
return "spirv_block";
}
const AddSpirvBlockAttribute::SpirvBlockAttribute*
AddSpirvBlockAttribute::SpirvBlockAttribute::Clone(CloneContext* ctx) const {
return ctx->dst->ASTNodes()
.Create<AddSpirvBlockAttribute::SpirvBlockAttribute>(ctx->dst->ID());
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,76 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_ADD_SPIRV_BLOCK_ATTRIBUTE_H_
#define SRC_TINT_TRANSFORM_ADD_SPIRV_BLOCK_ATTRIBUTE_H_
#include <string>
#include "src/tint/ast/internal_attribute.h"
#include "src/tint/transform/transform.h"
namespace tint {
namespace transform {
/// AddSpirvBlockAttribute is a transform that adds an
/// `@internal(spirv_block)` attribute to any structure that is used as the
/// store type of a buffer. If that structure is nested inside another structure
/// or an array, then it is wrapped inside another structure which gets the
/// `@internal(spirv_block)` attribute instead.
class AddSpirvBlockAttribute
: public Castable<AddSpirvBlockAttribute, Transform> {
public:
/// SpirvBlockAttribute is an InternalAttribute that is used to decorate a
// structure that needs a SPIR-V block attribute.
class SpirvBlockAttribute
: public Castable<SpirvBlockAttribute, ast::InternalAttribute> {
public:
/// Constructor
/// @param program_id the identifier of the program that owns this node
explicit SpirvBlockAttribute(ProgramID program_id);
/// Destructor
~SpirvBlockAttribute() override;
/// @return a short description of the internal attribute which will be
/// displayed as `@internal(<name>)`
std::string InternalName() const override;
/// Performs a deep clone of this object using the CloneContext `ctx`.
/// @param ctx the clone context
/// @return the newly cloned object
const SpirvBlockAttribute* Clone(CloneContext* ctx) const override;
};
/// Constructor
AddSpirvBlockAttribute();
/// Destructor
~AddSpirvBlockAttribute() override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_ADD_SPIRV_BLOCK_ATTRIBUTE_H_

View File

@@ -0,0 +1,615 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/add_spirv_block_attribute.h"
#include <memory>
#include <utility>
#include "src/tint/transform/test_helper.h"
namespace tint {
namespace transform {
namespace {
using AddSpirvBlockAttributeTest = TransformTest;
TEST_F(AddSpirvBlockAttributeTest, EmptyModule) {
auto* src = "";
auto* expect = "";
auto got = Run<AddSpirvBlockAttribute>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, Noop_UsedForPrivateVar) {
auto* src = R"(
struct S {
f : f32;
}
var<private> p : S;
@stage(fragment)
fn main() {
p.f = 1.0;
}
)";
auto* expect = src;
auto got = Run<AddSpirvBlockAttribute>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, Noop_UsedForShaderIO) {
auto* src = R"(
struct S {
@location(0)
f : f32;
}
@stage(fragment)
fn main() -> S {
return S();
}
)";
auto* expect = src;
auto got = Run<AddSpirvBlockAttribute>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, BasicScalar) {
auto* src = R"(
@group(0) @binding(0)
var<uniform> u : f32;
@stage(fragment)
fn main() {
let f = u;
}
)";
auto* expect = R"(
@internal(spirv_block)
struct u_block {
inner : f32;
}
@group(0) @binding(0) var<uniform> u : u_block;
@stage(fragment)
fn main() {
let f = u.inner;
}
)";
auto got = Run<AddSpirvBlockAttribute>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, BasicArray) {
auto* src = R"(
@group(0) @binding(0)
var<uniform> u : array<vec4<f32>, 4u>;
@stage(fragment)
fn main() {
let a = u;
}
)";
auto* expect = R"(
@internal(spirv_block)
struct u_block {
inner : array<vec4<f32>, 4u>;
}
@group(0) @binding(0) var<uniform> u : u_block;
@stage(fragment)
fn main() {
let a = u.inner;
}
)";
auto got = Run<AddSpirvBlockAttribute>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, BasicArray_Alias) {
auto* src = R"(
type Numbers = array<vec4<f32>, 4u>;
@group(0) @binding(0)
var<uniform> u : Numbers;
@stage(fragment)
fn main() {
let a = u;
}
)";
auto* expect = R"(
type Numbers = array<vec4<f32>, 4u>;
@internal(spirv_block)
struct u_block {
inner : array<vec4<f32>, 4u>;
}
@group(0) @binding(0) var<uniform> u : u_block;
@stage(fragment)
fn main() {
let a = u.inner;
}
)";
auto got = Run<AddSpirvBlockAttribute>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, BasicStruct) {
auto* src = R"(
struct S {
f : f32;
};
@group(0) @binding(0)
var<uniform> u : S;
@stage(fragment)
fn main() {
let f = u.f;
}
)";
auto* expect = R"(
@internal(spirv_block)
struct S {
f : f32;
}
@group(0) @binding(0) var<uniform> u : S;
@stage(fragment)
fn main() {
let f = u.f;
}
)";
auto got = Run<AddSpirvBlockAttribute>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, Nested_OuterBuffer_InnerNotBuffer) {
auto* src = R"(
struct Inner {
f : f32;
};
struct Outer {
i : Inner;
};
@group(0) @binding(0)
var<uniform> u : Outer;
@stage(fragment)
fn main() {
let f = u.i.f;
}
)";
auto* expect = R"(
struct Inner {
f : f32;
}
@internal(spirv_block)
struct Outer {
i : Inner;
}
@group(0) @binding(0) var<uniform> u : Outer;
@stage(fragment)
fn main() {
let f = u.i.f;
}
)";
auto got = Run<AddSpirvBlockAttribute>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, Nested_OuterBuffer_InnerBuffer) {
auto* src = R"(
struct Inner {
f : f32;
};
struct Outer {
i : Inner;
};
@group(0) @binding(0)
var<uniform> u0 : Outer;
@group(0) @binding(1)
var<uniform> u1 : Inner;
@stage(fragment)
fn main() {
let f0 = u0.i.f;
let f1 = u1.f;
}
)";
auto* expect = R"(
struct Inner {
f : f32;
}
@internal(spirv_block)
struct Outer {
i : Inner;
}
@group(0) @binding(0) var<uniform> u0 : Outer;
@internal(spirv_block)
struct u1_block {
inner : Inner;
}
@group(0) @binding(1) var<uniform> u1 : u1_block;
@stage(fragment)
fn main() {
let f0 = u0.i.f;
let f1 = u1.inner.f;
}
)";
auto got = Run<AddSpirvBlockAttribute>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, Nested_OuterNotBuffer_InnerBuffer) {
auto* src = R"(
struct Inner {
f : f32;
};
struct Outer {
i : Inner;
};
var<private> p : Outer;
@group(0) @binding(1)
var<uniform> u : Inner;
@stage(fragment)
fn main() {
let f0 = p.i.f;
let f1 = u.f;
}
)";
auto* expect = R"(
struct Inner {
f : f32;
}
struct Outer {
i : Inner;
}
var<private> p : Outer;
@internal(spirv_block)
struct u_block {
inner : Inner;
}
@group(0) @binding(1) var<uniform> u : u_block;
@stage(fragment)
fn main() {
let f0 = p.i.f;
let f1 = u.inner.f;
}
)";
auto got = Run<AddSpirvBlockAttribute>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, Nested_InnerUsedForMultipleBuffers) {
auto* src = R"(
struct Inner {
f : f32;
};
struct S {
i : Inner;
};
@group(0) @binding(0)
var<uniform> u0 : S;
@group(0) @binding(1)
var<uniform> u1 : Inner;
@group(0) @binding(2)
var<uniform> u2 : Inner;
@stage(fragment)
fn main() {
let f0 = u0.i.f;
let f1 = u1.f;
let f2 = u2.f;
}
)";
auto* expect = R"(
struct Inner {
f : f32;
}
@internal(spirv_block)
struct S {
i : Inner;
}
@group(0) @binding(0) var<uniform> u0 : S;
@internal(spirv_block)
struct u1_block {
inner : Inner;
}
@group(0) @binding(1) var<uniform> u1 : u1_block;
@group(0) @binding(2) var<uniform> u2 : u1_block;
@stage(fragment)
fn main() {
let f0 = u0.i.f;
let f1 = u1.inner.f;
let f2 = u2.inner.f;
}
)";
auto got = Run<AddSpirvBlockAttribute>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, StructInArray) {
auto* src = R"(
struct S {
f : f32;
};
@group(0) @binding(0)
var<uniform> u : S;
@stage(fragment)
fn main() {
let f = u.f;
let a = array<S, 4>();
}
)";
auto* expect = R"(
struct S {
f : f32;
}
@internal(spirv_block)
struct u_block {
inner : S;
}
@group(0) @binding(0) var<uniform> u : u_block;
@stage(fragment)
fn main() {
let f = u.inner.f;
let a = array<S, 4>();
}
)";
auto got = Run<AddSpirvBlockAttribute>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, StructInArray_MultipleBuffers) {
auto* src = R"(
struct S {
f : f32;
};
@group(0) @binding(0)
var<uniform> u0 : S;
@group(0) @binding(1)
var<uniform> u1 : S;
@stage(fragment)
fn main() {
let f0 = u0.f;
let f1 = u1.f;
let a = array<S, 4>();
}
)";
auto* expect = R"(
struct S {
f : f32;
}
@internal(spirv_block)
struct u0_block {
inner : S;
}
@group(0) @binding(0) var<uniform> u0 : u0_block;
@group(0) @binding(1) var<uniform> u1 : u0_block;
@stage(fragment)
fn main() {
let f0 = u0.inner.f;
let f1 = u1.inner.f;
let a = array<S, 4>();
}
)";
auto got = Run<AddSpirvBlockAttribute>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest, Aliases_Nested_OuterBuffer_InnerBuffer) {
auto* src = R"(
struct Inner {
f : f32;
};
type MyInner = Inner;
struct Outer {
i : MyInner;
};
type MyOuter = Outer;
@group(0) @binding(0)
var<uniform> u0 : MyOuter;
@group(0) @binding(1)
var<uniform> u1 : MyInner;
@stage(fragment)
fn main() {
let f0 = u0.i.f;
let f1 = u1.f;
}
)";
auto* expect = R"(
struct Inner {
f : f32;
}
type MyInner = Inner;
@internal(spirv_block)
struct Outer {
i : MyInner;
}
type MyOuter = Outer;
@group(0) @binding(0) var<uniform> u0 : MyOuter;
@internal(spirv_block)
struct u1_block {
inner : Inner;
}
@group(0) @binding(1) var<uniform> u1 : u1_block;
@stage(fragment)
fn main() {
let f0 = u0.i.f;
let f1 = u1.inner.f;
}
)";
auto got = Run<AddSpirvBlockAttribute>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(AddSpirvBlockAttributeTest,
Aliases_Nested_OuterBuffer_InnerBuffer_OutOfOrder) {
auto* src = R"(
@stage(fragment)
fn main() {
let f0 = u0.i.f;
let f1 = u1.f;
}
@group(0) @binding(1)
var<uniform> u1 : MyInner;
type MyInner = Inner;
@group(0) @binding(0)
var<uniform> u0 : MyOuter;
type MyOuter = Outer;
struct Outer {
i : MyInner;
};
struct Inner {
f : f32;
};
)";
auto* expect = R"(
@stage(fragment)
fn main() {
let f0 = u0.i.f;
let f1 = u1.inner.f;
}
@internal(spirv_block)
struct u1_block {
inner : Inner;
}
@group(0) @binding(1) var<uniform> u1 : u1_block;
type MyInner = Inner;
@group(0) @binding(0) var<uniform> u0 : MyOuter;
type MyOuter = Outer;
@internal(spirv_block)
struct Outer {
i : MyInner;
}
struct Inner {
f : f32;
}
)";
auto got = Run<AddSpirvBlockAttribute>(src);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,233 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/array_length_from_uniform.h"
#include <memory>
#include <string>
#include <utility>
#include "src/tint/program_builder.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/variable.h"
#include "src/tint/transform/simplify_pointers.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform);
TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform::Config);
TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform::Result);
namespace tint {
namespace transform {
ArrayLengthFromUniform::ArrayLengthFromUniform() = default;
ArrayLengthFromUniform::~ArrayLengthFromUniform() = default;
/// Iterate over all arrayLength() builtins that operate on
/// storage buffer variables.
/// @param ctx the CloneContext.
/// @param functor of type void(const ast::CallExpression*, const
/// sem::VariableUser, const sem::GlobalVariable*). It takes in an
/// ast::CallExpression of the arrayLength call expression node, a
/// sem::VariableUser of the used storage buffer variable, and the
/// sem::GlobalVariable for the storage buffer.
template <typename F>
static void IterateArrayLengthOnStorageVar(CloneContext& ctx, F&& functor) {
auto& sem = ctx.src->Sem();
// Find all calls to the arrayLength() builtin.
for (auto* node : ctx.src->ASTNodes().Objects()) {
auto* call_expr = node->As<ast::CallExpression>();
if (!call_expr) {
continue;
}
auto* call = sem.Get(call_expr);
auto* builtin = call->Target()->As<sem::Builtin>();
if (!builtin || builtin->Type() != sem::BuiltinType::kArrayLength) {
continue;
}
// Get the storage buffer that contains the runtime array.
// Since we require SimplifyPointers, we can assume that the arrayLength()
// call has one of two forms:
// arrayLength(&struct_var.array_member)
// arrayLength(&array_var)
auto* param = call_expr->args[0]->As<ast::UnaryOpExpression>();
if (!param || param->op != ast::UnaryOp::kAddressOf) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "expected form of arrayLength argument to be &array_var or "
"&struct_var.array_member";
break;
}
auto* storage_buffer_expr = param->expr;
if (auto* accessor = param->expr->As<ast::MemberAccessorExpression>()) {
storage_buffer_expr = accessor->structure;
}
auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
if (!storage_buffer_sem) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "expected form of arrayLength argument to be &array_var or "
"&struct_var.array_member";
break;
}
// Get the index to use for the buffer size array.
auto* var = tint::As<sem::GlobalVariable>(storage_buffer_sem->Variable());
if (!var) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "storage buffer is not a global variable";
break;
}
functor(call_expr, storage_buffer_sem, var);
}
}
bool ArrayLengthFromUniform::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* fn : program->AST().Functions()) {
if (auto* sem_fn = program->Sem().Get(fn)) {
for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
if (builtin->Type() == sem::BuiltinType::kArrayLength) {
return true;
}
}
}
}
return false;
}
void ArrayLengthFromUniform::Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const {
auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform,
"missing transform data for " + std::string(TypeInfo().name));
return;
}
const char* kBufferSizeMemberName = "buffer_size";
// Determine the size of the buffer size array.
uint32_t max_buffer_size_index = 0;
IterateArrayLengthOnStorageVar(
ctx, [&](const ast::CallExpression*, const sem::VariableUser*,
const sem::GlobalVariable* var) {
auto binding = var->BindingPoint();
auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
return;
}
if (idx_itr->second > max_buffer_size_index) {
max_buffer_size_index = idx_itr->second;
}
});
// Get (or create, on first call) the uniform buffer that will receive the
// size of each storage buffer in the module.
const ast::Variable* buffer_size_ubo = nullptr;
auto get_ubo = [&]() {
if (!buffer_size_ubo) {
// Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
// We do this because UBOs require an element stride that is 16-byte
// aligned.
auto* buffer_size_struct = ctx.dst->Structure(
ctx.dst->Sym(),
{ctx.dst->Member(
kBufferSizeMemberName,
ctx.dst->ty.array(ctx.dst->ty.vec4(ctx.dst->ty.u32()),
(max_buffer_size_index / 4) + 1))});
buffer_size_ubo = ctx.dst->Global(
ctx.dst->Sym(), ctx.dst->ty.Of(buffer_size_struct),
ast::StorageClass::kUniform,
ast::AttributeList{ctx.dst->GroupAndBinding(
cfg->ubo_binding.group, cfg->ubo_binding.binding)});
}
return buffer_size_ubo;
};
std::unordered_set<uint32_t> used_size_indices;
IterateArrayLengthOnStorageVar(
ctx, [&](const ast::CallExpression* call_expr,
const sem::VariableUser* storage_buffer_sem,
const sem::GlobalVariable* var) {
auto binding = var->BindingPoint();
auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
return;
}
uint32_t size_index = idx_itr->second;
used_size_indices.insert(size_index);
// Load the total storage buffer size from the UBO.
uint32_t array_index = size_index / 4;
auto* vec_expr = ctx.dst->IndexAccessor(
ctx.dst->MemberAccessor(get_ubo()->symbol, kBufferSizeMemberName),
array_index);
uint32_t vec_index = size_index % 4;
auto* total_storage_buffer_size =
ctx.dst->IndexAccessor(vec_expr, vec_index);
// Calculate actual array length
// total_storage_buffer_size - array_offset
// array_length = ----------------------------------------
// array_stride
const ast::Expression* total_size = total_storage_buffer_size;
auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef();
const sem::Array* array_type = nullptr;
if (auto* str = storage_buffer_type->As<sem::Struct>()) {
// The variable is a struct, so subtract the byte offset of the array
// member.
auto* array_member_sem = str->Members().back();
array_type = array_member_sem->Type()->As<sem::Array>();
total_size = ctx.dst->Sub(total_storage_buffer_size,
array_member_sem->Offset());
} else if (auto* arr = storage_buffer_type->As<sem::Array>()) {
array_type = arr;
} else {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "expected form of arrayLength argument to be &array_var or "
"&struct_var.array_member";
return;
}
auto* array_length = ctx.dst->Div(total_size, array_type->Stride());
ctx.Replace(call_expr, array_length);
});
ctx.Clone();
outputs.Add<Result>(used_size_indices);
}
ArrayLengthFromUniform::Config::Config(sem::BindingPoint ubo_bp)
: ubo_binding(ubo_bp) {}
ArrayLengthFromUniform::Config::Config(const Config&) = default;
ArrayLengthFromUniform::Config& ArrayLengthFromUniform::Config::operator=(
const Config&) = default;
ArrayLengthFromUniform::Config::~Config() = default;
ArrayLengthFromUniform::Result::Result(
std::unordered_set<uint32_t> used_size_indices_in)
: used_size_indices(std::move(used_size_indices_in)) {}
ArrayLengthFromUniform::Result::Result(const Result&) = default;
ArrayLengthFromUniform::Result::~Result() = default;
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,125 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_ARRAY_LENGTH_FROM_UNIFORM_H_
#define SRC_TINT_TRANSFORM_ARRAY_LENGTH_FROM_UNIFORM_H_
#include <unordered_map>
#include <unordered_set>
#include "src/tint/sem/binding_point.h"
#include "src/tint/transform/transform.h"
namespace tint {
// Forward declarations
class CloneContext;
namespace transform {
/// ArrayLengthFromUniform is a transform that implements calls to arrayLength()
/// by calculating the length from the total size of the storage buffer, which
/// is received via a uniform buffer.
///
/// The generated uniform buffer will have the form:
/// ```
/// struct buffer_size_struct {
/// buffer_size : array<u32, 8>;
/// };
///
/// @group(0) @binding(30)
/// var<uniform> buffer_size_ubo : buffer_size_struct;
/// ```
/// The binding group and number used for this uniform buffer is provided via
/// the `Config` transform input. The `Config` struct also defines the mapping
/// from a storage buffer's `BindingPoint` to the array index that will be used
/// to get the size of that buffer.
///
/// This transform assumes that the `SimplifyPointers`
/// transforms have been run before it so that arguments to the arrayLength
/// builtin always have the form `&resource.array`.
///
/// @note Depends on the following transforms to have been run first:
/// * SimplifyPointers
class ArrayLengthFromUniform
: public Castable<ArrayLengthFromUniform, Transform> {
public:
/// Constructor
ArrayLengthFromUniform();
/// Destructor
~ArrayLengthFromUniform() override;
/// Configuration options for the ArrayLengthFromUniform transform.
struct Config : public Castable<Data, transform::Data> {
/// Constructor
/// @param ubo_bp the binding point to use for the generated uniform buffer.
explicit Config(sem::BindingPoint ubo_bp);
/// Copy constructor
Config(const Config&);
/// Copy assignment
/// @return this Config
Config& operator=(const Config&);
/// Destructor
~Config() override;
/// The binding point to use for the generated uniform buffer.
sem::BindingPoint ubo_binding;
/// The mapping from binding point to the index for the buffer size lookup.
std::unordered_map<sem::BindingPoint, uint32_t> bindpoint_to_size_index;
};
/// Information produced about what the transform did.
/// If there were no calls to the arrayLength() builtin, then no Result will
/// be emitted.
struct Result : public Castable<Result, transform::Data> {
/// Constructor
/// @param used_size_indices Indices into the UBO that are statically used.
explicit Result(std::unordered_set<uint32_t> used_size_indices);
/// Copy constructor
Result(const Result&);
/// Destructor
~Result() override;
/// Indices into the UBO that are statically used.
const std::unordered_set<uint32_t> used_size_indices;
};
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_ARRAY_LENGTH_FROM_UNIFORM_H_

View File

@@ -0,0 +1,589 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/array_length_from_uniform.h"
#include <utility>
#include "src/tint/transform/simplify_pointers.h"
#include "src/tint/transform/test_helper.h"
#include "src/tint/transform/unshadow.h"
namespace tint {
namespace transform {
namespace {
using ArrayLengthFromUniformTest = TransformTest;
TEST_F(ArrayLengthFromUniformTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src));
}
TEST_F(ArrayLengthFromUniformTest, ShouldRunNoArrayLength) {
auto* src = R"(
struct SB {
x : i32;
arr : array<i32>;
};
[[group(0), binding(0)]] var<storage, read> sb : SB;
[[stage(compute), workgroup_size(1)]]
fn main() {
}
)";
EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src));
}
TEST_F(ArrayLengthFromUniformTest, ShouldRunWithArrayLength) {
auto* src = R"(
struct SB {
x : i32;
arr : array<i32>;
};
[[group(0), binding(0)]] var<storage, read> sb : SB;
[[stage(compute), workgroup_size(1)]]
fn main() {
var len : u32 = arrayLength(&sb.arr);
}
)";
EXPECT_TRUE(ShouldRun<ArrayLengthFromUniform>(src));
}
TEST_F(ArrayLengthFromUniformTest, Error_MissingTransformData) {
auto* src = R"(
struct SB {
x : i32;
arr : array<i32>;
};
[[group(0), binding(0)]] var<storage, read> sb : SB;
[[stage(compute), workgroup_size(1)]]
fn main() {
var len : u32 = arrayLength(&sb.arr);
}
)";
auto* expect =
"error: missing transform data for "
"tint::transform::ArrayLengthFromUniform";
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ArrayLengthFromUniformTest, Basic) {
auto* src = R"(
@group(0) @binding(0) var<storage, read> sb : array<i32>;
@stage(compute) @workgroup_size(1)
fn main() {
var len : u32 = arrayLength(&sb);
}
)";
auto* expect = R"(
struct tint_symbol {
buffer_size : array<vec4<u32>, 1u>;
}
@group(0) @binding(30) var<uniform> tint_symbol_1 : tint_symbol;
@group(0) @binding(0) var<storage, read> sb : array<i32>;
@stage(compute) @workgroup_size(1)
fn main() {
var len : u32 = (tint_symbol_1.buffer_size[0u][0u] / 4u);
}
)";
ArrayLengthFromUniform::Config cfg({0, 30u});
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
DataMap data;
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(expect, str(got));
EXPECT_EQ(std::unordered_set<uint32_t>({0}),
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
}
TEST_F(ArrayLengthFromUniformTest, BasicInStruct) {
auto* src = R"(
struct SB {
x : i32;
arr : array<i32>;
};
@group(0) @binding(0) var<storage, read> sb : SB;
@stage(compute) @workgroup_size(1)
fn main() {
var len : u32 = arrayLength(&sb.arr);
}
)";
auto* expect = R"(
struct tint_symbol {
buffer_size : array<vec4<u32>, 1u>;
}
@group(0) @binding(30) var<uniform> tint_symbol_1 : tint_symbol;
struct SB {
x : i32;
arr : array<i32>;
}
@group(0) @binding(0) var<storage, read> sb : SB;
@stage(compute) @workgroup_size(1)
fn main() {
var len : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
}
)";
ArrayLengthFromUniform::Config cfg({0, 30u});
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
DataMap data;
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(expect, str(got));
EXPECT_EQ(std::unordered_set<uint32_t>({0}),
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
}
TEST_F(ArrayLengthFromUniformTest, WithStride) {
auto* src = R"(
@group(0) @binding(0) var<storage, read> sb : @stride(64) array<i32>;
@stage(compute) @workgroup_size(1)
fn main() {
var len : u32 = arrayLength(&sb);
}
)";
auto* expect = R"(
struct tint_symbol {
buffer_size : array<vec4<u32>, 1u>;
}
@group(0) @binding(30) var<uniform> tint_symbol_1 : tint_symbol;
@group(0) @binding(0) var<storage, read> sb : @stride(64) array<i32>;
@stage(compute) @workgroup_size(1)
fn main() {
var len : u32 = (tint_symbol_1.buffer_size[0u][0u] / 64u);
}
)";
ArrayLengthFromUniform::Config cfg({0, 30u});
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
DataMap data;
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(expect, str(got));
EXPECT_EQ(std::unordered_set<uint32_t>({0}),
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
}
TEST_F(ArrayLengthFromUniformTest, WithStride_InStruct) {
auto* src = R"(
struct SB {
x : i32;
y : f32;
arr : @stride(64) array<i32>;
};
@group(0) @binding(0) var<storage, read> sb : SB;
@stage(compute) @workgroup_size(1)
fn main() {
var len : u32 = arrayLength(&sb.arr);
}
)";
auto* expect = R"(
struct tint_symbol {
buffer_size : array<vec4<u32>, 1u>;
}
@group(0) @binding(30) var<uniform> tint_symbol_1 : tint_symbol;
struct SB {
x : i32;
y : f32;
arr : @stride(64) array<i32>;
}
@group(0) @binding(0) var<storage, read> sb : SB;
@stage(compute) @workgroup_size(1)
fn main() {
var len : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 8u) / 64u);
}
)";
ArrayLengthFromUniform::Config cfg({0, 30u});
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
DataMap data;
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(expect, str(got));
EXPECT_EQ(std::unordered_set<uint32_t>({0}),
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
}
TEST_F(ArrayLengthFromUniformTest, MultipleStorageBuffers) {
auto* src = R"(
struct SB1 {
x : i32;
arr1 : array<i32>;
};
struct SB2 {
x : i32;
arr2 : array<vec4<f32>>;
};
struct SB4 {
x : i32;
arr4 : array<vec4<f32>>;
};
@group(0) @binding(2) var<storage, read> sb1 : SB1;
@group(1) @binding(2) var<storage, read> sb2 : SB2;
@group(2) @binding(2) var<storage, read> sb3 : array<vec4<f32>>;
@group(3) @binding(2) var<storage, read> sb4 : SB4;
@group(4) @binding(2) var<storage, read> sb5 : array<vec4<f32>>;
@stage(compute) @workgroup_size(1)
fn main() {
var len1 : u32 = arrayLength(&(sb1.arr1));
var len2 : u32 = arrayLength(&(sb2.arr2));
var len3 : u32 = arrayLength(&sb3);
var len4 : u32 = arrayLength(&(sb4.arr4));
var len5 : u32 = arrayLength(&sb5);
var x : u32 = (len1 + len2 + len3 + len4 + len5);
}
)";
auto* expect = R"(
struct tint_symbol {
buffer_size : array<vec4<u32>, 2u>;
}
@group(0) @binding(30) var<uniform> tint_symbol_1 : tint_symbol;
struct SB1 {
x : i32;
arr1 : array<i32>;
}
struct SB2 {
x : i32;
arr2 : array<vec4<f32>>;
}
struct SB4 {
x : i32;
arr4 : array<vec4<f32>>;
}
@group(0) @binding(2) var<storage, read> sb1 : SB1;
@group(1) @binding(2) var<storage, read> sb2 : SB2;
@group(2) @binding(2) var<storage, read> sb3 : array<vec4<f32>>;
@group(3) @binding(2) var<storage, read> sb4 : SB4;
@group(4) @binding(2) var<storage, read> sb5 : array<vec4<f32>>;
@stage(compute) @workgroup_size(1)
fn main() {
var len1 : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
var len2 : u32 = ((tint_symbol_1.buffer_size[0u][1u] - 16u) / 16u);
var len3 : u32 = (tint_symbol_1.buffer_size[0u][2u] / 16u);
var len4 : u32 = ((tint_symbol_1.buffer_size[0u][3u] - 16u) / 16u);
var len5 : u32 = (tint_symbol_1.buffer_size[1u][0u] / 16u);
var x : u32 = ((((len1 + len2) + len3) + len4) + len5);
}
)";
ArrayLengthFromUniform::Config cfg({0, 30u});
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2u}, 0);
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{1u, 2u}, 1);
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{2u, 2u}, 2);
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{3u, 2u}, 3);
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{4u, 2u}, 4);
DataMap data;
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(expect, str(got));
EXPECT_EQ(std::unordered_set<uint32_t>({0, 1, 2, 3, 4}),
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
}
TEST_F(ArrayLengthFromUniformTest, MultipleUnusedStorageBuffers) {
auto* src = R"(
struct SB1 {
x : i32;
arr1 : array<i32>;
};
struct SB2 {
x : i32;
arr2 : array<vec4<f32>>;
};
struct SB4 {
x : i32;
arr4 : array<vec4<f32>>;
};
@group(0) @binding(2) var<storage, read> sb1 : SB1;
@group(1) @binding(2) var<storage, read> sb2 : SB2;
@group(2) @binding(2) var<storage, read> sb3 : array<vec4<f32>>;
@group(3) @binding(2) var<storage, read> sb4 : SB4;
@group(4) @binding(2) var<storage, read> sb5 : array<vec4<f32>>;
@stage(compute) @workgroup_size(1)
fn main() {
var len1 : u32 = arrayLength(&(sb1.arr1));
var len3 : u32 = arrayLength(&sb3);
var x : u32 = (len1 + len3);
}
)";
auto* expect = R"(
struct tint_symbol {
buffer_size : array<vec4<u32>, 1u>;
}
@group(0) @binding(30) var<uniform> tint_symbol_1 : tint_symbol;
struct SB1 {
x : i32;
arr1 : array<i32>;
}
struct SB2 {
x : i32;
arr2 : array<vec4<f32>>;
}
struct SB4 {
x : i32;
arr4 : array<vec4<f32>>;
}
@group(0) @binding(2) var<storage, read> sb1 : SB1;
@group(1) @binding(2) var<storage, read> sb2 : SB2;
@group(2) @binding(2) var<storage, read> sb3 : array<vec4<f32>>;
@group(3) @binding(2) var<storage, read> sb4 : SB4;
@group(4) @binding(2) var<storage, read> sb5 : array<vec4<f32>>;
@stage(compute) @workgroup_size(1)
fn main() {
var len1 : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
var len3 : u32 = (tint_symbol_1.buffer_size[0u][2u] / 16u);
var x : u32 = (len1 + len3);
}
)";
ArrayLengthFromUniform::Config cfg({0, 30u});
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2u}, 0);
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{1u, 2u}, 1);
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{2u, 2u}, 2);
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{3u, 2u}, 3);
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{4u, 2u}, 4);
DataMap data;
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(expect, str(got));
EXPECT_EQ(std::unordered_set<uint32_t>({0, 2}),
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
}
TEST_F(ArrayLengthFromUniformTest, NoArrayLengthCalls) {
auto* src = R"(
struct SB {
x : i32;
arr : array<i32>;
}
@group(0) @binding(0) var<storage, read> sb : SB;
@stage(compute) @workgroup_size(1)
fn main() {
_ = &(sb.arr);
}
)";
ArrayLengthFromUniform::Config cfg({0, 30u});
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
DataMap data;
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(src, str(got));
EXPECT_EQ(got.data.Get<ArrayLengthFromUniform::Result>(), nullptr);
}
TEST_F(ArrayLengthFromUniformTest, MissingBindingPointToIndexMapping) {
auto* src = R"(
struct SB1 {
x : i32;
arr1 : array<i32>;
};
struct SB2 {
x : i32;
arr2 : array<vec4<f32>>;
};
@group(0) @binding(2) var<storage, read> sb1 : SB1;
@group(1) @binding(2) var<storage, read> sb2 : SB2;
@stage(compute) @workgroup_size(1)
fn main() {
var len1 : u32 = arrayLength(&(sb1.arr1));
var len2 : u32 = arrayLength(&(sb2.arr2));
var x : u32 = (len1 + len2);
}
)";
auto* expect = R"(
struct tint_symbol {
buffer_size : array<vec4<u32>, 1u>;
}
@group(0) @binding(30) var<uniform> tint_symbol_1 : tint_symbol;
struct SB1 {
x : i32;
arr1 : array<i32>;
}
struct SB2 {
x : i32;
arr2 : array<vec4<f32>>;
}
@group(0) @binding(2) var<storage, read> sb1 : SB1;
@group(1) @binding(2) var<storage, read> sb2 : SB2;
@stage(compute) @workgroup_size(1)
fn main() {
var len1 : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
var len2 : u32 = arrayLength(&(sb2.arr2));
var x : u32 = (len1 + len2);
}
)";
ArrayLengthFromUniform::Config cfg({0, 30u});
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 2}, 0);
DataMap data;
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(expect, str(got));
EXPECT_EQ(std::unordered_set<uint32_t>({0}),
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
}
TEST_F(ArrayLengthFromUniformTest, OutOfOrder) {
auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var len : u32 = arrayLength(&sb.arr);
}
@group(0) @binding(0) var<storage, read> sb : SB;
struct SB {
x : i32;
arr : array<i32>;
};
)";
auto* expect = R"(
struct tint_symbol {
buffer_size : array<vec4<u32>, 1u>;
}
@group(0) @binding(30) var<uniform> tint_symbol_1 : tint_symbol;
@stage(compute) @workgroup_size(1)
fn main() {
var len : u32 = ((tint_symbol_1.buffer_size[0u][0u] - 4u) / 4u);
}
@group(0) @binding(0) var<storage, read> sb : SB;
struct SB {
x : i32;
arr : array<i32>;
}
)";
ArrayLengthFromUniform::Config cfg({0, 30u});
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
DataMap data;
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(expect, str(got));
EXPECT_EQ(std::unordered_set<uint32_t>({0}),
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,163 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/binding_remapper.h"
#include <string>
#include <unordered_set>
#include <utility>
#include "src/tint/ast/disable_validation_attribute.h"
#include "src/tint/program_builder.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/variable.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::BindingRemapper);
TINT_INSTANTIATE_TYPEINFO(tint::transform::BindingRemapper::Remappings);
namespace tint {
namespace transform {
BindingRemapper::Remappings::Remappings(BindingPoints bp,
AccessControls ac,
bool may_collide)
: binding_points(std::move(bp)),
access_controls(std::move(ac)),
allow_collisions(may_collide) {}
BindingRemapper::Remappings::Remappings(const Remappings&) = default;
BindingRemapper::Remappings::~Remappings() = default;
BindingRemapper::BindingRemapper() = default;
BindingRemapper::~BindingRemapper() = default;
bool BindingRemapper::ShouldRun(const Program*, const DataMap& inputs) const {
if (auto* remappings = inputs.Get<Remappings>()) {
return !remappings->binding_points.empty() ||
!remappings->access_controls.empty();
}
return false;
}
void BindingRemapper::Run(CloneContext& ctx,
const DataMap& inputs,
DataMap&) const {
auto* remappings = inputs.Get<Remappings>();
if (!remappings) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform,
"missing transform data for " + std::string(TypeInfo().name));
return;
}
// A set of post-remapped binding points that need to be decorated with a
// DisableValidationAttribute to disable binding-point-collision validation
std::unordered_set<sem::BindingPoint> add_collision_attr;
if (remappings->allow_collisions) {
// Scan for binding point collisions generated by this transform.
// Populate all collisions in the `add_collision_attr` set.
for (auto* func_ast : ctx.src->AST().Functions()) {
if (!func_ast->IsEntryPoint()) {
continue;
}
auto* func = ctx.src->Sem().Get(func_ast);
std::unordered_map<sem::BindingPoint, int> binding_point_counts;
for (auto* var : func->TransitivelyReferencedGlobals()) {
if (auto binding_point = var->Declaration()->BindingPoint()) {
BindingPoint from{binding_point.group->value,
binding_point.binding->value};
auto bp_it = remappings->binding_points.find(from);
if (bp_it != remappings->binding_points.end()) {
// Remapped
BindingPoint to = bp_it->second;
if (binding_point_counts[to]++) {
add_collision_attr.emplace(to);
}
} else {
// No remapping
if (binding_point_counts[from]++) {
add_collision_attr.emplace(from);
}
}
}
}
}
}
for (auto* var : ctx.src->AST().GlobalVariables()) {
if (auto binding_point = var->BindingPoint()) {
// The original binding point
BindingPoint from{binding_point.group->value,
binding_point.binding->value};
// The binding point after remapping
BindingPoint bp = from;
// Replace any group or binding attributes.
// Note: This has to be performed *before* remapping access controls, as
// `ctx.Clone(var->attributes)` depend on these replacements.
auto bp_it = remappings->binding_points.find(from);
if (bp_it != remappings->binding_points.end()) {
BindingPoint to = bp_it->second;
auto* new_group = ctx.dst->create<ast::GroupAttribute>(to.group);
auto* new_binding = ctx.dst->create<ast::BindingAttribute>(to.binding);
ctx.Replace(binding_point.group, new_group);
ctx.Replace(binding_point.binding, new_binding);
bp = to;
}
// Replace any access controls.
auto ac_it = remappings->access_controls.find(from);
if (ac_it != remappings->access_controls.end()) {
ast::Access ac = ac_it->second;
if (ac > ast::Access::kLastValid) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform,
"invalid access mode (" +
std::to_string(static_cast<uint32_t>(ac)) + ")");
return;
}
auto* sem = ctx.src->Sem().Get(var);
if (sem->StorageClass() != ast::StorageClass::kStorage) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform,
"cannot apply access control to variable with storage class " +
std::string(ast::ToString(sem->StorageClass())));
return;
}
auto* ty = sem->Type()->UnwrapRef();
const ast::Type* inner_ty = CreateASTTypeFor(ctx, ty);
auto* new_var = ctx.dst->create<ast::Variable>(
ctx.Clone(var->source), ctx.Clone(var->symbol),
var->declared_storage_class, ac, inner_ty, false, false,
ctx.Clone(var->constructor), ctx.Clone(var->attributes));
ctx.Replace(var, new_var);
}
// Add `DisableValidationAttribute`s if required
if (add_collision_attr.count(bp)) {
auto* attribute =
ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision);
ctx.InsertBefore(var->attributes, *var->attributes.begin(), attribute);
}
}
}
ctx.Clone();
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,92 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_BINDING_REMAPPER_H_
#define SRC_TINT_TRANSFORM_BINDING_REMAPPER_H_
#include <unordered_map>
#include "src/tint/ast/access.h"
#include "src/tint/sem/binding_point.h"
#include "src/tint/transform/transform.h"
namespace tint {
namespace transform {
/// BindingPoint is an alias to sem::BindingPoint
using BindingPoint = sem::BindingPoint;
/// BindingRemapper is a transform used to remap resource binding points and
/// access controls.
class BindingRemapper : public Castable<BindingRemapper, Transform> {
public:
/// BindingPoints is a map of old binding point to new binding point
using BindingPoints = std::unordered_map<BindingPoint, BindingPoint>;
/// AccessControls is a map of old binding point to new access control
using AccessControls = std::unordered_map<BindingPoint, ast::Access>;
/// Remappings is consumed by the BindingRemapper transform.
/// Data holds information about shader usage and constant buffer offsets.
struct Remappings : public Castable<Data, transform::Data> {
/// Constructor
/// @param bp a map of new binding points
/// @param ac a map of new access controls
/// @param may_collide If true, then validation will be disabled for
/// binding point collisions generated by this transform
Remappings(BindingPoints bp, AccessControls ac, bool may_collide = true);
/// Copy constructor
Remappings(const Remappings&);
/// Destructor
~Remappings() override;
/// A map of old binding point to new binding point
const BindingPoints binding_points;
/// A map of old binding point to new access controls
const AccessControls access_controls;
/// If true, then validation will be disabled for binding point collisions
/// generated by this transform
const bool allow_collisions;
};
/// Constructor
BindingRemapper();
~BindingRemapper() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_BINDING_REMAPPER_H_

View File

@@ -0,0 +1,423 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/binding_remapper.h"
#include <utility>
#include "src/tint/transform/test_helper.h"
namespace tint {
namespace transform {
namespace {
using BindingRemapperTest = TransformTest;
TEST_F(BindingRemapperTest, ShouldRunNoRemappings) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<BindingRemapper>(src));
}
TEST_F(BindingRemapperTest, ShouldRunEmptyRemappings) {
auto* src = R"()";
DataMap data;
data.Add<BindingRemapper::Remappings>(BindingRemapper::BindingPoints{},
BindingRemapper::AccessControls{});
EXPECT_FALSE(ShouldRun<BindingRemapper>(src, data));
}
TEST_F(BindingRemapperTest, ShouldRunBindingPointRemappings) {
auto* src = R"()";
DataMap data;
data.Add<BindingRemapper::Remappings>(
BindingRemapper::BindingPoints{
{{2, 1}, {1, 2}},
},
BindingRemapper::AccessControls{});
EXPECT_TRUE(ShouldRun<BindingRemapper>(src, data));
}
TEST_F(BindingRemapperTest, ShouldRunAccessControlRemappings) {
auto* src = R"()";
DataMap data;
data.Add<BindingRemapper::Remappings>(BindingRemapper::BindingPoints{},
BindingRemapper::AccessControls{
{{2, 1}, ast::Access::kWrite},
});
EXPECT_TRUE(ShouldRun<BindingRemapper>(src, data));
}
TEST_F(BindingRemapperTest, NoRemappings) {
auto* src = R"(
struct S {
a : f32;
}
@group(2) @binding(1) var<storage, read> a : S;
@group(3) @binding(2) var<storage, read> b : S;
@stage(compute) @workgroup_size(1)
fn f() {
}
)";
auto* expect = src;
DataMap data;
data.Add<BindingRemapper::Remappings>(BindingRemapper::BindingPoints{},
BindingRemapper::AccessControls{});
auto got = Run<BindingRemapper>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(BindingRemapperTest, RemapBindingPoints) {
auto* src = R"(
struct S {
a : f32;
};
@group(2) @binding(1) var<storage, read> a : S;
@group(3) @binding(2) var<storage, read> b : S;
@stage(compute) @workgroup_size(1)
fn f() {
}
)";
auto* expect = R"(
struct S {
a : f32;
}
@group(1) @binding(2) var<storage, read> a : S;
@group(3) @binding(2) var<storage, read> b : S;
@stage(compute) @workgroup_size(1)
fn f() {
}
)";
DataMap data;
data.Add<BindingRemapper::Remappings>(
BindingRemapper::BindingPoints{
{{2, 1}, {1, 2}}, // Remap
{{4, 5}, {6, 7}}, // Not found
// Keep @group(3) @binding(2) as is
},
BindingRemapper::AccessControls{});
auto got = Run<BindingRemapper>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(BindingRemapperTest, RemapAccessControls) {
auto* src = R"(
struct S {
a : f32;
};
@group(2) @binding(1) var<storage, read> a : S;
@group(3) @binding(2) var<storage, write> b : S;
@group(4) @binding(3) var<storage, read> c : S;
@stage(compute) @workgroup_size(1)
fn f() {
}
)";
auto* expect = R"(
struct S {
a : f32;
}
@group(2) @binding(1) var<storage, write> a : S;
@group(3) @binding(2) var<storage, write> b : S;
@group(4) @binding(3) var<storage, read> c : S;
@stage(compute) @workgroup_size(1)
fn f() {
}
)";
DataMap data;
data.Add<BindingRemapper::Remappings>(
BindingRemapper::BindingPoints{},
BindingRemapper::AccessControls{
{{2, 1}, ast::Access::kWrite}, // Modify access control
// Keep @group(3) @binding(2) as is
{{4, 3}, ast::Access::kRead}, // Add access control
});
auto got = Run<BindingRemapper>(src, data);
EXPECT_EQ(expect, str(got));
}
// TODO(crbug.com/676): Possibly enable if the spec allows for access
// attributes in type aliases. If not, just remove.
TEST_F(BindingRemapperTest, DISABLED_RemapAccessControlsWithAliases) {
auto* src = R"(
struct S {
a : f32;
};
type, read ReadOnlyS = S;
type, write WriteOnlyS = S;
type A = S;
@group(2) @binding(1) var<storage> a : ReadOnlyS;
@group(3) @binding(2) var<storage> b : WriteOnlyS;
@group(4) @binding(3) var<storage> c : A;
@stage(compute) @workgroup_size(1)
fn f() {
}
)";
auto* expect = R"(
struct S {
a : f32;
};
type, read ReadOnlyS = S;
type, write WriteOnlyS = S;
type A = S;
@group(2) @binding(1) var<storage, write> a : S;
@group(3) @binding(2) var<storage> b : WriteOnlyS;
@group(4) @binding(3) var<storage, write> c : S;
@stage(compute) @workgroup_size(1)
fn f() {
}
)";
DataMap data;
data.Add<BindingRemapper::Remappings>(
BindingRemapper::BindingPoints{},
BindingRemapper::AccessControls{
{{2, 1}, ast::Access::kWrite}, // Modify access control
// Keep @group(3) @binding(2) as is
{{4, 3}, ast::Access::kRead}, // Add access control
});
auto got = Run<BindingRemapper>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(BindingRemapperTest, RemapAll) {
auto* src = R"(
struct S {
a : f32;
};
@group(2) @binding(1) var<storage, read> a : S;
@group(3) @binding(2) var<storage, read> b : S;
@stage(compute) @workgroup_size(1)
fn f() {
}
)";
auto* expect = R"(
struct S {
a : f32;
}
@group(4) @binding(5) var<storage, write> a : S;
@group(6) @binding(7) var<storage, write> b : S;
@stage(compute) @workgroup_size(1)
fn f() {
}
)";
DataMap data;
data.Add<BindingRemapper::Remappings>(
BindingRemapper::BindingPoints{
{{2, 1}, {4, 5}},
{{3, 2}, {6, 7}},
},
BindingRemapper::AccessControls{
{{2, 1}, ast::Access::kWrite},
{{3, 2}, ast::Access::kWrite},
});
auto got = Run<BindingRemapper>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(BindingRemapperTest, BindingCollisionsSameEntryPoint) {
auto* src = R"(
struct S {
i : i32;
};
@group(2) @binding(1) var<storage, read> a : S;
@group(3) @binding(2) var<storage, read> b : S;
@group(4) @binding(3) var<storage, read> c : S;
@group(5) @binding(4) var<storage, read> d : S;
@stage(compute) @workgroup_size(1)
fn f() {
let x : i32 = (((a.i + b.i) + c.i) + d.i);
}
)";
auto* expect = R"(
struct S {
i : i32;
}
@internal(disable_validation__binding_point_collision) @group(1) @binding(1) var<storage, read> a : S;
@internal(disable_validation__binding_point_collision) @group(1) @binding(1) var<storage, read> b : S;
@internal(disable_validation__binding_point_collision) @group(5) @binding(4) var<storage, read> c : S;
@internal(disable_validation__binding_point_collision) @group(5) @binding(4) var<storage, read> d : S;
@stage(compute) @workgroup_size(1)
fn f() {
let x : i32 = (((a.i + b.i) + c.i) + d.i);
}
)";
DataMap data;
data.Add<BindingRemapper::Remappings>(
BindingRemapper::BindingPoints{
{{2, 1}, {1, 1}},
{{3, 2}, {1, 1}},
{{4, 3}, {5, 4}},
},
BindingRemapper::AccessControls{}, true);
auto got = Run<BindingRemapper>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(BindingRemapperTest, BindingCollisionsDifferentEntryPoints) {
auto* src = R"(
struct S {
i : i32;
};
@group(2) @binding(1) var<storage, read> a : S;
@group(3) @binding(2) var<storage, read> b : S;
@group(4) @binding(3) var<storage, read> c : S;
@group(5) @binding(4) var<storage, read> d : S;
@stage(compute) @workgroup_size(1)
fn f1() {
let x : i32 = (a.i + c.i);
}
@stage(compute) @workgroup_size(1)
fn f2() {
let x : i32 = (b.i + d.i);
}
)";
auto* expect = R"(
struct S {
i : i32;
}
@group(1) @binding(1) var<storage, read> a : S;
@group(1) @binding(1) var<storage, read> b : S;
@group(5) @binding(4) var<storage, read> c : S;
@group(5) @binding(4) var<storage, read> d : S;
@stage(compute) @workgroup_size(1)
fn f1() {
let x : i32 = (a.i + c.i);
}
@stage(compute) @workgroup_size(1)
fn f2() {
let x : i32 = (b.i + d.i);
}
)";
DataMap data;
data.Add<BindingRemapper::Remappings>(
BindingRemapper::BindingPoints{
{{2, 1}, {1, 1}},
{{3, 2}, {1, 1}},
{{4, 3}, {5, 4}},
},
BindingRemapper::AccessControls{}, true);
auto got = Run<BindingRemapper>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(BindingRemapperTest, NoData) {
auto* src = R"(
struct S {
a : f32;
}
@group(2) @binding(1) var<storage, read> a : S;
@group(3) @binding(2) var<storage, read> b : S;
@stage(compute) @workgroup_size(1)
fn f() {
}
)";
auto* expect = src;
auto got = Run<BindingRemapper>(src);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,243 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/calculate_array_length.h"
#include <unordered_map>
#include <utility>
#include "src/tint/ast/call_statement.h"
#include "src/tint/ast/disable_validation_attribute.h"
#include "src/tint/program_builder.h"
#include "src/tint/sem/block_statement.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/statement.h"
#include "src/tint/sem/struct.h"
#include "src/tint/sem/variable.h"
#include "src/tint/transform/simplify_pointers.h"
#include "src/tint/utils/hash.h"
#include "src/tint/utils/map.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::CalculateArrayLength);
TINT_INSTANTIATE_TYPEINFO(
tint::transform::CalculateArrayLength::BufferSizeIntrinsic);
namespace tint {
namespace transform {
namespace {
/// ArrayUsage describes a runtime array usage.
/// It is used as a key by the array_length_by_usage map.
struct ArrayUsage {
ast::BlockStatement const* const block;
sem::Variable const* const buffer;
bool operator==(const ArrayUsage& rhs) const {
return block == rhs.block && buffer == rhs.buffer;
}
struct Hasher {
inline std::size_t operator()(const ArrayUsage& u) const {
return utils::Hash(u.block, u.buffer);
}
};
};
} // namespace
CalculateArrayLength::BufferSizeIntrinsic::BufferSizeIntrinsic(ProgramID pid)
: Base(pid) {}
CalculateArrayLength::BufferSizeIntrinsic::~BufferSizeIntrinsic() = default;
std::string CalculateArrayLength::BufferSizeIntrinsic::InternalName() const {
return "intrinsic_buffer_size";
}
const CalculateArrayLength::BufferSizeIntrinsic*
CalculateArrayLength::BufferSizeIntrinsic::Clone(CloneContext* ctx) const {
return ctx->dst->ASTNodes().Create<CalculateArrayLength::BufferSizeIntrinsic>(
ctx->dst->ID());
}
CalculateArrayLength::CalculateArrayLength() = default;
CalculateArrayLength::~CalculateArrayLength() = default;
bool CalculateArrayLength::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* fn : program->AST().Functions()) {
if (auto* sem_fn = program->Sem().Get(fn)) {
for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
if (builtin->Type() == sem::BuiltinType::kArrayLength) {
return true;
}
}
}
}
return false;
}
void CalculateArrayLength::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
auto& sem = ctx.src->Sem();
// get_buffer_size_intrinsic() emits the function decorated with
// BufferSizeIntrinsic that is transformed by the HLSL writer into a call to
// [RW]ByteAddressBuffer.GetDimensions().
std::unordered_map<const sem::Type*, Symbol> buffer_size_intrinsics;
auto get_buffer_size_intrinsic = [&](const sem::Type* buffer_type) {
return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] {
auto name = ctx.dst->Sym();
auto* type = CreateASTTypeFor(ctx, buffer_type);
auto* disable_validation = ctx.dst->Disable(
ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
ctx.dst->AST().AddFunction(ctx.dst->create<ast::Function>(
name,
ast::VariableList{
// Note: The buffer parameter requires the kStorage StorageClass
// in order for HLSL to emit this as a ByteAddressBuffer.
ctx.dst->create<ast::Variable>(
ctx.dst->Sym("buffer"), ast::StorageClass::kStorage,
ast::Access::kUndefined, type, true, false, nullptr,
ast::AttributeList{disable_validation}),
ctx.dst->Param("result",
ctx.dst->ty.pointer(ctx.dst->ty.u32(),
ast::StorageClass::kFunction)),
},
ctx.dst->ty.void_(), nullptr,
ast::AttributeList{
ctx.dst->ASTNodes().Create<BufferSizeIntrinsic>(ctx.dst->ID()),
},
ast::AttributeList{}));
return name;
});
};
std::unordered_map<ArrayUsage, Symbol, ArrayUsage::Hasher>
array_length_by_usage;
// Find all the arrayLength() calls...
for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* call_expr = node->As<ast::CallExpression>()) {
auto* call = sem.Get(call_expr);
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
if (builtin->Type() == sem::BuiltinType::kArrayLength) {
// We're dealing with an arrayLength() call
// A runtime-sized array can only appear as the store type of a
// variable, or the last element of a structure (which cannot itself
// be nested). Given that we require SimplifyPointers, we can assume
// that the arrayLength() call has one of two forms:
// arrayLength(&struct_var.array_member)
// arrayLength(&array_var)
auto* arg = call_expr->args[0];
auto* address_of = arg->As<ast::UnaryOpExpression>();
if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "arrayLength() expected address-of, got "
<< arg->TypeInfo().name;
}
auto* storage_buffer_expr = address_of->expr;
if (auto* accessor =
storage_buffer_expr->As<ast::MemberAccessorExpression>()) {
storage_buffer_expr = accessor->structure;
}
auto* storage_buffer_sem =
sem.Get<sem::VariableUser>(storage_buffer_expr);
if (!storage_buffer_sem) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "expected form of arrayLength argument to be &array_var or "
"&struct_var.array_member";
break;
}
auto* storage_buffer_var = storage_buffer_sem->Variable();
auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef();
// Generate BufferSizeIntrinsic for this storage type if we haven't
// already
auto buffer_size = get_buffer_size_intrinsic(storage_buffer_type);
// Find the current statement block
auto* block = call->Stmt()->Block()->Declaration();
auto array_length = utils::GetOrCreate(
array_length_by_usage, {block, storage_buffer_var}, [&] {
// First time this array length is used for this block.
// Let's calculate it.
// Construct the variable that'll hold the result of
// RWByteAddressBuffer.GetDimensions()
auto* buffer_size_result = ctx.dst->Decl(
ctx.dst->Var(ctx.dst->Sym(), ctx.dst->ty.u32(),
ast::StorageClass::kNone, ctx.dst->Expr(0u)));
// Call storage_buffer.GetDimensions(&buffer_size_result)
auto* call_get_dims = ctx.dst->CallStmt(ctx.dst->Call(
// BufferSizeIntrinsic(X, ARGS...) is
// translated to:
// X.GetDimensions(ARGS..) by the writer
buffer_size, ctx.Clone(storage_buffer_expr),
ctx.dst->AddressOf(
ctx.dst->Expr(buffer_size_result->variable->symbol))));
// Calculate actual array length
// total_storage_buffer_size - array_offset
// array_length = ----------------------------------------
// array_stride
auto name = ctx.dst->Sym();
const ast::Expression* total_size =
ctx.dst->Expr(buffer_size_result->variable);
const sem::Array* array_type = nullptr;
if (auto* str = storage_buffer_type->As<sem::Struct>()) {
// The variable is a struct, so subtract the byte offset of
// the array member.
auto* array_member_sem = str->Members().back();
array_type = array_member_sem->Type()->As<sem::Array>();
total_size =
ctx.dst->Sub(total_size, array_member_sem->Offset());
} else if (auto* arr = storage_buffer_type->As<sem::Array>()) {
array_type = arr;
} else {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "expected form of arrayLength argument to be "
"&array_var or &struct_var.array_member";
return name;
}
uint32_t array_stride = array_type->Size();
auto* array_length_var = ctx.dst->Decl(
ctx.dst->Const(name, ctx.dst->ty.u32(),
ctx.dst->Div(total_size, array_stride)));
// Insert the array length calculations at the top of the block
ctx.InsertBefore(block->statements, block->statements[0],
buffer_size_result);
ctx.InsertBefore(block->statements, block->statements[0],
call_get_dims);
ctx.InsertBefore(block->statements, block->statements[0],
array_length_var);
return name;
});
// Replace the call to arrayLength() with the array length variable
ctx.Replace(call_expr, ctx.dst->Expr(array_length));
}
}
}
}
ctx.Clone();
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,83 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_CALCULATE_ARRAY_LENGTH_H_
#define SRC_TINT_TRANSFORM_CALCULATE_ARRAY_LENGTH_H_
#include <string>
#include "src/tint/ast/internal_attribute.h"
#include "src/tint/transform/transform.h"
namespace tint {
// Forward declarations
class CloneContext;
namespace transform {
/// CalculateArrayLength is a transform used to replace calls to arrayLength()
/// with a value calculated from the size of the storage buffer.
///
/// @note Depends on the following transforms to have been run first:
/// * SimplifyPointers
class CalculateArrayLength : public Castable<CalculateArrayLength, Transform> {
public:
/// BufferSizeIntrinsic is an InternalAttribute that's applied to intrinsic
/// functions used to obtain the runtime size of a storage buffer.
class BufferSizeIntrinsic
: public Castable<BufferSizeIntrinsic, ast::InternalAttribute> {
public:
/// Constructor
/// @param program_id the identifier of the program that owns this node
explicit BufferSizeIntrinsic(ProgramID program_id);
/// Destructor
~BufferSizeIntrinsic() override;
/// @return "buffer_size"
std::string InternalName() const override;
/// Performs a deep clone of this object using the CloneContext `ctx`.
/// @param ctx the clone context
/// @return the newly cloned object
const BufferSizeIntrinsic* Clone(CloneContext* ctx) const override;
};
/// Constructor
CalculateArrayLength();
/// Destructor
~CalculateArrayLength() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_CALCULATE_ARRAY_LENGTH_H_

View File

@@ -0,0 +1,625 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/calculate_array_length.h"
#include "src/tint/transform/simplify_pointers.h"
#include "src/tint/transform/test_helper.h"
#include "src/tint/transform/unshadow.h"
namespace tint {
namespace transform {
namespace {
using CalculateArrayLengthTest = TransformTest;
TEST_F(CalculateArrayLengthTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<CalculateArrayLength>(src));
}
TEST_F(CalculateArrayLengthTest, ShouldRunNoArrayLength) {
auto* src = R"(
struct SB {
x : i32;
arr : array<i32>;
};
[[group(0), binding(0)]] var<storage, read> sb : SB;
[[stage(compute), workgroup_size(1)]]
fn main() {
}
)";
EXPECT_FALSE(ShouldRun<CalculateArrayLength>(src));
}
TEST_F(CalculateArrayLengthTest, ShouldRunWithArrayLength) {
auto* src = R"(
struct SB {
x : i32;
arr : array<i32>;
};
[[group(0), binding(0)]] var<storage, read> sb : SB;
[[stage(compute), workgroup_size(1)]]
fn main() {
var len : u32 = arrayLength(&sb.arr);
}
)";
EXPECT_TRUE(ShouldRun<CalculateArrayLength>(src));
}
TEST_F(CalculateArrayLengthTest, BasicArray) {
auto* src = R"(
@group(0) @binding(0) var<storage, read> sb : array<i32>;
@stage(compute) @workgroup_size(1)
fn main() {
var len : u32 = arrayLength(&sb);
}
)";
auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<i32>, result : ptr<function, u32>)
@group(0) @binding(0) var<storage, read> sb : array<i32>;
@stage(compute) @workgroup_size(1)
fn main() {
var tint_symbol_1 : u32 = 0u;
tint_symbol(sb, &(tint_symbol_1));
let tint_symbol_2 : u32 = (tint_symbol_1 / 4u);
var len : u32 = tint_symbol_2;
}
)";
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(CalculateArrayLengthTest, BasicInStruct) {
auto* src = R"(
struct SB {
x : i32;
arr : array<i32>;
};
@group(0) @binding(0) var<storage, read> sb : SB;
@stage(compute) @workgroup_size(1)
fn main() {
var len : u32 = arrayLength(&sb.arr);
}
)";
auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, result : ptr<function, u32>)
struct SB {
x : i32;
arr : array<i32>;
}
@group(0) @binding(0) var<storage, read> sb : SB;
@stage(compute) @workgroup_size(1)
fn main() {
var tint_symbol_1 : u32 = 0u;
tint_symbol(sb, &(tint_symbol_1));
let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
var len : u32 = tint_symbol_2;
}
)";
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(CalculateArrayLengthTest, ArrayOfStruct) {
auto* src = R"(
struct S {
f : f32;
}
@group(0) @binding(0) var<storage, read> arr : array<S>;
@stage(compute) @workgroup_size(1)
fn main() {
let len = arrayLength(&arr);
}
)";
auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<S>, result : ptr<function, u32>)
struct S {
f : f32;
}
@group(0) @binding(0) var<storage, read> arr : array<S>;
@stage(compute) @workgroup_size(1)
fn main() {
var tint_symbol_1 : u32 = 0u;
tint_symbol(arr, &(tint_symbol_1));
let tint_symbol_2 : u32 = (tint_symbol_1 / 4u);
let len = tint_symbol_2;
}
)";
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(CalculateArrayLengthTest, ArrayOfArrayOfStruct) {
auto* src = R"(
struct S {
f : f32;
}
@group(0) @binding(0) var<storage, read> arr : array<array<S, 4>>;
@stage(compute) @workgroup_size(1)
fn main() {
let len = arrayLength(&arr);
}
)";
auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<array<S, 4u>>, result : ptr<function, u32>)
struct S {
f : f32;
}
@group(0) @binding(0) var<storage, read> arr : array<array<S, 4>>;
@stage(compute) @workgroup_size(1)
fn main() {
var tint_symbol_1 : u32 = 0u;
tint_symbol(arr, &(tint_symbol_1));
let tint_symbol_2 : u32 = (tint_symbol_1 / 16u);
let len = tint_symbol_2;
}
)";
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(CalculateArrayLengthTest, InSameBlock) {
auto* src = R"(
@group(0) @binding(0) var<storage, read> sb : array<i32>;;
@stage(compute) @workgroup_size(1)
fn main() {
var a : u32 = arrayLength(&sb);
var b : u32 = arrayLength(&sb);
var c : u32 = arrayLength(&sb);
}
)";
auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<i32>, result : ptr<function, u32>)
@group(0) @binding(0) var<storage, read> sb : array<i32>;
@stage(compute) @workgroup_size(1)
fn main() {
var tint_symbol_1 : u32 = 0u;
tint_symbol(sb, &(tint_symbol_1));
let tint_symbol_2 : u32 = (tint_symbol_1 / 4u);
var a : u32 = tint_symbol_2;
var b : u32 = tint_symbol_2;
var c : u32 = tint_symbol_2;
}
)";
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(CalculateArrayLengthTest, InSameBlock_Struct) {
auto* src = R"(
struct SB {
x : i32;
arr : array<i32>;
};
@group(0) @binding(0) var<storage, read> sb : SB;
@stage(compute) @workgroup_size(1)
fn main() {
var a : u32 = arrayLength(&sb.arr);
var b : u32 = arrayLength(&sb.arr);
var c : u32 = arrayLength(&sb.arr);
}
)";
auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, result : ptr<function, u32>)
struct SB {
x : i32;
arr : array<i32>;
}
@group(0) @binding(0) var<storage, read> sb : SB;
@stage(compute) @workgroup_size(1)
fn main() {
var tint_symbol_1 : u32 = 0u;
tint_symbol(sb, &(tint_symbol_1));
let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
var a : u32 = tint_symbol_2;
var b : u32 = tint_symbol_2;
var c : u32 = tint_symbol_2;
}
)";
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(CalculateArrayLengthTest, WithStride) {
auto* src = R"(
@group(0) @binding(0) var<storage, read> sb : @stride(64) array<i32>;
@stage(compute) @workgroup_size(1)
fn main() {
var len : u32 = arrayLength(&sb);
}
)";
auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : @stride(64) array<i32>, result : ptr<function, u32>)
@group(0) @binding(0) var<storage, read> sb : @stride(64) array<i32>;
@stage(compute) @workgroup_size(1)
fn main() {
var tint_symbol_1 : u32 = 0u;
tint_symbol(sb, &(tint_symbol_1));
let tint_symbol_2 : u32 = (tint_symbol_1 / 64u);
var len : u32 = tint_symbol_2;
}
)";
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(CalculateArrayLengthTest, WithStride_InStruct) {
auto* src = R"(
struct SB {
x : i32;
y : f32;
arr : @stride(64) array<i32>;
};
@group(0) @binding(0) var<storage, read> sb : SB;
@stage(compute) @workgroup_size(1)
fn main() {
var len : u32 = arrayLength(&sb.arr);
}
)";
auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, result : ptr<function, u32>)
struct SB {
x : i32;
y : f32;
arr : @stride(64) array<i32>;
}
@group(0) @binding(0) var<storage, read> sb : SB;
@stage(compute) @workgroup_size(1)
fn main() {
var tint_symbol_1 : u32 = 0u;
tint_symbol(sb, &(tint_symbol_1));
let tint_symbol_2 : u32 = ((tint_symbol_1 - 8u) / 64u);
var len : u32 = tint_symbol_2;
}
)";
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(CalculateArrayLengthTest, Nested) {
auto* src = R"(
struct SB {
x : i32;
arr : array<i32>;
};
@group(0) @binding(0) var<storage, read> sb : SB;
@stage(compute) @workgroup_size(1)
fn main() {
if (true) {
var len : u32 = arrayLength(&sb.arr);
} else {
if (true) {
var len : u32 = arrayLength(&sb.arr);
}
}
}
)";
auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, result : ptr<function, u32>)
struct SB {
x : i32;
arr : array<i32>;
}
@group(0) @binding(0) var<storage, read> sb : SB;
@stage(compute) @workgroup_size(1)
fn main() {
if (true) {
var tint_symbol_1 : u32 = 0u;
tint_symbol(sb, &(tint_symbol_1));
let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
var len : u32 = tint_symbol_2;
} else {
if (true) {
var tint_symbol_3 : u32 = 0u;
tint_symbol(sb, &(tint_symbol_3));
let tint_symbol_4 : u32 = ((tint_symbol_3 - 4u) / 4u);
var len : u32 = tint_symbol_4;
}
}
}
)";
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(CalculateArrayLengthTest, MultipleStorageBuffers) {
auto* src = R"(
struct SB1 {
x : i32;
arr1 : array<i32>;
};
struct SB2 {
x : i32;
arr2 : array<vec4<f32>>;
};
@group(0) @binding(0) var<storage, read> sb1 : SB1;
@group(0) @binding(1) var<storage, read> sb2 : SB2;
@group(0) @binding(2) var<storage, read> sb3 : array<i32>;
@stage(compute) @workgroup_size(1)
fn main() {
var len1 : u32 = arrayLength(&(sb1.arr1));
var len2 : u32 = arrayLength(&(sb2.arr2));
var len3 : u32 = arrayLength(&sb3);
var x : u32 = (len1 + len2 + len3);
}
)";
auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB1, result : ptr<function, u32>)
@internal(intrinsic_buffer_size)
fn tint_symbol_3(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB2, result : ptr<function, u32>)
@internal(intrinsic_buffer_size)
fn tint_symbol_6(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<i32>, result : ptr<function, u32>)
struct SB1 {
x : i32;
arr1 : array<i32>;
}
struct SB2 {
x : i32;
arr2 : array<vec4<f32>>;
}
@group(0) @binding(0) var<storage, read> sb1 : SB1;
@group(0) @binding(1) var<storage, read> sb2 : SB2;
@group(0) @binding(2) var<storage, read> sb3 : array<i32>;
@stage(compute) @workgroup_size(1)
fn main() {
var tint_symbol_1 : u32 = 0u;
tint_symbol(sb1, &(tint_symbol_1));
let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
var tint_symbol_4 : u32 = 0u;
tint_symbol_3(sb2, &(tint_symbol_4));
let tint_symbol_5 : u32 = ((tint_symbol_4 - 16u) / 16u);
var tint_symbol_7 : u32 = 0u;
tint_symbol_6(sb3, &(tint_symbol_7));
let tint_symbol_8 : u32 = (tint_symbol_7 / 4u);
var len1 : u32 = tint_symbol_2;
var len2 : u32 = tint_symbol_5;
var len3 : u32 = tint_symbol_8;
var x : u32 = ((len1 + len2) + len3);
}
)";
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(CalculateArrayLengthTest, Shadowing) {
auto* src = R"(
struct SB {
x : i32;
arr : array<i32>;
};
@group(0) @binding(0) var<storage, read> a : SB;
@group(0) @binding(1) var<storage, read> b : SB;
@stage(compute) @workgroup_size(1)
fn main() {
let x = &a;
var a : u32 = arrayLength(&a.arr);
{
var b : u32 = arrayLength(&((*x).arr));
}
}
)";
auto* expect =
R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, result : ptr<function, u32>)
struct SB {
x : i32;
arr : array<i32>;
}
@group(0) @binding(0) var<storage, read> a : SB;
@group(0) @binding(1) var<storage, read> b : SB;
@stage(compute) @workgroup_size(1)
fn main() {
var tint_symbol_1 : u32 = 0u;
tint_symbol(a, &(tint_symbol_1));
let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
var a_1 : u32 = tint_symbol_2;
{
var tint_symbol_3 : u32 = 0u;
tint_symbol(a, &(tint_symbol_3));
let tint_symbol_4 : u32 = ((tint_symbol_3 - 4u) / 4u);
var b_1 : u32 = tint_symbol_4;
}
}
)";
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(CalculateArrayLengthTest, OutOfOrder) {
auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var len1 : u32 = arrayLength(&(sb1.arr1));
var len2 : u32 = arrayLength(&(sb2.arr2));
var len3 : u32 = arrayLength(&sb3);
var x : u32 = (len1 + len2 + len3);
}
@group(0) @binding(0) var<storage, read> sb1 : SB1;
struct SB1 {
x : i32;
arr1 : array<i32>;
};
@group(0) @binding(1) var<storage, read> sb2 : SB2;
struct SB2 {
x : i32;
arr2 : array<vec4<f32>>;
};
@group(0) @binding(2) var<storage, read> sb3 : array<i32>;
)";
auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB1, result : ptr<function, u32>)
@internal(intrinsic_buffer_size)
fn tint_symbol_3(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB2, result : ptr<function, u32>)
@internal(intrinsic_buffer_size)
fn tint_symbol_6(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<i32>, result : ptr<function, u32>)
@stage(compute) @workgroup_size(1)
fn main() {
var tint_symbol_1 : u32 = 0u;
tint_symbol(sb1, &(tint_symbol_1));
let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
var tint_symbol_4 : u32 = 0u;
tint_symbol_3(sb2, &(tint_symbol_4));
let tint_symbol_5 : u32 = ((tint_symbol_4 - 16u) / 16u);
var tint_symbol_7 : u32 = 0u;
tint_symbol_6(sb3, &(tint_symbol_7));
let tint_symbol_8 : u32 = (tint_symbol_7 / 4u);
var len1 : u32 = tint_symbol_2;
var len2 : u32 = tint_symbol_5;
var len3 : u32 = tint_symbol_8;
var x : u32 = ((len1 + len2) + len3);
}
@group(0) @binding(0) var<storage, read> sb1 : SB1;
struct SB1 {
x : i32;
arr1 : array<i32>;
}
@group(0) @binding(1) var<storage, read> sb2 : SB2;
struct SB2 {
x : i32;
arr2 : array<vec4<f32>>;
}
@group(0) @binding(2) var<storage, read> sb3 : array<i32>;
)";
auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,779 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/canonicalize_entry_point_io.h"
#include <algorithm>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "src/tint/ast/disable_validation_attribute.h"
#include "src/tint/program_builder.h"
#include "src/tint/sem/function.h"
#include "src/tint/transform/unshadow.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::CanonicalizeEntryPointIO);
TINT_INSTANTIATE_TYPEINFO(tint::transform::CanonicalizeEntryPointIO::Config);
namespace tint {
namespace transform {
CanonicalizeEntryPointIO::CanonicalizeEntryPointIO() = default;
CanonicalizeEntryPointIO::~CanonicalizeEntryPointIO() = default;
namespace {
// Comparison function used to reorder struct members such that all members with
// location attributes appear first (ordered by location slot), followed by
// those with builtin attributes.
bool StructMemberComparator(const ast::StructMember* a,
const ast::StructMember* b) {
auto* a_loc = ast::GetAttribute<ast::LocationAttribute>(a->attributes);
auto* b_loc = ast::GetAttribute<ast::LocationAttribute>(b->attributes);
auto* a_blt = ast::GetAttribute<ast::BuiltinAttribute>(a->attributes);
auto* b_blt = ast::GetAttribute<ast::BuiltinAttribute>(b->attributes);
if (a_loc) {
if (!b_loc) {
// `a` has location attribute and `b` does not: `a` goes first.
return true;
}
// Both have location attributes: smallest goes first.
return a_loc->value < b_loc->value;
} else {
if (b_loc) {
// `b` has location attribute and `a` does not: `b` goes first.
return false;
}
// Both are builtins: order doesn't matter, just use enum value.
return a_blt->builtin < b_blt->builtin;
}
}
// Returns true if `attr` is a shader IO attribute.
bool IsShaderIOAttribute(const ast::Attribute* attr) {
return attr->IsAnyOf<ast::BuiltinAttribute, ast::InterpolateAttribute,
ast::InvariantAttribute, ast::LocationAttribute>();
}
// Returns true if `attrs` contains a `sample_mask` builtin.
bool HasSampleMask(const ast::AttributeList& attrs) {
auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(attrs);
return builtin && builtin->builtin == ast::Builtin::kSampleMask;
}
} // namespace
/// State holds the current transform state for a single entry point.
struct CanonicalizeEntryPointIO::State {
/// OutputValue represents a shader result that the wrapper function produces.
struct OutputValue {
/// The name of the output value.
std::string name;
/// The type of the output value.
const ast::Type* type;
/// The shader IO attributes.
ast::AttributeList attributes;
/// The value itself.
const ast::Expression* value;
};
/// The clone context.
CloneContext& ctx;
/// The transform config.
CanonicalizeEntryPointIO::Config const cfg;
/// The entry point function (AST).
const ast::Function* func_ast;
/// The entry point function (SEM).
const sem::Function* func_sem;
/// The new entry point wrapper function's parameters.
ast::VariableList wrapper_ep_parameters;
/// The members of the wrapper function's struct parameter.
ast::StructMemberList wrapper_struct_param_members;
/// The name of the wrapper function's struct parameter.
Symbol wrapper_struct_param_name;
/// The parameters that will be passed to the original function.
ast::ExpressionList inner_call_parameters;
/// The members of the wrapper function's struct return type.
ast::StructMemberList wrapper_struct_output_members;
/// The wrapper function output values.
std::vector<OutputValue> wrapper_output_values;
/// The body of the wrapper function.
ast::StatementList wrapper_body;
/// Input names used by the entrypoint
std::unordered_set<std::string> input_names;
/// Constructor
/// @param context the clone context
/// @param config the transform config
/// @param function the entry point function
State(CloneContext& context,
const CanonicalizeEntryPointIO::Config& config,
const ast::Function* function)
: ctx(context),
cfg(config),
func_ast(function),
func_sem(ctx.src->Sem().Get(function)) {}
/// Clones the shader IO attributes from `src`.
/// @param src the attributes to clone
/// @param do_interpolate whether to clone InterpolateAttribute
/// @return the cloned attributes
ast::AttributeList CloneShaderIOAttributes(const ast::AttributeList& src,
bool do_interpolate) {
ast::AttributeList new_attributes;
for (auto* attr : src) {
if (IsShaderIOAttribute(attr) &&
(do_interpolate || !attr->Is<ast::InterpolateAttribute>())) {
new_attributes.push_back(ctx.Clone(attr));
}
}
return new_attributes;
}
/// Create or return a symbol for the wrapper function's struct parameter.
/// @returns the symbol for the struct parameter
Symbol InputStructSymbol() {
if (!wrapper_struct_param_name.IsValid()) {
wrapper_struct_param_name = ctx.dst->Sym();
}
return wrapper_struct_param_name;
}
/// Add a shader input to the entry point.
/// @param name the name of the shader input
/// @param type the type of the shader input
/// @param attributes the attributes to apply to the shader input
/// @returns an expression which evaluates to the value of the shader input
const ast::Expression* AddInput(std::string name,
const sem::Type* type,
ast::AttributeList attributes) {
auto* ast_type = CreateASTTypeFor(ctx, type);
if (cfg.shader_style == ShaderStyle::kSpirv ||
cfg.shader_style == ShaderStyle::kGlsl) {
// Vulkan requires that integer user-defined fragment inputs are
// always decorated with `Flat`.
// TODO(crbug.com/tint/1224): Remove this once a flat interpolation
// attribute is required for integers.
if (type->is_integer_scalar_or_vector() &&
ast::HasAttribute<ast::LocationAttribute>(attributes) &&
!ast::HasAttribute<ast::InterpolateAttribute>(attributes) &&
func_ast->PipelineStage() == ast::PipelineStage::kFragment) {
attributes.push_back(ctx.dst->Interpolate(
ast::InterpolationType::kFlat, ast::InterpolationSampling::kNone));
}
// Disable validation for use of the `input` storage class.
attributes.push_back(
ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
// In GLSL, if it's a builtin, override the name with the
// corresponding gl_ builtin name
auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(attributes);
if (cfg.shader_style == ShaderStyle::kGlsl && builtin) {
name = GLSLBuiltinToString(builtin->builtin, func_ast->PipelineStage(),
ast::StorageClass::kInput);
}
auto symbol = ctx.dst->Symbols().New(name);
// Create the global variable and use its value for the shader input.
const ast::Expression* value = ctx.dst->Expr(symbol);
if (builtin) {
if (cfg.shader_style == ShaderStyle::kGlsl) {
value = FromGLSLBuiltin(builtin->builtin, value, ast_type);
} else if (builtin->builtin == ast::Builtin::kSampleMask) {
// Vulkan requires the type of a SampleMask builtin to be an array.
// Declare it as array<u32, 1> and then load the first element.
ast_type = ctx.dst->ty.array(ast_type, 1);
value = ctx.dst->IndexAccessor(value, 0);
}
}
ctx.dst->Global(symbol, ast_type, ast::StorageClass::kInput,
std::move(attributes));
return value;
} else if (cfg.shader_style == ShaderStyle::kMsl &&
ast::HasAttribute<ast::BuiltinAttribute>(attributes)) {
// If this input is a builtin and we are targeting MSL, then add it to the
// parameter list and pass it directly to the inner function.
Symbol symbol = input_names.emplace(name).second
? ctx.dst->Symbols().Register(name)
: ctx.dst->Symbols().New(name);
wrapper_ep_parameters.push_back(
ctx.dst->Param(symbol, ast_type, std::move(attributes)));
return ctx.dst->Expr(symbol);
} else {
// Otherwise, move it to the new structure member list.
Symbol symbol = input_names.emplace(name).second
? ctx.dst->Symbols().Register(name)
: ctx.dst->Symbols().New(name);
wrapper_struct_param_members.push_back(
ctx.dst->Member(symbol, ast_type, std::move(attributes)));
return ctx.dst->MemberAccessor(InputStructSymbol(), symbol);
}
}
/// Add a shader output to the entry point.
/// @param name the name of the shader output
/// @param type the type of the shader output
/// @param attributes the attributes to apply to the shader output
/// @param value the value of the shader output
void AddOutput(std::string name,
const sem::Type* type,
ast::AttributeList attributes,
const ast::Expression* value) {
// Vulkan requires that integer user-defined vertex outputs are
// always decorated with `Flat`.
// TODO(crbug.com/tint/1224): Remove this once a flat interpolation
// attribute is required for integers.
if (cfg.shader_style == ShaderStyle::kSpirv &&
type->is_integer_scalar_or_vector() &&
ast::HasAttribute<ast::LocationAttribute>(attributes) &&
!ast::HasAttribute<ast::InterpolateAttribute>(attributes) &&
func_ast->PipelineStage() == ast::PipelineStage::kVertex) {
attributes.push_back(ctx.dst->Interpolate(
ast::InterpolationType::kFlat, ast::InterpolationSampling::kNone));
}
// In GLSL, if it's a builtin, override the name with the
// corresponding gl_ builtin name
if (cfg.shader_style == ShaderStyle::kGlsl) {
if (auto* b = ast::GetAttribute<ast::BuiltinAttribute>(attributes)) {
name = GLSLBuiltinToString(b->builtin, func_ast->PipelineStage(),
ast::StorageClass::kOutput);
value = ToGLSLBuiltin(b->builtin, value, type);
}
}
OutputValue output;
output.name = name;
output.type = CreateASTTypeFor(ctx, type);
output.attributes = std::move(attributes);
output.value = value;
wrapper_output_values.push_back(output);
}
/// Process a non-struct parameter.
/// This creates a new object for the shader input, moving the shader IO
/// attributes to it. It also adds an expression to the list of parameters
/// that will be passed to the original function.
/// @param param the original function parameter
void ProcessNonStructParameter(const sem::Parameter* param) {
// Remove the shader IO attributes from the inner function parameter, and
// attach them to the new object instead.
ast::AttributeList attributes;
for (auto* attr : param->Declaration()->attributes) {
if (IsShaderIOAttribute(attr)) {
ctx.Remove(param->Declaration()->attributes, attr);
attributes.push_back(ctx.Clone(attr));
}
}
auto name = ctx.src->Symbols().NameFor(param->Declaration()->symbol);
auto* input_expr = AddInput(name, param->Type(), std::move(attributes));
inner_call_parameters.push_back(input_expr);
}
/// Process a struct parameter.
/// This creates new objects for each struct member, moving the shader IO
/// attributes to them. It also creates the structure that will be passed to
/// the original function.
/// @param param the original function parameter
void ProcessStructParameter(const sem::Parameter* param) {
auto* str = param->Type()->As<sem::Struct>();
// Recreate struct members in the outer entry point and build an initializer
// list to pass them through to the inner function.
ast::ExpressionList inner_struct_values;
for (auto* member : str->Members()) {
if (member->Type()->Is<sem::Struct>()) {
TINT_ICE(Transform, ctx.dst->Diagnostics()) << "nested IO struct";
continue;
}
auto* member_ast = member->Declaration();
auto name = ctx.src->Symbols().NameFor(member_ast->symbol);
// In GLSL, do not add interpolation attributes on vertex input
bool do_interpolate = true;
if (cfg.shader_style == ShaderStyle::kGlsl &&
func_ast->PipelineStage() == ast::PipelineStage::kVertex) {
do_interpolate = false;
}
auto attributes =
CloneShaderIOAttributes(member_ast->attributes, do_interpolate);
auto* input_expr = AddInput(name, member->Type(), std::move(attributes));
inner_struct_values.push_back(input_expr);
}
// Construct the original structure using the new shader input objects.
inner_call_parameters.push_back(ctx.dst->Construct(
ctx.Clone(param->Declaration()->type), inner_struct_values));
}
/// Process the entry point return type.
/// This generates a list of output values that are returned by the original
/// function.
/// @param inner_ret_type the original function return type
/// @param original_result the result object produced by the original function
void ProcessReturnType(const sem::Type* inner_ret_type,
Symbol original_result) {
bool do_interpolate = true;
// In GLSL, do not add interpolation attributes on fragment output
if (cfg.shader_style == ShaderStyle::kGlsl &&
func_ast->PipelineStage() == ast::PipelineStage::kFragment) {
do_interpolate = false;
}
if (auto* str = inner_ret_type->As<sem::Struct>()) {
for (auto* member : str->Members()) {
if (member->Type()->Is<sem::Struct>()) {
TINT_ICE(Transform, ctx.dst->Diagnostics()) << "nested IO struct";
continue;
}
auto* member_ast = member->Declaration();
auto name = ctx.src->Symbols().NameFor(member_ast->symbol);
auto attributes =
CloneShaderIOAttributes(member_ast->attributes, do_interpolate);
// Extract the original structure member.
AddOutput(name, member->Type(), std::move(attributes),
ctx.dst->MemberAccessor(original_result, name));
}
} else if (!inner_ret_type->Is<sem::Void>()) {
auto attributes = CloneShaderIOAttributes(
func_ast->return_type_attributes, do_interpolate);
// Propagate the non-struct return value as is.
AddOutput("value", func_sem->ReturnType(), std::move(attributes),
ctx.dst->Expr(original_result));
}
}
/// Add a fixed sample mask to the wrapper function output.
/// If there is already a sample mask, bitwise-and it with the fixed mask.
/// Otherwise, create a new output value from the fixed mask.
void AddFixedSampleMask() {
// Check the existing output values for a sample mask builtin.
for (auto& outval : wrapper_output_values) {
if (HasSampleMask(outval.attributes)) {
// Combine the authored sample mask with the fixed mask.
outval.value = ctx.dst->And(outval.value, cfg.fixed_sample_mask);
return;
}
}
// No existing sample mask builtin was found, so create a new output value
// using the fixed sample mask.
AddOutput("fixed_sample_mask", ctx.dst->create<sem::U32>(),
{ctx.dst->Builtin(ast::Builtin::kSampleMask)},
ctx.dst->Expr(cfg.fixed_sample_mask));
}
/// Add a point size builtin to the wrapper function output.
void AddVertexPointSize() {
// Create a new output value and assign it a literal 1.0 value.
AddOutput("vertex_point_size", ctx.dst->create<sem::F32>(),
{ctx.dst->Builtin(ast::Builtin::kPointSize)}, ctx.dst->Expr(1.f));
}
/// Create an expression for gl_Position.[component]
/// @param component the component of gl_Position to access
/// @returns the new expression
const ast::Expression* GLPosition(const char* component) {
Symbol pos = ctx.dst->Symbols().Register("gl_Position");
Symbol c = ctx.dst->Symbols().Register(component);
return ctx.dst->MemberAccessor(ctx.dst->Expr(pos), ctx.dst->Expr(c));
}
/// Create the wrapper function's struct parameter and type objects.
void CreateInputStruct() {
// Sort the struct members to satisfy HLSL interfacing matching rules.
std::sort(wrapper_struct_param_members.begin(),
wrapper_struct_param_members.end(), StructMemberComparator);
// Create the new struct type.
auto struct_name = ctx.dst->Sym();
auto* in_struct = ctx.dst->create<ast::Struct>(
struct_name, wrapper_struct_param_members, ast::AttributeList{});
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, in_struct);
// Create a new function parameter using this struct type.
auto* param =
ctx.dst->Param(InputStructSymbol(), ctx.dst->ty.type_name(struct_name));
wrapper_ep_parameters.push_back(param);
}
/// Create and return the wrapper function's struct result object.
/// @returns the struct type
ast::Struct* CreateOutputStruct() {
ast::StatementList assignments;
auto wrapper_result = ctx.dst->Symbols().New("wrapper_result");
// Create the struct members and their corresponding assignment statements.
std::unordered_set<std::string> member_names;
for (auto& outval : wrapper_output_values) {
// Use the original output name, unless that is already taken.
Symbol name;
if (member_names.count(outval.name)) {
name = ctx.dst->Symbols().New(outval.name);
} else {
name = ctx.dst->Symbols().Register(outval.name);
}
member_names.insert(ctx.dst->Symbols().NameFor(name));
wrapper_struct_output_members.push_back(
ctx.dst->Member(name, outval.type, std::move(outval.attributes)));
assignments.push_back(ctx.dst->Assign(
ctx.dst->MemberAccessor(wrapper_result, name), outval.value));
}
// Sort the struct members to satisfy HLSL interfacing matching rules.
std::sort(wrapper_struct_output_members.begin(),
wrapper_struct_output_members.end(), StructMemberComparator);
// Create the new struct type.
auto* out_struct = ctx.dst->create<ast::Struct>(
ctx.dst->Sym(), wrapper_struct_output_members, ast::AttributeList{});
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, out_struct);
// Create the output struct object, assign its members, and return it.
auto* result_object =
ctx.dst->Var(wrapper_result, ctx.dst->ty.type_name(out_struct->name));
wrapper_body.push_back(ctx.dst->Decl(result_object));
wrapper_body.insert(wrapper_body.end(), assignments.begin(),
assignments.end());
wrapper_body.push_back(ctx.dst->Return(wrapper_result));
return out_struct;
}
/// Create and assign the wrapper function's output variables.
void CreateGlobalOutputVariables() {
for (auto& outval : wrapper_output_values) {
// Disable validation for use of the `output` storage class.
ast::AttributeList attributes = std::move(outval.attributes);
attributes.push_back(
ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
// Create the global variable and assign it the output value.
auto name = ctx.dst->Symbols().New(outval.name);
auto* type = outval.type;
const ast::Expression* lhs = ctx.dst->Expr(name);
if (HasSampleMask(attributes)) {
// Vulkan requires the type of a SampleMask builtin to be an array.
// Declare it as array<u32, 1> and then store to the first element.
type = ctx.dst->ty.array(type, 1);
lhs = ctx.dst->IndexAccessor(lhs, 0);
}
ctx.dst->Global(name, type, ast::StorageClass::kOutput,
std::move(attributes));
wrapper_body.push_back(ctx.dst->Assign(lhs, outval.value));
}
}
// Recreate the original function without entry point attributes and call it.
/// @returns the inner function call expression
const ast::CallExpression* CallInnerFunction() {
Symbol inner_name;
if (cfg.shader_style == ShaderStyle::kGlsl) {
// In GLSL, clone the original entry point name, as the wrapper will be
// called "main".
inner_name = ctx.Clone(func_ast->symbol);
} else {
// Add a suffix to the function name, as the wrapper function will take
// the original entry point name.
auto ep_name = ctx.src->Symbols().NameFor(func_ast->symbol);
inner_name = ctx.dst->Symbols().New(ep_name + "_inner");
}
// Clone everything, dropping the function and return type attributes.
// The parameter attributes will have already been stripped during
// processing.
auto* inner_function = ctx.dst->create<ast::Function>(
inner_name, ctx.Clone(func_ast->params),
ctx.Clone(func_ast->return_type), ctx.Clone(func_ast->body),
ast::AttributeList{}, ast::AttributeList{});
ctx.Replace(func_ast, inner_function);
// Call the function.
return ctx.dst->Call(inner_function->symbol, inner_call_parameters);
}
/// Process the entry point function.
void Process() {
bool needs_fixed_sample_mask = false;
bool needs_vertex_point_size = false;
if (func_ast->PipelineStage() == ast::PipelineStage::kFragment &&
cfg.fixed_sample_mask != 0xFFFFFFFF) {
needs_fixed_sample_mask = true;
}
if (func_ast->PipelineStage() == ast::PipelineStage::kVertex &&
cfg.emit_vertex_point_size) {
needs_vertex_point_size = true;
}
// Exit early if there is no shader IO to handle.
if (func_sem->Parameters().size() == 0 &&
func_sem->ReturnType()->Is<sem::Void>() && !needs_fixed_sample_mask &&
!needs_vertex_point_size && cfg.shader_style != ShaderStyle::kGlsl) {
return;
}
// Process the entry point parameters, collecting those that need to be
// aggregated into a single structure.
if (!func_sem->Parameters().empty()) {
for (auto* param : func_sem->Parameters()) {
if (param->Type()->Is<sem::Struct>()) {
ProcessStructParameter(param);
} else {
ProcessNonStructParameter(param);
}
}
// Create a structure parameter for the outer entry point if necessary.
if (!wrapper_struct_param_members.empty()) {
CreateInputStruct();
}
}
// Recreate the original function and call it.
auto* call_inner = CallInnerFunction();
// Process the return type, and start building the wrapper function body.
std::function<const ast::Type*()> wrapper_ret_type = [&] {
return ctx.dst->ty.void_();
};
if (func_sem->ReturnType()->Is<sem::Void>()) {
// The function call is just a statement with no result.
wrapper_body.push_back(ctx.dst->CallStmt(call_inner));
} else {
// Capture the result of calling the original function.
auto* inner_result = ctx.dst->Const(
ctx.dst->Symbols().New("inner_result"), nullptr, call_inner);
wrapper_body.push_back(ctx.dst->Decl(inner_result));
// Process the original return type to determine the outputs that the
// outer function needs to produce.
ProcessReturnType(func_sem->ReturnType(), inner_result->symbol);
}
// Add a fixed sample mask, if necessary.
if (needs_fixed_sample_mask) {
AddFixedSampleMask();
}
// Add the pointsize builtin, if necessary.
if (needs_vertex_point_size) {
AddVertexPointSize();
}
// Produce the entry point outputs, if necessary.
if (!wrapper_output_values.empty()) {
if (cfg.shader_style == ShaderStyle::kSpirv ||
cfg.shader_style == ShaderStyle::kGlsl) {
CreateGlobalOutputVariables();
} else {
auto* output_struct = CreateOutputStruct();
wrapper_ret_type = [&, output_struct] {
return ctx.dst->ty.type_name(output_struct->name);
};
}
}
if (cfg.shader_style == ShaderStyle::kGlsl &&
func_ast->PipelineStage() == ast::PipelineStage::kVertex) {
auto* pos_y = GLPosition("y");
auto* negate_pos_y = ctx.dst->create<ast::UnaryOpExpression>(
ast::UnaryOp::kNegation, GLPosition("y"));
wrapper_body.push_back(ctx.dst->Assign(pos_y, negate_pos_y));
auto* two_z = ctx.dst->Mul(ctx.dst->Expr(2.0f), GLPosition("z"));
auto* fixed_z = ctx.dst->Sub(two_z, GLPosition("w"));
wrapper_body.push_back(ctx.dst->Assign(GLPosition("z"), fixed_z));
}
// Create the wrapper entry point function.
// For GLSL, use "main", otherwise take the name of the original
// entry point function.
Symbol name;
if (cfg.shader_style == ShaderStyle::kGlsl) {
name = ctx.dst->Symbols().New("main");
} else {
name = ctx.Clone(func_ast->symbol);
}
auto* wrapper_func = ctx.dst->create<ast::Function>(
name, wrapper_ep_parameters, wrapper_ret_type(),
ctx.dst->Block(wrapper_body), ctx.Clone(func_ast->attributes),
ast::AttributeList{});
ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), func_ast,
wrapper_func);
}
/// Retrieve the gl_ string corresponding to a builtin.
/// @param builtin the builtin
/// @param stage the current pipeline stage
/// @param storage_class the storage class (input or output)
/// @returns the gl_ string corresponding to that builtin
const char* GLSLBuiltinToString(ast::Builtin builtin,
ast::PipelineStage stage,
ast::StorageClass storage_class) {
switch (builtin) {
case ast::Builtin::kPosition:
switch (stage) {
case ast::PipelineStage::kVertex:
return "gl_Position";
case ast::PipelineStage::kFragment:
return "gl_FragCoord";
default:
return "";
}
case ast::Builtin::kVertexIndex:
return "gl_VertexID";
case ast::Builtin::kInstanceIndex:
return "gl_InstanceID";
case ast::Builtin::kFrontFacing:
return "gl_FrontFacing";
case ast::Builtin::kFragDepth:
return "gl_FragDepth";
case ast::Builtin::kLocalInvocationId:
return "gl_LocalInvocationID";
case ast::Builtin::kLocalInvocationIndex:
return "gl_LocalInvocationIndex";
case ast::Builtin::kGlobalInvocationId:
return "gl_GlobalInvocationID";
case ast::Builtin::kNumWorkgroups:
return "gl_NumWorkGroups";
case ast::Builtin::kWorkgroupId:
return "gl_WorkGroupID";
case ast::Builtin::kSampleIndex:
return "gl_SampleID";
case ast::Builtin::kSampleMask:
if (storage_class == ast::StorageClass::kInput) {
return "gl_SampleMaskIn";
} else {
return "gl_SampleMask";
}
default:
return "";
}
}
/// Convert a given GLSL builtin value to the corresponding WGSL value.
/// @param builtin the builtin variable
/// @param value the value to convert
/// @param ast_type (inout) the incoming WGSL and outgoing GLSL types
/// @returns an expression representing the GLSL builtin converted to what
/// WGSL expects
const ast::Expression* FromGLSLBuiltin(ast::Builtin builtin,
const ast::Expression* value,
const ast::Type*& ast_type) {
switch (builtin) {
case ast::Builtin::kVertexIndex:
case ast::Builtin::kInstanceIndex:
case ast::Builtin::kSampleIndex:
// GLSL uses i32 for these, so bitcast to u32.
value = ctx.dst->Bitcast(ast_type, value);
ast_type = ctx.dst->ty.i32();
break;
case ast::Builtin::kSampleMask:
// gl_SampleMask is an array of i32. Retrieve the first element and
// bitcast it to u32.
value = ctx.dst->IndexAccessor(value, 0);
value = ctx.dst->Bitcast(ast_type, value);
ast_type = ctx.dst->ty.array(ctx.dst->ty.i32(), 1);
break;
default:
break;
}
return value;
}
/// Convert a given WGSL value to the type expected when assigning to a
/// GLSL builtin.
/// @param builtin the builtin variable
/// @param value the value to convert
/// @param type (out) the type to which the value was converted
/// @returns the converted value which can be assigned to the GLSL builtin
const ast::Expression* ToGLSLBuiltin(ast::Builtin builtin,
const ast::Expression* value,
const sem::Type*& type) {
switch (builtin) {
case ast::Builtin::kVertexIndex:
case ast::Builtin::kInstanceIndex:
case ast::Builtin::kSampleIndex:
case ast::Builtin::kSampleMask:
type = ctx.dst->create<sem::I32>();
value = ctx.dst->Bitcast(CreateASTTypeFor(ctx, type), value);
break;
default:
break;
}
return value;
}
};
void CanonicalizeEntryPointIO::Run(CloneContext& ctx,
const DataMap& inputs,
DataMap&) const {
auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform,
"missing transform data for " + std::string(TypeInfo().name));
return;
}
// Remove entry point IO attributes from struct declarations.
// New structures will be created for each entry point, as necessary.
for (auto* ty : ctx.src->AST().TypeDecls()) {
if (auto* struct_ty = ty->As<ast::Struct>()) {
for (auto* member : struct_ty->members) {
for (auto* attr : member->attributes) {
if (IsShaderIOAttribute(attr)) {
ctx.Remove(member->attributes, attr);
}
}
}
}
}
for (auto* func_ast : ctx.src->AST().Functions()) {
if (!func_ast->IsEntryPoint()) {
continue;
}
State state(ctx, *cfg, func_ast);
state.Process();
}
ctx.Clone();
}
CanonicalizeEntryPointIO::Config::Config(ShaderStyle style,
uint32_t sample_mask,
bool emit_point_size)
: shader_style(style),
fixed_sample_mask(sample_mask),
emit_vertex_point_size(emit_point_size) {}
CanonicalizeEntryPointIO::Config::Config(const Config&) = default;
CanonicalizeEntryPointIO::Config::~Config() = default;
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,149 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_CANONICALIZE_ENTRY_POINT_IO_H_
#define SRC_TINT_TRANSFORM_CANONICALIZE_ENTRY_POINT_IO_H_
#include "src/tint/transform/transform.h"
namespace tint {
namespace transform {
/// CanonicalizeEntryPointIO is a transform used to rewrite shader entry point
/// interfaces into a form that the generators can handle. Each entry point
/// function is stripped of all shader IO attributes and wrapped in a function
/// that provides the shader interface.
/// The transform config determines whether to use global variables, structures,
/// or parameters for the shader inputs and outputs, and optionally adds
/// additional builtins to the shader interface.
///
/// Before:
/// ```
/// struct Locations{
/// @location(1) loc1 : f32;
/// @location(2) loc2 : vec4<u32>;
/// };
///
/// @stage(fragment)
/// fn frag_main(@builtin(position) coord : vec4<f32>,
/// locations : Locations) -> @location(0) f32 {
/// if (coord.w > 1.0) {
/// return 0.0;
/// }
/// var col : f32 = (coord.x * locations.loc1);
/// return col;
/// }
/// ```
///
/// After (using structures for all parameters):
/// ```
/// struct Locations{
/// loc1 : f32;
/// loc2 : vec4<u32>;
/// };
///
/// struct frag_main_in {
/// @builtin(position) coord : vec4<f32>;
/// @location(1) loc1 : f32;
/// @location(2) loc2 : vec4<u32>
/// };
///
/// struct frag_main_out {
/// @location(0) loc0 : f32;
/// };
///
/// fn frag_main_inner(coord : vec4<f32>,
/// locations : Locations) -> f32 {
/// if (coord.w > 1.0) {
/// return 0.0;
/// }
/// var col : f32 = (coord.x * locations.loc1);
/// return col;
/// }
///
/// @stage(fragment)
/// fn frag_main(in : frag_main_in) -> frag_main_out {
/// let inner_retval = frag_main_inner(in.coord, Locations(in.loc1, in.loc2));
/// var wrapper_result : frag_main_out;
/// wrapper_result.loc0 = inner_retval;
/// return wrapper_result;
/// }
/// ```
///
/// @note Depends on the following transforms to have been run first:
/// * Unshadow
class CanonicalizeEntryPointIO
: public Castable<CanonicalizeEntryPointIO, Transform> {
public:
/// ShaderStyle is an enumerator of different ways to emit shader IO.
enum class ShaderStyle {
/// Target SPIR-V (using global variables).
kSpirv,
/// Target GLSL (using global variables).
kGlsl,
/// Target MSL (using non-struct function parameters for builtins).
kMsl,
/// Target HLSL (using structures for all IO).
kHlsl,
};
/// Configuration options for the transform.
struct Config : public Castable<Config, Data> {
/// Constructor
/// @param style the approach to use for emitting shader IO.
/// @param sample_mask an optional sample mask to combine with shader masks
/// @param emit_vertex_point_size `true` to generate a pointsize builtin
explicit Config(ShaderStyle style,
uint32_t sample_mask = 0xFFFFFFFF,
bool emit_vertex_point_size = false);
/// Copy constructor
Config(const Config&);
/// Destructor
~Config() override;
/// The approach to use for emitting shader IO.
const ShaderStyle shader_style;
/// A fixed sample mask to combine into masks produced by fragment shaders.
const uint32_t fixed_sample_mask;
/// Set to `true` to generate a pointsize builtin and have it set to 1.0
/// from all vertex shaders in the module.
const bool emit_vertex_point_size;
};
/// Constructor
CanonicalizeEntryPointIO();
~CanonicalizeEntryPointIO() override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
struct State;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_CANONICALIZE_ENTRY_POINT_IO_H_

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,355 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/combine_samplers.h"
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "src/tint/sem/function.h"
#include "src/tint/sem/statement.h"
#include "src/tint/utils/map.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::CombineSamplers);
TINT_INSTANTIATE_TYPEINFO(tint::transform::CombineSamplers::BindingInfo);
namespace {
bool IsGlobal(const tint::sem::VariablePair& pair) {
return pair.first->Is<tint::sem::GlobalVariable>() &&
(!pair.second || pair.second->Is<tint::sem::GlobalVariable>());
}
} // namespace
namespace tint {
namespace transform {
CombineSamplers::BindingInfo::BindingInfo(const BindingMap& map,
const sem::BindingPoint& placeholder)
: binding_map(map), placeholder_binding_point(placeholder) {}
CombineSamplers::BindingInfo::BindingInfo(const BindingInfo& other) = default;
CombineSamplers::BindingInfo::~BindingInfo() = default;
/// The PIMPL state for the CombineSamplers transform
struct CombineSamplers::State {
/// The clone context
CloneContext& ctx;
/// The binding info
const BindingInfo* binding_info;
/// Map from a texture/sampler pair to the corresponding combined sampler
/// variable
using CombinedTextureSamplerMap =
std::unordered_map<sem::VariablePair, const ast::Variable*>;
/// Use sem::BindingPoint without scope.
using BindingPoint = sem::BindingPoint;
/// A map of all global texture/sampler variable pairs to the global
/// combined sampler variable that will replace it.
CombinedTextureSamplerMap global_combined_texture_samplers_;
/// A map of all texture/sampler variable pairs that contain a function
/// parameter to the combined sampler function paramter that will replace it.
std::unordered_map<const sem::Function*, CombinedTextureSamplerMap>
function_combined_texture_samplers_;
/// Placeholder global samplers used when a function contains texture-only
/// references (one comparison sampler, one regular). These are also used as
/// temporary sampler parameters to the texture builtins to satisfy the WGSL
/// resolver, but are then ignored and removed by the GLSL writer.
const ast::Variable* placeholder_samplers_[2] = {};
/// Group and binding attributes used by all combined sampler globals.
/// Group 0 and binding 0 are used, with collisions disabled.
/// @returns the newly-created attribute list
ast::AttributeList Attributes() const {
auto attributes = ctx.dst->GroupAndBinding(0, 0);
attributes.push_back(
ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision));
return attributes;
}
/// Constructor
/// @param context the clone context
/// @param info the binding map information
State(CloneContext& context, const BindingInfo* info)
: ctx(context), binding_info(info) {}
/// Creates a combined sampler global variables.
/// (Note this is actually a Texture node at the AST level, but it will be
/// written as the corresponding sampler (eg., sampler2D) on GLSL output.)
/// @param texture_var the texture (global) variable
/// @param sampler_var the sampler (global) variable
/// @param name the default name to use (may be overridden by map lookup)
/// @returns the newly-created global variable
const ast::Variable* CreateCombinedGlobal(const sem::Variable* texture_var,
const sem::Variable* sampler_var,
std::string name) {
SamplerTexturePair bp_pair;
bp_pair.texture_binding_point =
texture_var->As<sem::GlobalVariable>()->BindingPoint();
bp_pair.sampler_binding_point =
sampler_var ? sampler_var->As<sem::GlobalVariable>()->BindingPoint()
: binding_info->placeholder_binding_point;
auto it = binding_info->binding_map.find(bp_pair);
if (it != binding_info->binding_map.end()) {
name = it->second;
}
const ast::Type* type = CreateCombinedASTTypeFor(texture_var, sampler_var);
Symbol symbol = ctx.dst->Symbols().New(name);
return ctx.dst->Global(symbol, type, Attributes());
}
/// Creates placeholder global sampler variables.
/// @param kind the sampler kind to create for
/// @returns the newly-created global variable
const ast::Variable* CreatePlaceholder(ast::SamplerKind kind) {
const ast::Type* type = ctx.dst->ty.sampler(kind);
const char* name = kind == ast::SamplerKind::kComparisonSampler
? "placeholder_comparison_sampler"
: "placeholder_sampler";
Symbol symbol = ctx.dst->Symbols().New(name);
return ctx.dst->Global(symbol, type, Attributes());
}
/// Creates ast::Type for a given texture and sampler variable pair.
/// Depth textures with no samplers are turned into the corresponding
/// f32 texture (e.g., texture_depth_2d -> texture_2d<f32>).
/// @param texture the texture variable of interest
/// @param sampler the texture variable of interest
/// @returns the newly-created type
const ast::Type* CreateCombinedASTTypeFor(const sem::Variable* texture,
const sem::Variable* sampler) {
const sem::Type* texture_type = texture->Type()->UnwrapRef();
const sem::DepthTexture* depth = texture_type->As<sem::DepthTexture>();
if (depth && !sampler) {
return ctx.dst->create<ast::SampledTexture>(depth->dim(),
ctx.dst->create<ast::F32>());
} else {
return CreateASTTypeFor(ctx, texture_type);
}
}
/// Performs the transformation
void Run() {
auto& sem = ctx.src->Sem();
// Remove all texture and sampler global variables. These will be replaced
// by combined samplers.
for (auto* var : ctx.src->AST().GlobalVariables()) {
auto* type = sem.Get(var->type);
if (type && type->IsAnyOf<sem::Texture, sem::Sampler>() &&
!type->Is<sem::StorageTexture>()) {
ctx.Remove(ctx.src->AST().GlobalDeclarations(), var);
} else if (auto binding_point = var->BindingPoint()) {
if (binding_point.group->value == 0 &&
binding_point.binding->value == 0) {
auto* attribute =
ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision);
ctx.InsertFront(var->attributes, attribute);
}
}
}
// Rewrite all function signatures to use combined samplers, and remove
// separate textures & samplers. Create new combined globals where found.
ctx.ReplaceAll([&](const ast::Function* src) -> const ast::Function* {
if (auto* func = sem.Get(src)) {
auto pairs = func->TextureSamplerPairs();
if (pairs.empty()) {
return nullptr;
}
ast::VariableList params;
for (auto pair : func->TextureSamplerPairs()) {
const sem::Variable* texture_var = pair.first;
const sem::Variable* sampler_var = pair.second;
std::string name =
ctx.src->Symbols().NameFor(texture_var->Declaration()->symbol);
if (sampler_var) {
name += "_" + ctx.src->Symbols().NameFor(
sampler_var->Declaration()->symbol);
}
if (IsGlobal(pair)) {
// Both texture and sampler are global; add a new global variable
// to represent the combined sampler (if not already created).
utils::GetOrCreate(global_combined_texture_samplers_, pair, [&] {
return CreateCombinedGlobal(texture_var, sampler_var, name);
});
} else {
// Either texture or sampler (or both) is a function parameter;
// add a new function parameter to represent the combined sampler.
const ast::Type* type =
CreateCombinedASTTypeFor(texture_var, sampler_var);
const ast::Variable* var =
ctx.dst->Param(ctx.dst->Symbols().New(name), type);
params.push_back(var);
function_combined_texture_samplers_[func][pair] = var;
}
}
// Filter out separate textures and samplers from the original
// function signature.
for (auto* var : src->params) {
if (!sem.Get(var->type)->IsAnyOf<sem::Texture, sem::Sampler>()) {
params.push_back(ctx.Clone(var));
}
}
// Create a new function signature that differs only in the parameter
// list.
auto symbol = ctx.Clone(src->symbol);
auto* return_type = ctx.Clone(src->return_type);
auto* body = ctx.Clone(src->body);
auto attributes = ctx.Clone(src->attributes);
auto return_type_attributes = ctx.Clone(src->return_type_attributes);
return ctx.dst->create<ast::Function>(
symbol, params, return_type, body, std::move(attributes),
std::move(return_type_attributes));
}
return nullptr;
});
// Replace all function call expressions containing texture or
// sampler parameters to use the current function's combined samplers or
// the combined global samplers, as appropriate.
ctx.ReplaceAll([&](const ast::CallExpression* expr)
-> const ast::Expression* {
if (auto* call = sem.Get(expr)) {
ast::ExpressionList args;
// Replace all texture builtin calls.
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
const auto& signature = builtin->Signature();
int sampler_index = signature.IndexOf(sem::ParameterUsage::kSampler);
int texture_index = signature.IndexOf(sem::ParameterUsage::kTexture);
if (texture_index == -1) {
return nullptr;
}
const sem::Expression* texture = call->Arguments()[texture_index];
// We don't want to combine storage textures with anything, since
// they never have associated samplers in GLSL.
if (texture->Type()->UnwrapRef()->Is<sem::StorageTexture>()) {
return nullptr;
}
const sem::Expression* sampler =
sampler_index != -1 ? call->Arguments()[sampler_index] : nullptr;
auto* texture_var = texture->As<sem::VariableUser>()->Variable();
auto* sampler_var =
sampler ? sampler->As<sem::VariableUser>()->Variable() : nullptr;
sem::VariablePair new_pair(texture_var, sampler_var);
for (auto* arg : expr->args) {
auto* type = ctx.src->TypeOf(arg)->UnwrapRef();
if (type->Is<sem::Texture>()) {
const ast::Variable* var =
IsGlobal(new_pair)
? global_combined_texture_samplers_[new_pair]
: function_combined_texture_samplers_
[call->Stmt()->Function()][new_pair];
args.push_back(ctx.dst->Expr(var->symbol));
} else if (auto* sampler_type = type->As<sem::Sampler>()) {
ast::SamplerKind kind = sampler_type->kind();
int index = (kind == ast::SamplerKind::kSampler) ? 0 : 1;
const ast::Variable*& p = placeholder_samplers_[index];
if (!p) {
p = CreatePlaceholder(kind);
}
args.push_back(ctx.dst->Expr(p->symbol));
} else {
args.push_back(ctx.Clone(arg));
}
}
const ast::Expression* value =
ctx.dst->Call(ctx.Clone(expr->target.name), args);
if (builtin->Type() == sem::BuiltinType::kTextureLoad &&
texture_var->Type()->UnwrapRef()->Is<sem::DepthTexture>() &&
!call->Stmt()->Declaration()->Is<ast::CallStatement>()) {
value = ctx.dst->MemberAccessor(value, "x");
}
return value;
}
// Replace all function calls.
if (auto* callee = call->Target()->As<sem::Function>()) {
for (auto pair : callee->TextureSamplerPairs()) {
// Global pairs used by the callee do not require a function
// parameter at the call site.
if (IsGlobal(pair)) {
continue;
}
const sem::Variable* texture_var = pair.first;
const sem::Variable* sampler_var = pair.second;
if (auto* param = texture_var->As<sem::Parameter>()) {
const sem::Expression* texture =
call->Arguments()[param->Index()];
texture_var = texture->As<sem::VariableUser>()->Variable();
}
if (sampler_var) {
if (auto* param = sampler_var->As<sem::Parameter>()) {
const sem::Expression* sampler =
call->Arguments()[param->Index()];
sampler_var = sampler->As<sem::VariableUser>()->Variable();
}
}
sem::VariablePair new_pair(texture_var, sampler_var);
// If both texture and sampler are (now) global, pass that
// global variable to the callee. Otherwise use the caller's
// function parameter for this pair.
const ast::Variable* var =
IsGlobal(new_pair) ? global_combined_texture_samplers_[new_pair]
: function_combined_texture_samplers_
[call->Stmt()->Function()][new_pair];
auto* arg = ctx.dst->Expr(var->symbol);
args.push_back(arg);
}
// Append all of the remaining non-texture and non-sampler
// parameters.
for (auto* arg : expr->args) {
if (!ctx.src->TypeOf(arg)
->UnwrapRef()
->IsAnyOf<sem::Texture, sem::Sampler>()) {
args.push_back(ctx.Clone(arg));
}
}
return ctx.dst->Call(ctx.Clone(expr->target.name), args);
}
}
return nullptr;
});
ctx.Clone();
}
};
CombineSamplers::CombineSamplers() = default;
CombineSamplers::~CombineSamplers() = default;
void CombineSamplers::Run(CloneContext& ctx,
const DataMap& inputs,
DataMap&) const {
auto* binding_info = inputs.Get<BindingInfo>();
if (!binding_info) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform,
"missing transform data for " + std::string(TypeInfo().name));
return;
}
State(ctx, binding_info).Run();
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,110 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_COMBINE_SAMPLERS_H_
#define SRC_TINT_TRANSFORM_COMBINE_SAMPLERS_H_
#include <string>
#include <unordered_map>
#include "src/tint/sem/sampler_texture_pair.h"
#include "src/tint/transform/transform.h"
namespace tint {
namespace transform {
/// This transform converts all separate texture/sampler refences in a
/// program into combined texture/samplers. This is required for GLSL,
/// which does not support separate texture/samplers.
///
/// It utilizes the texture/sampler information collected by the
/// Resolver and stored on each sem::Function. For each function, all
/// separate texture/sampler parameters in the function signature are
/// removed. For each unique pair, if both texture and sampler are
/// global variables, the function passes the corresponding combined
/// global stored in global_combined_texture_samplers_ at the call
/// site. Otherwise, either the texture or sampler must be a function
/// parameter. In this case, a new parameter is added to the function
/// signature. All separate texture/sampler parameters are removed.
///
/// All texture builtin callsites are modified to pass the combined
/// texture/sampler as the first argument, and separate texture/sampler
/// arguments are removed.
///
/// Note that the sampler may be null, indicating that only a texture
/// reference was required (e.g., textureLoad). In this case, a
/// placeholder global sampler is used at the AST level. This will be
/// combined with the original texture to give a combined global, and
/// the placeholder removed (ignored) by the GLSL writer.
///
/// Note that the combined samplers are actually represented by a
/// Texture node at the AST level, since this contains all the
/// information needed to represent a combined sampler in GLSL
/// (dimensionality, component type, etc). The GLSL writer outputs such
/// (Tint) Textures as (GLSL) Samplers.
class CombineSamplers : public Castable<CombineSamplers, Transform> {
public:
/// A pair of binding points.
using SamplerTexturePair = sem::SamplerTexturePair;
/// A map from a sampler/texture pair to a named global.
using BindingMap = std::unordered_map<SamplerTexturePair, std::string>;
/// The client-provided mapping from separate texture and sampler binding
/// points to combined sampler binding point.
struct BindingInfo : public Castable<Data, transform::Data> {
/// Constructor
/// @param map the map of all (texture, sampler) -> (combined) pairs
/// @param placeholder the binding point to use for placeholder samplers.
BindingInfo(const BindingMap& map, const sem::BindingPoint& placeholder);
/// Copy constructor
/// @param other the other BindingInfo to copy
BindingInfo(const BindingInfo& other);
/// Destructor
~BindingInfo() override;
/// A map of bindings from (texture, sampler) -> combined sampler.
BindingMap binding_map;
/// The binding point to use for placeholder samplers.
sem::BindingPoint placeholder_binding_point;
};
/// Constructor
CombineSamplers();
/// Destructor
~CombineSamplers() override;
protected:
/// The PIMPL state for this transform
struct State;
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_COMBINE_SAMPLERS_H_

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,998 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/decompose_memory_access.h"
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "src/tint/ast/assignment_statement.h"
#include "src/tint/ast/call_statement.h"
#include "src/tint/ast/disable_validation_attribute.h"
#include "src/tint/ast/type_name.h"
#include "src/tint/ast/unary_op.h"
#include "src/tint/block_allocator.h"
#include "src/tint/program_builder.h"
#include "src/tint/sem/array.h"
#include "src/tint/sem/atomic_type.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/member_accessor_expression.h"
#include "src/tint/sem/reference_type.h"
#include "src/tint/sem/statement.h"
#include "src/tint/sem/struct.h"
#include "src/tint/sem/variable.h"
#include "src/tint/utils/hash.h"
#include "src/tint/utils/map.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeMemoryAccess);
TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeMemoryAccess::Intrinsic);
namespace tint {
namespace transform {
namespace {
/// Offset is a simple ast::Expression builder interface, used to build byte
/// offsets for storage and uniform buffer accesses.
struct Offset : Castable<Offset> {
/// @returns builds and returns the ast::Expression in `ctx.dst`
virtual const ast::Expression* Build(CloneContext& ctx) const = 0;
};
/// OffsetExpr is an implementation of Offset that clones and casts the given
/// expression to `u32`.
struct OffsetExpr : Offset {
const ast::Expression* const expr = nullptr;
explicit OffsetExpr(const ast::Expression* e) : expr(e) {}
const ast::Expression* Build(CloneContext& ctx) const override {
auto* type = ctx.src->Sem().Get(expr)->Type()->UnwrapRef();
auto* res = ctx.Clone(expr);
if (!type->Is<sem::U32>()) {
res = ctx.dst->Construct<ProgramBuilder::u32>(res);
}
return res;
}
};
/// OffsetLiteral is an implementation of Offset that constructs a u32 literal
/// value.
struct OffsetLiteral : Castable<OffsetLiteral, Offset> {
uint32_t const literal = 0;
explicit OffsetLiteral(uint32_t lit) : literal(lit) {}
const ast::Expression* Build(CloneContext& ctx) const override {
return ctx.dst->Expr(literal);
}
};
/// OffsetBinOp is an implementation of Offset that constructs a binary-op of
/// two Offsets.
struct OffsetBinOp : Offset {
ast::BinaryOp op;
Offset const* lhs = nullptr;
Offset const* rhs = nullptr;
const ast::Expression* Build(CloneContext& ctx) const override {
return ctx.dst->create<ast::BinaryExpression>(op, lhs->Build(ctx),
rhs->Build(ctx));
}
};
/// LoadStoreKey is the unordered map key to a load or store intrinsic.
struct LoadStoreKey {
ast::StorageClass const storage_class; // buffer storage class
sem::Type const* buf_ty = nullptr; // buffer type
sem::Type const* el_ty = nullptr; // element type
bool operator==(const LoadStoreKey& rhs) const {
return storage_class == rhs.storage_class && buf_ty == rhs.buf_ty &&
el_ty == rhs.el_ty;
}
struct Hasher {
inline std::size_t operator()(const LoadStoreKey& u) const {
return utils::Hash(u.storage_class, u.buf_ty, u.el_ty);
}
};
};
/// AtomicKey is the unordered map key to an atomic intrinsic.
struct AtomicKey {
sem::Type const* buf_ty = nullptr; // buffer type
sem::Type const* el_ty = nullptr; // element type
sem::BuiltinType const op; // atomic op
bool operator==(const AtomicKey& rhs) const {
return buf_ty == rhs.buf_ty && el_ty == rhs.el_ty && op == rhs.op;
}
struct Hasher {
inline std::size_t operator()(const AtomicKey& u) const {
return utils::Hash(u.buf_ty, u.el_ty, u.op);
}
};
};
bool IntrinsicDataTypeFor(const sem::Type* ty,
DecomposeMemoryAccess::Intrinsic::DataType& out) {
if (ty->Is<sem::I32>()) {
out = DecomposeMemoryAccess::Intrinsic::DataType::kI32;
return true;
}
if (ty->Is<sem::U32>()) {
out = DecomposeMemoryAccess::Intrinsic::DataType::kU32;
return true;
}
if (ty->Is<sem::F32>()) {
out = DecomposeMemoryAccess::Intrinsic::DataType::kF32;
return true;
}
if (auto* vec = ty->As<sem::Vector>()) {
switch (vec->Width()) {
case 2:
if (vec->type()->Is<sem::I32>()) {
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2I32;
return true;
}
if (vec->type()->Is<sem::U32>()) {
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2U32;
return true;
}
if (vec->type()->Is<sem::F32>()) {
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2F32;
return true;
}
break;
case 3:
if (vec->type()->Is<sem::I32>()) {
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3I32;
return true;
}
if (vec->type()->Is<sem::U32>()) {
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3U32;
return true;
}
if (vec->type()->Is<sem::F32>()) {
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3F32;
return true;
}
break;
case 4:
if (vec->type()->Is<sem::I32>()) {
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4I32;
return true;
}
if (vec->type()->Is<sem::U32>()) {
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4U32;
return true;
}
if (vec->type()->Is<sem::F32>()) {
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4F32;
return true;
}
break;
}
return false;
}
return false;
}
/// @returns a DecomposeMemoryAccess::Intrinsic attribute that can be applied
/// to a stub function to load the type `ty`.
DecomposeMemoryAccess::Intrinsic* IntrinsicLoadFor(
ProgramBuilder* builder,
ast::StorageClass storage_class,
const sem::Type* ty) {
DecomposeMemoryAccess::Intrinsic::DataType type;
if (!IntrinsicDataTypeFor(ty, type)) {
return nullptr;
}
return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
builder->ID(), DecomposeMemoryAccess::Intrinsic::Op::kLoad, storage_class,
type);
}
/// @returns a DecomposeMemoryAccess::Intrinsic attribute that can be applied
/// to a stub function to store the type `ty`.
DecomposeMemoryAccess::Intrinsic* IntrinsicStoreFor(
ProgramBuilder* builder,
ast::StorageClass storage_class,
const sem::Type* ty) {
DecomposeMemoryAccess::Intrinsic::DataType type;
if (!IntrinsicDataTypeFor(ty, type)) {
return nullptr;
}
return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
builder->ID(), DecomposeMemoryAccess::Intrinsic::Op::kStore,
storage_class, type);
}
/// @returns a DecomposeMemoryAccess::Intrinsic attribute that can be applied
/// to a stub function for the atomic op and the type `ty`.
DecomposeMemoryAccess::Intrinsic* IntrinsicAtomicFor(ProgramBuilder* builder,
sem::BuiltinType ity,
const sem::Type* ty) {
auto op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicLoad;
switch (ity) {
case sem::BuiltinType::kAtomicLoad:
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicLoad;
break;
case sem::BuiltinType::kAtomicStore:
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicStore;
break;
case sem::BuiltinType::kAtomicAdd:
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicAdd;
break;
case sem::BuiltinType::kAtomicSub:
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicSub;
break;
case sem::BuiltinType::kAtomicMax:
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicMax;
break;
case sem::BuiltinType::kAtomicMin:
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicMin;
break;
case sem::BuiltinType::kAtomicAnd:
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicAnd;
break;
case sem::BuiltinType::kAtomicOr:
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicOr;
break;
case sem::BuiltinType::kAtomicXor:
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicXor;
break;
case sem::BuiltinType::kAtomicExchange:
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicExchange;
break;
case sem::BuiltinType::kAtomicCompareExchangeWeak:
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicCompareExchangeWeak;
break;
default:
TINT_ICE(Transform, builder->Diagnostics())
<< "invalid IntrinsicType for DecomposeMemoryAccess::Intrinsic: "
<< ty->type_name();
break;
}
DecomposeMemoryAccess::Intrinsic::DataType type;
if (!IntrinsicDataTypeFor(ty, type)) {
return nullptr;
}
return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
builder->ID(), op, ast::StorageClass::kStorage, type);
}
/// BufferAccess describes a single storage or uniform buffer access
struct BufferAccess {
sem::Expression const* var = nullptr; // Storage buffer variable
Offset const* offset = nullptr; // The byte offset on var
sem::Type const* type = nullptr; // The type of the access
operator bool() const { return var; } // Returns true if valid
};
/// Store describes a single storage or uniform buffer write
struct Store {
const ast::AssignmentStatement* assignment; // The AST assignment statement
BufferAccess target; // The target for the write
};
} // namespace
/// State holds the current transform state
struct DecomposeMemoryAccess::State {
/// The clone context
CloneContext& ctx;
/// Alias to `*ctx.dst`
ProgramBuilder& b;
/// Map of AST expression to storage or uniform buffer access
/// This map has entries added when encountered, and removed when outer
/// expressions chain the access.
/// Subset of #expression_order, as expressions are not removed from
/// #expression_order.
std::unordered_map<const ast::Expression*, BufferAccess> accesses;
/// The visited order of AST expressions (superset of #accesses)
std::vector<const ast::Expression*> expression_order;
/// [buffer-type, element-type] -> load function name
std::unordered_map<LoadStoreKey, Symbol, LoadStoreKey::Hasher> load_funcs;
/// [buffer-type, element-type] -> store function name
std::unordered_map<LoadStoreKey, Symbol, LoadStoreKey::Hasher> store_funcs;
/// [buffer-type, element-type, atomic-op] -> load function name
std::unordered_map<AtomicKey, Symbol, AtomicKey::Hasher> atomic_funcs;
/// List of storage or uniform buffer writes
std::vector<Store> stores;
/// Allocations for offsets
BlockAllocator<Offset> offsets_;
/// Constructor
/// @param context the CloneContext
explicit State(CloneContext& context) : ctx(context), b(*ctx.dst) {}
/// @param offset the offset value to wrap in an Offset
/// @returns an Offset for the given literal value
const Offset* ToOffset(uint32_t offset) {
return offsets_.Create<OffsetLiteral>(offset);
}
/// @param expr the expression to convert to an Offset
/// @returns an Offset for the given ast::Expression
const Offset* ToOffset(const ast::Expression* expr) {
if (auto* u32 = expr->As<ast::UintLiteralExpression>()) {
return offsets_.Create<OffsetLiteral>(u32->value);
} else if (auto* i32 = expr->As<ast::SintLiteralExpression>()) {
if (i32->value > 0) {
return offsets_.Create<OffsetLiteral>(i32->value);
}
}
return offsets_.Create<OffsetExpr>(expr);
}
/// @param offset the Offset that is returned
/// @returns the given offset (pass-through)
const Offset* ToOffset(const Offset* offset) { return offset; }
/// @param lhs_ the left-hand side of the add expression
/// @param rhs_ the right-hand side of the add expression
/// @return an Offset that is a sum of lhs and rhs, performing basic constant
/// folding if possible
template <typename LHS, typename RHS>
const Offset* Add(LHS&& lhs_, RHS&& rhs_) {
auto* lhs = ToOffset(std::forward<LHS>(lhs_));
auto* rhs = ToOffset(std::forward<RHS>(rhs_));
auto* lhs_lit = tint::As<OffsetLiteral>(lhs);
auto* rhs_lit = tint::As<OffsetLiteral>(rhs);
if (lhs_lit && lhs_lit->literal == 0) {
return rhs;
}
if (rhs_lit && rhs_lit->literal == 0) {
return lhs;
}
if (lhs_lit && rhs_lit) {
if (static_cast<uint64_t>(lhs_lit->literal) +
static_cast<uint64_t>(rhs_lit->literal) <=
0xffffffff) {
return offsets_.Create<OffsetLiteral>(lhs_lit->literal +
rhs_lit->literal);
}
}
auto* out = offsets_.Create<OffsetBinOp>();
out->op = ast::BinaryOp::kAdd;
out->lhs = lhs;
out->rhs = rhs;
return out;
}
/// @param lhs_ the left-hand side of the multiply expression
/// @param rhs_ the right-hand side of the multiply expression
/// @return an Offset that is the multiplication of lhs and rhs, performing
/// basic constant folding if possible
template <typename LHS, typename RHS>
const Offset* Mul(LHS&& lhs_, RHS&& rhs_) {
auto* lhs = ToOffset(std::forward<LHS>(lhs_));
auto* rhs = ToOffset(std::forward<RHS>(rhs_));
auto* lhs_lit = tint::As<OffsetLiteral>(lhs);
auto* rhs_lit = tint::As<OffsetLiteral>(rhs);
if (lhs_lit && lhs_lit->literal == 0) {
return offsets_.Create<OffsetLiteral>(0);
}
if (rhs_lit && rhs_lit->literal == 0) {
return offsets_.Create<OffsetLiteral>(0);
}
if (lhs_lit && lhs_lit->literal == 1) {
return rhs;
}
if (rhs_lit && rhs_lit->literal == 1) {
return lhs;
}
if (lhs_lit && rhs_lit) {
return offsets_.Create<OffsetLiteral>(lhs_lit->literal *
rhs_lit->literal);
}
auto* out = offsets_.Create<OffsetBinOp>();
out->op = ast::BinaryOp::kMultiply;
out->lhs = lhs;
out->rhs = rhs;
return out;
}
/// AddAccess() adds the `expr -> access` map item to #accesses, and `expr`
/// to #expression_order.
/// @param expr the expression that performs the access
/// @param access the access
void AddAccess(const ast::Expression* expr, const BufferAccess& access) {
TINT_ASSERT(Transform, access.type);
accesses.emplace(expr, access);
expression_order.emplace_back(expr);
}
/// TakeAccess() removes the `node` item from #accesses (if it exists),
/// returning the BufferAccess. If #accesses does not hold an item for
/// `node`, an invalid BufferAccess is returned.
/// @param node the expression that performed an access
/// @return the BufferAccess for the given expression
BufferAccess TakeAccess(const ast::Expression* node) {
auto lhs_it = accesses.find(node);
if (lhs_it == accesses.end()) {
return {};
}
auto access = lhs_it->second;
accesses.erase(node);
return access;
}
/// LoadFunc() returns a symbol to an intrinsic function that loads an element
/// of type `el_ty` from a storage or uniform buffer of type `buf_ty`.
/// The emitted function has the signature:
/// `fn load(buf : buf_ty, offset : u32) -> el_ty`
/// @param buf_ty the storage or uniform buffer type
/// @param el_ty the storage or uniform buffer element type
/// @param var_user the variable user
/// @return the name of the function that performs the load
Symbol LoadFunc(const sem::Type* buf_ty,
const sem::Type* el_ty,
const sem::VariableUser* var_user) {
auto storage_class = var_user->Variable()->StorageClass();
return utils::GetOrCreate(
load_funcs, LoadStoreKey{storage_class, buf_ty, el_ty}, [&] {
auto* buf_ast_ty = CreateASTTypeFor(ctx, buf_ty);
auto* disable_validation = b.Disable(
ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
ast::VariableList params = {
// Note: The buffer parameter requires the StorageClass in
// order for HLSL to emit this as a ByteAddressBuffer or cbuffer
// array.
b.create<ast::Variable>(b.Sym("buffer"), storage_class,
var_user->Variable()->Access(),
buf_ast_ty, true, false, nullptr,
ast::AttributeList{disable_validation}),
b.Param("offset", b.ty.u32()),
};
auto name = b.Sym();
if (auto* intrinsic =
IntrinsicLoadFor(ctx.dst, storage_class, el_ty)) {
auto* el_ast_ty = CreateASTTypeFor(ctx, el_ty);
auto* func = b.create<ast::Function>(
name, params, el_ast_ty, nullptr,
ast::AttributeList{
intrinsic,
b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
},
ast::AttributeList{});
b.AST().AddFunction(func);
} else if (auto* arr_ty = el_ty->As<sem::Array>()) {
// fn load_func(buf : buf_ty, offset : u32) -> array<T, N> {
// var arr : array<T, N>;
// for (var i = 0u; i < array_count; i = i + 1) {
// arr[i] = el_load_func(buf, offset + i * array_stride)
// }
// return arr;
// }
auto load =
LoadFunc(buf_ty, arr_ty->ElemType()->UnwrapRef(), var_user);
auto* arr =
b.Var(b.Symbols().New("arr"), CreateASTTypeFor(ctx, arr_ty));
auto* i = b.Var(b.Symbols().New("i"), nullptr, b.Expr(0u));
auto* for_init = b.Decl(i);
auto* for_cond = b.create<ast::BinaryExpression>(
ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(arr_ty->Count()));
auto* for_cont = b.Assign(i, b.Add(i, 1u));
auto* arr_el = b.IndexAccessor(arr, i);
auto* el_offset =
b.Add(b.Expr("offset"), b.Mul(i, arr_ty->Stride()));
auto* el_val = b.Call(load, "buffer", el_offset);
auto* for_loop = b.For(for_init, for_cond, for_cont,
b.Block(b.Assign(arr_el, el_val)));
b.Func(name, params, CreateASTTypeFor(ctx, arr_ty),
{
b.Decl(arr),
for_loop,
b.Return(arr),
});
} else {
ast::ExpressionList values;
if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
auto* vec_ty = mat_ty->ColumnType();
Symbol load = LoadFunc(buf_ty, vec_ty, var_user);
for (uint32_t i = 0; i < mat_ty->columns(); i++) {
auto* offset = b.Add("offset", i * mat_ty->ColumnStride());
values.emplace_back(b.Call(load, "buffer", offset));
}
} else if (auto* str = el_ty->As<sem::Struct>()) {
for (auto* member : str->Members()) {
auto* offset = b.Add("offset", member->Offset());
Symbol load =
LoadFunc(buf_ty, member->Type()->UnwrapRef(), var_user);
values.emplace_back(b.Call(load, "buffer", offset));
}
}
b.Func(
name, params, CreateASTTypeFor(ctx, el_ty),
{
b.Return(b.Construct(CreateASTTypeFor(ctx, el_ty), values)),
});
}
return name;
});
}
/// StoreFunc() returns a symbol to an intrinsic function that stores an
/// element of type `el_ty` to a storage buffer of type `buf_ty`.
/// The function has the signature:
/// `fn store(buf : buf_ty, offset : u32, value : el_ty)`
/// @param buf_ty the storage buffer type
/// @param el_ty the storage buffer element type
/// @param var_user the variable user
/// @return the name of the function that performs the store
Symbol StoreFunc(const sem::Type* buf_ty,
const sem::Type* el_ty,
const sem::VariableUser* var_user) {
auto storage_class = var_user->Variable()->StorageClass();
return utils::GetOrCreate(
store_funcs, LoadStoreKey{storage_class, buf_ty, el_ty}, [&] {
auto* buf_ast_ty = CreateASTTypeFor(ctx, buf_ty);
auto* el_ast_ty = CreateASTTypeFor(ctx, el_ty);
auto* disable_validation = b.Disable(
ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
ast::VariableList params{
// Note: The buffer parameter requires the StorageClass in
// order for HLSL to emit this as a ByteAddressBuffer.
b.create<ast::Variable>(b.Sym("buffer"), storage_class,
var_user->Variable()->Access(),
buf_ast_ty, true, false, nullptr,
ast::AttributeList{disable_validation}),
b.Param("offset", b.ty.u32()),
b.Param("value", el_ast_ty),
};
auto name = b.Sym();
if (auto* intrinsic =
IntrinsicStoreFor(ctx.dst, storage_class, el_ty)) {
auto* func = b.create<ast::Function>(
name, params, b.ty.void_(), nullptr,
ast::AttributeList{
intrinsic,
b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
},
ast::AttributeList{});
b.AST().AddFunction(func);
} else {
ast::StatementList body;
if (auto* arr_ty = el_ty->As<sem::Array>()) {
// fn store_func(buf : buf_ty, offset : u32, value : el_ty) {
// var array = value; // No dynamic indexing on constant arrays
// for (var i = 0u; i < array_count; i = i + 1) {
// arr[i] = el_store_func(buf, offset + i * array_stride,
// value[i])
// }
// return arr;
// }
auto* array =
b.Var(b.Symbols().New("array"), nullptr, b.Expr("value"));
auto store =
StoreFunc(buf_ty, arr_ty->ElemType()->UnwrapRef(), var_user);
auto* i = b.Var(b.Symbols().New("i"), nullptr, b.Expr(0u));
auto* for_init = b.Decl(i);
auto* for_cond = b.create<ast::BinaryExpression>(
ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(arr_ty->Count()));
auto* for_cont = b.Assign(i, b.Add(i, 1u));
auto* arr_el = b.IndexAccessor(array, i);
auto* el_offset =
b.Add(b.Expr("offset"), b.Mul(i, arr_ty->Stride()));
auto* store_stmt =
b.CallStmt(b.Call(store, "buffer", el_offset, arr_el));
auto* for_loop =
b.For(for_init, for_cond, for_cont, b.Block(store_stmt));
body = {b.Decl(array), for_loop};
} else if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
auto* vec_ty = mat_ty->ColumnType();
Symbol store = StoreFunc(buf_ty, vec_ty, var_user);
for (uint32_t i = 0; i < mat_ty->columns(); i++) {
auto* offset = b.Add("offset", i * mat_ty->ColumnStride());
auto* access = b.IndexAccessor("value", i);
auto* call = b.Call(store, "buffer", offset, access);
body.emplace_back(b.CallStmt(call));
}
} else if (auto* str = el_ty->As<sem::Struct>()) {
for (auto* member : str->Members()) {
auto* offset = b.Add("offset", member->Offset());
auto* access = b.MemberAccessor(
"value", ctx.Clone(member->Declaration()->symbol));
Symbol store =
StoreFunc(buf_ty, member->Type()->UnwrapRef(), var_user);
auto* call = b.Call(store, "buffer", offset, access);
body.emplace_back(b.CallStmt(call));
}
}
b.Func(name, params, b.ty.void_(), body);
}
return name;
});
}
/// AtomicFunc() returns a symbol to an intrinsic function that performs an
/// atomic operation from a storage buffer of type `buf_ty`. The function has
/// the signature:
// `fn atomic_op(buf : buf_ty, offset : u32, ...) -> T`
/// @param buf_ty the storage buffer type
/// @param el_ty the storage buffer element type
/// @param intrinsic the atomic intrinsic
/// @param var_user the variable user
/// @return the name of the function that performs the load
Symbol AtomicFunc(const sem::Type* buf_ty,
const sem::Type* el_ty,
const sem::Builtin* intrinsic,
const sem::VariableUser* var_user) {
auto op = intrinsic->Type();
return utils::GetOrCreate(atomic_funcs, AtomicKey{buf_ty, el_ty, op}, [&] {
auto* buf_ast_ty = CreateASTTypeFor(ctx, buf_ty);
auto* disable_validation = b.Disable(
ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
// The first parameter to all WGSL atomics is the expression to the
// atomic. This is replaced with two parameters: the buffer and offset.
ast::VariableList params = {
// Note: The buffer parameter requires the kStorage StorageClass in
// order for HLSL to emit this as a ByteAddressBuffer.
b.create<ast::Variable>(b.Sym("buffer"), ast::StorageClass::kStorage,
var_user->Variable()->Access(), buf_ast_ty,
true, false, nullptr,
ast::AttributeList{disable_validation}),
b.Param("offset", b.ty.u32()),
};
// Other parameters are copied as-is:
for (size_t i = 1; i < intrinsic->Parameters().size(); i++) {
auto* param = intrinsic->Parameters()[i];
auto* ty = CreateASTTypeFor(ctx, param->Type());
params.emplace_back(b.Param("param_" + std::to_string(i), ty));
}
auto* atomic = IntrinsicAtomicFor(ctx.dst, op, el_ty);
if (atomic == nullptr) {
TINT_ICE(Transform, b.Diagnostics())
<< "IntrinsicAtomicFor() returned nullptr for op " << op
<< " and type " << el_ty->type_name();
}
auto* ret_ty = CreateASTTypeFor(ctx, intrinsic->ReturnType());
auto* func = b.create<ast::Function>(
b.Sym(), params, ret_ty, nullptr,
ast::AttributeList{
atomic,
b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
},
ast::AttributeList{});
b.AST().AddFunction(func);
return func->symbol;
});
}
};
DecomposeMemoryAccess::Intrinsic::Intrinsic(ProgramID pid,
Op o,
ast::StorageClass sc,
DataType ty)
: Base(pid), op(o), storage_class(sc), type(ty) {}
DecomposeMemoryAccess::Intrinsic::~Intrinsic() = default;
std::string DecomposeMemoryAccess::Intrinsic::InternalName() const {
std::stringstream ss;
switch (op) {
case Op::kLoad:
ss << "intrinsic_load_";
break;
case Op::kStore:
ss << "intrinsic_store_";
break;
case Op::kAtomicLoad:
ss << "intrinsic_atomic_load_";
break;
case Op::kAtomicStore:
ss << "intrinsic_atomic_store_";
break;
case Op::kAtomicAdd:
ss << "intrinsic_atomic_add_";
break;
case Op::kAtomicSub:
ss << "intrinsic_atomic_sub_";
break;
case Op::kAtomicMax:
ss << "intrinsic_atomic_max_";
break;
case Op::kAtomicMin:
ss << "intrinsic_atomic_min_";
break;
case Op::kAtomicAnd:
ss << "intrinsic_atomic_and_";
break;
case Op::kAtomicOr:
ss << "intrinsic_atomic_or_";
break;
case Op::kAtomicXor:
ss << "intrinsic_atomic_xor_";
break;
case Op::kAtomicExchange:
ss << "intrinsic_atomic_exchange_";
break;
case Op::kAtomicCompareExchangeWeak:
ss << "intrinsic_atomic_compare_exchange_weak_";
break;
}
ss << storage_class << "_";
switch (type) {
case DataType::kU32:
ss << "u32";
break;
case DataType::kF32:
ss << "f32";
break;
case DataType::kI32:
ss << "i32";
break;
case DataType::kVec2U32:
ss << "vec2_u32";
break;
case DataType::kVec2F32:
ss << "vec2_f32";
break;
case DataType::kVec2I32:
ss << "vec2_i32";
break;
case DataType::kVec3U32:
ss << "vec3_u32";
break;
case DataType::kVec3F32:
ss << "vec3_f32";
break;
case DataType::kVec3I32:
ss << "vec3_i32";
break;
case DataType::kVec4U32:
ss << "vec4_u32";
break;
case DataType::kVec4F32:
ss << "vec4_f32";
break;
case DataType::kVec4I32:
ss << "vec4_i32";
break;
}
return ss.str();
}
const DecomposeMemoryAccess::Intrinsic* DecomposeMemoryAccess::Intrinsic::Clone(
CloneContext* ctx) const {
return ctx->dst->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
ctx->dst->ID(), op, storage_class, type);
}
DecomposeMemoryAccess::DecomposeMemoryAccess() = default;
DecomposeMemoryAccess::~DecomposeMemoryAccess() = default;
bool DecomposeMemoryAccess::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* decl : program->AST().GlobalDeclarations()) {
if (auto* var = program->Sem().Get<sem::Variable>(decl)) {
if (var->StorageClass() == ast::StorageClass::kStorage ||
var->StorageClass() == ast::StorageClass::kUniform) {
return true;
}
}
}
return false;
}
void DecomposeMemoryAccess::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
auto& sem = ctx.src->Sem();
State state(ctx);
// Scan the AST nodes for storage and uniform buffer accesses. Complex
// expression chains (e.g. `storage_buffer.foo.bar[20].x`) are handled by
// maintaining an offset chain via the `state.TakeAccess()`,
// `state.AddAccess()` methods.
//
// Inner-most expression nodes are guaranteed to be visited first because AST
// nodes are fully immutable and require their children to be constructed
// first so their pointer can be passed to the parent's constructor.
for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* ident = node->As<ast::IdentifierExpression>()) {
// X
if (auto* var = sem.Get<sem::VariableUser>(ident)) {
if (var->Variable()->StorageClass() == ast::StorageClass::kStorage ||
var->Variable()->StorageClass() == ast::StorageClass::kUniform) {
// Variable to a storage or uniform buffer
state.AddAccess(ident, {
var,
state.ToOffset(0u),
var->Type()->UnwrapRef(),
});
}
}
continue;
}
if (auto* accessor = node->As<ast::MemberAccessorExpression>()) {
// X.Y
auto* accessor_sem = sem.Get(accessor);
if (auto* swizzle = accessor_sem->As<sem::Swizzle>()) {
if (swizzle->Indices().size() == 1) {
if (auto access = state.TakeAccess(accessor->structure)) {
auto* vec_ty = access.type->As<sem::Vector>();
auto* offset =
state.Mul(vec_ty->type()->Size(), swizzle->Indices()[0]);
state.AddAccess(accessor, {
access.var,
state.Add(access.offset, offset),
vec_ty->type()->UnwrapRef(),
});
}
}
} else {
if (auto access = state.TakeAccess(accessor->structure)) {
auto* str_ty = access.type->As<sem::Struct>();
auto* member = str_ty->FindMember(accessor->member->symbol);
auto offset = member->Offset();
state.AddAccess(accessor, {
access.var,
state.Add(access.offset, offset),
member->Type()->UnwrapRef(),
});
}
}
continue;
}
if (auto* accessor = node->As<ast::IndexAccessorExpression>()) {
if (auto access = state.TakeAccess(accessor->object)) {
// X[Y]
if (auto* arr = access.type->As<sem::Array>()) {
auto* offset = state.Mul(arr->Stride(), accessor->index);
state.AddAccess(accessor, {
access.var,
state.Add(access.offset, offset),
arr->ElemType()->UnwrapRef(),
});
continue;
}
if (auto* vec_ty = access.type->As<sem::Vector>()) {
auto* offset = state.Mul(vec_ty->type()->Size(), accessor->index);
state.AddAccess(accessor, {
access.var,
state.Add(access.offset, offset),
vec_ty->type()->UnwrapRef(),
});
continue;
}
if (auto* mat_ty = access.type->As<sem::Matrix>()) {
auto* offset = state.Mul(mat_ty->ColumnStride(), accessor->index);
state.AddAccess(accessor, {
access.var,
state.Add(access.offset, offset),
mat_ty->ColumnType(),
});
continue;
}
}
}
if (auto* op = node->As<ast::UnaryOpExpression>()) {
if (op->op == ast::UnaryOp::kAddressOf) {
// &X
if (auto access = state.TakeAccess(op->expr)) {
// HLSL does not support pointers, so just take the access from the
// reference and place it on the pointer.
state.AddAccess(op, access);
continue;
}
}
}
if (auto* assign = node->As<ast::AssignmentStatement>()) {
// X = Y
// Move the LHS access to a store.
if (auto lhs = state.TakeAccess(assign->lhs)) {
state.stores.emplace_back(Store{assign, lhs});
}
}
if (auto* call_expr = node->As<ast::CallExpression>()) {
auto* call = sem.Get(call_expr);
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
if (builtin->Type() == sem::BuiltinType::kArrayLength) {
// arrayLength(X)
// Don't convert X into a load, this builtin actually requires the
// real pointer.
state.TakeAccess(call_expr->args[0]);
continue;
}
if (builtin->IsAtomic()) {
if (auto access = state.TakeAccess(call_expr->args[0])) {
// atomic___(X)
ctx.Replace(call_expr, [=, &ctx, &state] {
auto* buf = access.var->Declaration();
auto* offset = access.offset->Build(ctx);
auto* buf_ty = access.var->Type()->UnwrapRef();
auto* el_ty = access.type->UnwrapRef()->As<sem::Atomic>()->Type();
Symbol func = state.AtomicFunc(
buf_ty, el_ty, builtin, access.var->As<sem::VariableUser>());
ast::ExpressionList args{ctx.Clone(buf), offset};
for (size_t i = 1; i < call_expr->args.size(); i++) {
auto* arg = call_expr->args[i];
args.emplace_back(ctx.Clone(arg));
}
return ctx.dst->Call(func, args);
});
}
}
}
}
}
// All remaining accesses are loads, transform these into calls to the
// corresponding load function
for (auto* expr : state.expression_order) {
auto access_it = state.accesses.find(expr);
if (access_it == state.accesses.end()) {
continue;
}
BufferAccess access = access_it->second;
ctx.Replace(expr, [=, &ctx, &state] {
auto* buf = access.var->Declaration();
auto* offset = access.offset->Build(ctx);
auto* buf_ty = access.var->Type()->UnwrapRef();
auto* el_ty = access.type->UnwrapRef();
Symbol func =
state.LoadFunc(buf_ty, el_ty, access.var->As<sem::VariableUser>());
return ctx.dst->Call(func, ctx.CloneWithoutTransform(buf), offset);
});
}
// And replace all storage and uniform buffer assignments with stores
for (auto store : state.stores) {
ctx.Replace(store.assignment, [=, &ctx, &state] {
auto* buf = store.target.var->Declaration();
auto* offset = store.target.offset->Build(ctx);
auto* buf_ty = store.target.var->Type()->UnwrapRef();
auto* el_ty = store.target.type->UnwrapRef();
auto* value = store.assignment->rhs;
Symbol func = state.StoreFunc(buf_ty, el_ty,
store.target.var->As<sem::VariableUser>());
auto* call = ctx.dst->Call(func, ctx.CloneWithoutTransform(buf), offset,
ctx.Clone(value));
return ctx.dst->CallStmt(call);
});
}
ctx.Clone();
}
} // namespace transform
} // namespace tint
TINT_INSTANTIATE_TYPEINFO(tint::transform::Offset);
TINT_INSTANTIATE_TYPEINFO(tint::transform::OffsetLiteral);

View File

@@ -0,0 +1,131 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_DECOMPOSE_MEMORY_ACCESS_H_
#define SRC_TINT_TRANSFORM_DECOMPOSE_MEMORY_ACCESS_H_
#include <string>
#include "src/tint/ast/internal_attribute.h"
#include "src/tint/transform/transform.h"
namespace tint {
// Forward declarations
class CloneContext;
namespace transform {
/// DecomposeMemoryAccess is a transform used to replace storage and uniform
/// buffer accesses with a combination of load, store or atomic functions on
/// primitive types.
class DecomposeMemoryAccess
: public Castable<DecomposeMemoryAccess, Transform> {
public:
/// Intrinsic is an InternalAttribute that's used to decorate a stub function
/// so that the HLSL transforms this into calls to
/// `[RW]ByteAddressBuffer.Load[N]()` or `[RW]ByteAddressBuffer.Store[N]()`,
/// with a possible cast.
class Intrinsic : public Castable<Intrinsic, ast::InternalAttribute> {
public:
/// Intrinsic op
enum class Op {
kLoad,
kStore,
kAtomicLoad,
kAtomicStore,
kAtomicAdd,
kAtomicSub,
kAtomicMax,
kAtomicMin,
kAtomicAnd,
kAtomicOr,
kAtomicXor,
kAtomicExchange,
kAtomicCompareExchangeWeak,
};
/// Intrinsic data type
enum class DataType {
kU32,
kF32,
kI32,
kVec2U32,
kVec2F32,
kVec2I32,
kVec3U32,
kVec3F32,
kVec3I32,
kVec4U32,
kVec4F32,
kVec4I32,
};
/// Constructor
/// @param program_id the identifier of the program that owns this node
/// @param o the op of the intrinsic
/// @param sc the storage class of the buffer
/// @param ty the data type of the intrinsic
Intrinsic(ProgramID program_id, Op o, ast::StorageClass sc, DataType ty);
/// Destructor
~Intrinsic() override;
/// @return a short description of the internal attribute which will be
/// displayed as `@internal(<name>)`
std::string InternalName() const override;
/// Performs a deep clone of this object using the CloneContext `ctx`.
/// @param ctx the clone context
/// @return the newly cloned object
const Intrinsic* Clone(CloneContext* ctx) const override;
/// The op of the intrinsic
const Op op;
/// The storage class of the buffer this intrinsic operates on
ast::StorageClass const storage_class;
/// The type of the intrinsic
const DataType type;
};
/// Constructor
DecomposeMemoryAccess();
/// Destructor
~DecomposeMemoryAccess() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
struct State;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_DECOMPOSE_MEMORY_ACCESS_H_

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,162 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/decompose_strided_array.h"
#include <unordered_map>
#include <utility>
#include <vector>
#include "src/tint/program_builder.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/expression.h"
#include "src/tint/sem/member_accessor_expression.h"
#include "src/tint/sem/type_constructor.h"
#include "src/tint/transform/simplify_pointers.h"
#include "src/tint/utils/hash.h"
#include "src/tint/utils/map.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeStridedArray);
namespace tint {
namespace transform {
namespace {
using DecomposedArrays = std::unordered_map<const sem::Array*, Symbol>;
} // namespace
DecomposeStridedArray::DecomposeStridedArray() = default;
DecomposeStridedArray::~DecomposeStridedArray() = default;
bool DecomposeStridedArray::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* ast = node->As<ast::Array>()) {
if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) {
return true;
}
}
}
return false;
}
void DecomposeStridedArray::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
const auto& sem = ctx.src->Sem();
static constexpr const char* kMemberName = "el";
// Maps an array type in the source program to the name of the struct wrapper
// type in the target program.
std::unordered_map<const sem::Array*, Symbol> decomposed;
// Find and replace all arrays with a @stride attribute with a array that has
// the @stride removed. If the source array stride does not match the natural
// stride for the array element type, then replace the array element type with
// a structure, holding a single field with a @size attribute equal to the
// array stride.
ctx.ReplaceAll([&](const ast::Array* ast) -> const ast::Array* {
if (auto* arr = sem.Get(ast)) {
if (!arr->IsStrideImplicit()) {
auto el_ty = utils::GetOrCreate(decomposed, arr, [&] {
auto name = ctx.dst->Symbols().New("strided_arr");
auto* member_ty = ctx.Clone(ast->type);
auto* member = ctx.dst->Member(kMemberName, member_ty,
{ctx.dst->MemberSize(arr->Stride())});
ctx.dst->Structure(name, {member});
return name;
});
auto* count = ctx.Clone(ast->count);
return ctx.dst->ty.array(ctx.dst->ty.type_name(el_ty), count);
}
if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) {
// Strip the @stride attribute
auto* ty = ctx.Clone(ast->type);
auto* count = ctx.Clone(ast->count);
return ctx.dst->ty.array(ty, count);
}
}
return nullptr;
});
// Find all array index-accessors expressions for arrays that have had their
// element changed to a single field structure. These expressions are adjusted
// to insert an additional member accessor for the single structure field.
// Example: `arr[i]` -> `arr[i].el`
ctx.ReplaceAll(
[&](const ast::IndexAccessorExpression* idx) -> const ast::Expression* {
if (auto* ty = ctx.src->TypeOf(idx->object)) {
if (auto* arr = ty->UnwrapRef()->As<sem::Array>()) {
if (!arr->IsStrideImplicit()) {
auto* expr = ctx.CloneWithoutTransform(idx);
return ctx.dst->MemberAccessor(expr, kMemberName);
}
}
}
return nullptr;
});
// Find all array type constructor expressions for array types that have had
// their element changed to a single field structure. These constructors are
// adjusted to wrap each of the arguments with an additional constructor for
// the new element structure type.
// Example:
// `@stride(32) array<i32, 3>(1, 2, 3)`
// ->
// `array<strided_arr, 3>(strided_arr(1), strided_arr(2), strided_arr(3))`
ctx.ReplaceAll(
[&](const ast::CallExpression* expr) -> const ast::Expression* {
if (!expr->args.empty()) {
if (auto* call = sem.Get(expr)) {
if (auto* ctor = call->Target()->As<sem::TypeConstructor>()) {
if (auto* arr = ctor->ReturnType()->As<sem::Array>()) {
// Begin by cloning the array constructor type or name
// If this is an unaliased array, this may add a new entry to
// decomposed.
// If this is an aliased array, decomposed should already be
// populated with any strided aliases.
ast::CallExpression::Target target;
if (expr->target.type) {
target.type = ctx.Clone(expr->target.type);
} else {
target.name = ctx.Clone(expr->target.name);
}
ast::ExpressionList args;
if (auto it = decomposed.find(arr); it != decomposed.end()) {
args.reserve(expr->args.size());
for (auto* arg : expr->args) {
args.emplace_back(
ctx.dst->Call(it->second, ctx.Clone(arg)));
}
} else {
args = ctx.Clone(expr->args);
}
return target.type ? ctx.dst->Construct(target.type, args)
: ctx.dst->Call(target.name, args);
}
}
}
}
return nullptr;
});
ctx.Clone();
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,61 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_DECOMPOSE_STRIDED_ARRAY_H_
#define SRC_TINT_TRANSFORM_DECOMPOSE_STRIDED_ARRAY_H_
#include "src/tint/transform/transform.h"
namespace tint {
namespace transform {
/// DecomposeStridedArray transforms replaces arrays with a non-default
/// `@stride` attribute with an array of structure elements, where the
/// structure contains a single field with an equivalent `@size` attribute.
/// `@stride` attributes on arrays that match the default stride are also
/// removed.
///
/// @note Depends on the following transforms to have been run first:
/// * SimplifyPointers
class DecomposeStridedArray
: public Castable<DecomposeStridedArray, Transform> {
public:
/// Constructor
DecomposeStridedArray();
/// Destructor
~DecomposeStridedArray() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_DECOMPOSE_STRIDED_ARRAY_H_

View File

@@ -0,0 +1,698 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/decompose_strided_array.h"
#include <memory>
#include <utility>
#include <vector>
#include "src/tint/program_builder.h"
#include "src/tint/transform/simplify_pointers.h"
#include "src/tint/transform/test_helper.h"
#include "src/tint/transform/unshadow.h"
namespace tint {
namespace transform {
namespace {
using DecomposeStridedArrayTest = TransformTest;
using f32 = ProgramBuilder::f32;
TEST_F(DecomposeStridedArrayTest, ShouldRunEmptyModule) {
ProgramBuilder b;
EXPECT_FALSE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
}
TEST_F(DecomposeStridedArrayTest, ShouldRunNonStridedArray) {
// var<private> arr : array<f32, 4>
ProgramBuilder b;
b.Global("arr", b.ty.array<f32, 4>(), ast::StorageClass::kPrivate);
EXPECT_FALSE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
}
TEST_F(DecomposeStridedArrayTest, ShouldRunDefaultStridedArray) {
// var<private> arr : @stride(4) array<f32, 4>
ProgramBuilder b;
b.Global("arr", b.ty.array<f32, 4>(4), ast::StorageClass::kPrivate);
EXPECT_TRUE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
}
TEST_F(DecomposeStridedArrayTest, ShouldRunExplicitStridedArray) {
// var<private> arr : @stride(16) array<f32, 4>
ProgramBuilder b;
b.Global("arr", b.ty.array<f32, 4>(16), ast::StorageClass::kPrivate);
EXPECT_TRUE(ShouldRun<DecomposeStridedArray>(Program(std::move(b))));
}
TEST_F(DecomposeStridedArrayTest, Empty) {
auto* src = R"()";
auto* expect = src;
auto got = Run<DecomposeStridedArray>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, PrivateDefaultStridedArray) {
// var<private> arr : @stride(4) array<f32, 4>
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let a : @stride(4) array<f32, 4> = a;
// let b : f32 = arr[1];
// }
ProgramBuilder b;
b.Global("arr", b.ty.array<f32, 4>(4), ast::StorageClass::kPrivate);
b.Func("f", {}, b.ty.void_(),
{
b.Decl(b.Const("a", b.ty.array<f32, 4>(4), b.Expr("arr"))),
b.Decl(b.Const("b", b.ty.f32(), b.IndexAccessor("arr", 1))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
var<private> arr : array<f32, 4>;
@stage(compute) @workgroup_size(1)
fn f() {
let a : array<f32, 4> = arr;
let b : f32 = arr[1];
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, PrivateStridedArray) {
// var<private> arr : @stride(32) array<f32, 4>
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let a : @stride(32) array<f32, 4> = a;
// let b : f32 = arr[1];
// }
ProgramBuilder b;
b.Global("arr", b.ty.array<f32, 4>(32), ast::StorageClass::kPrivate);
b.Func("f", {}, b.ty.void_(),
{
b.Decl(b.Const("a", b.ty.array<f32, 4>(32), b.Expr("arr"))),
b.Decl(b.Const("b", b.ty.f32(), b.IndexAccessor("arr", 1))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
struct strided_arr {
@size(32)
el : f32;
}
var<private> arr : array<strided_arr, 4>;
@stage(compute) @workgroup_size(1)
fn f() {
let a : array<strided_arr, 4> = arr;
let b : f32 = arr[1].el;
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, ReadUniformStridedArray) {
// struct S {
// a : @stride(32) array<f32, 4>;
// };
// @group(0) @binding(0) var<uniform> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let a : @stride(32) array<f32, 4> = s.a;
// let b : f32 = s.a[1];
// }
ProgramBuilder b;
auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(32))});
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
b.GroupAndBinding(0, 0));
b.Func("f", {}, b.ty.void_(),
{
b.Decl(b.Const("a", b.ty.array<f32, 4>(32),
b.MemberAccessor("s", "a"))),
b.Decl(b.Const("b", b.ty.f32(),
b.IndexAccessor(b.MemberAccessor("s", "a"), 1))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
struct strided_arr {
@size(32)
el : f32;
}
struct S {
a : array<strided_arr, 4>;
}
@group(0) @binding(0) var<uniform> s : S;
@stage(compute) @workgroup_size(1)
fn f() {
let a : array<strided_arr, 4> = s.a;
let b : f32 = s.a[1].el;
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, ReadUniformDefaultStridedArray) {
// struct S {
// a : @stride(16) array<vec4<f32>, 4>;
// };
// @group(0) @binding(0) var<uniform> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let a : @stride(16) array<vec4<f32>, 4> = s.a;
// let b : f32 = s.a[1][2];
// }
ProgramBuilder b;
auto* S =
b.Structure("S", {b.Member("a", b.ty.array(b.ty.vec4<f32>(), 4, 16))});
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
b.GroupAndBinding(0, 0));
b.Func("f", {}, b.ty.void_(),
{
b.Decl(b.Const("a", b.ty.array(b.ty.vec4<f32>(), 4, 16),
b.MemberAccessor("s", "a"))),
b.Decl(b.Const(
"b", b.ty.f32(),
b.IndexAccessor(b.IndexAccessor(b.MemberAccessor("s", "a"), 1),
2))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect =
R"(
struct S {
a : array<vec4<f32>, 4>;
}
@group(0) @binding(0) var<uniform> s : S;
@stage(compute) @workgroup_size(1)
fn f() {
let a : array<vec4<f32>, 4> = s.a;
let b : f32 = s.a[1][2];
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, ReadStorageStridedArray) {
// struct S {
// a : @stride(32) array<f32, 4>;
// };
// @group(0) @binding(0) var<storage> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let a : @stride(32) array<f32, 4> = s.a;
// let b : f32 = s.a[1];
// }
ProgramBuilder b;
auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(32))});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
b.GroupAndBinding(0, 0));
b.Func("f", {}, b.ty.void_(),
{
b.Decl(b.Const("a", b.ty.array<f32, 4>(32),
b.MemberAccessor("s", "a"))),
b.Decl(b.Const("b", b.ty.f32(),
b.IndexAccessor(b.MemberAccessor("s", "a"), 1))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
struct strided_arr {
@size(32)
el : f32;
}
struct S {
a : array<strided_arr, 4>;
}
@group(0) @binding(0) var<storage> s : S;
@stage(compute) @workgroup_size(1)
fn f() {
let a : array<strided_arr, 4> = s.a;
let b : f32 = s.a[1].el;
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, ReadStorageDefaultStridedArray) {
// struct S {
// a : @stride(4) array<f32, 4>;
// };
// @group(0) @binding(0) var<storage> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let a : @stride(4) array<f32, 4> = s.a;
// let b : f32 = s.a[1];
// }
ProgramBuilder b;
auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(4))});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
b.GroupAndBinding(0, 0));
b.Func("f", {}, b.ty.void_(),
{
b.Decl(b.Const("a", b.ty.array<f32, 4>(4),
b.MemberAccessor("s", "a"))),
b.Decl(b.Const("b", b.ty.f32(),
b.IndexAccessor(b.MemberAccessor("s", "a"), 1))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
struct S {
a : array<f32, 4>;
}
@group(0) @binding(0) var<storage> s : S;
@stage(compute) @workgroup_size(1)
fn f() {
let a : array<f32, 4> = s.a;
let b : f32 = s.a[1];
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, WriteStorageStridedArray) {
// struct S {
// a : @stride(32) array<f32, 4>;
// };
// @group(0) @binding(0) var<storage, read_write> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// s.a = @stride(32) array<f32, 4>();
// s.a = @stride(32) array<f32, 4>(1.0, 2.0, 3.0, 4.0);
// s.a[1] = 5.0;
// }
ProgramBuilder b;
auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(32))});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func(
"f", {}, b.ty.void_(),
{
b.Assign(b.MemberAccessor("s", "a"),
b.Construct(b.ty.array<f32, 4>(32))),
b.Assign(b.MemberAccessor("s", "a"),
b.Construct(b.ty.array<f32, 4>(32), 1.0f, 2.0f, 3.0f, 4.0f)),
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1), 5.0f),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect =
R"(
struct strided_arr {
@size(32)
el : f32;
}
struct S {
a : array<strided_arr, 4>;
}
@group(0) @binding(0) var<storage, read_write> s : S;
@stage(compute) @workgroup_size(1)
fn f() {
s.a = array<strided_arr, 4>();
s.a = array<strided_arr, 4>(strided_arr(1.0), strided_arr(2.0), strided_arr(3.0), strided_arr(4.0));
s.a[1].el = 5.0;
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, WriteStorageDefaultStridedArray) {
// struct S {
// a : @stride(4) array<f32, 4>;
// };
// @group(0) @binding(0) var<storage, read_write> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// s.a = @stride(4) array<f32, 4>();
// s.a = @stride(4) array<f32, 4>(1.0, 2.0, 3.0, 4.0);
// s.a[1] = 5.0;
// }
ProgramBuilder b;
auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(4))});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func(
"f", {}, b.ty.void_(),
{
b.Assign(b.MemberAccessor("s", "a"),
b.Construct(b.ty.array<f32, 4>(4))),
b.Assign(b.MemberAccessor("s", "a"),
b.Construct(b.ty.array<f32, 4>(4), 1.0f, 2.0f, 3.0f, 4.0f)),
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1), 5.0f),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect =
R"(
struct S {
a : array<f32, 4>;
}
@group(0) @binding(0) var<storage, read_write> s : S;
@stage(compute) @workgroup_size(1)
fn f() {
s.a = array<f32, 4>();
s.a = array<f32, 4>(1.0, 2.0, 3.0, 4.0);
s.a[1] = 5.0;
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, ReadWriteViaPointerLets) {
// struct S {
// a : @stride(32) array<f32, 4>;
// };
// @group(0) @binding(0) var<storage, read_write> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let a = &s.a;
// let b = &*&*(a);
// let c = *b;
// let d = (*b)[1];
// (*b) = @stride(32) array<f32, 4>(1.0, 2.0, 3.0, 4.0);
// (*b)[1] = 5.0;
// }
ProgramBuilder b;
auto* S = b.Structure("S", {b.Member("a", b.ty.array<f32, 4>(32))});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func("f", {}, b.ty.void_(),
{
b.Decl(b.Const("a", nullptr,
b.AddressOf(b.MemberAccessor("s", "a")))),
b.Decl(b.Const("b", nullptr,
b.AddressOf(b.Deref(b.AddressOf(b.Deref("a")))))),
b.Decl(b.Const("c", nullptr, b.Deref("b"))),
b.Decl(b.Const("d", nullptr, b.IndexAccessor(b.Deref("b"), 1))),
b.Assign(b.Deref("b"), b.Construct(b.ty.array<f32, 4>(32), 1.0f,
2.0f, 3.0f, 4.0f)),
b.Assign(b.IndexAccessor(b.Deref("b"), 1), 5.0f),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect =
R"(
struct strided_arr {
@size(32)
el : f32;
}
struct S {
a : array<strided_arr, 4>;
}
@group(0) @binding(0) var<storage, read_write> s : S;
@stage(compute) @workgroup_size(1)
fn f() {
let c = s.a;
let d = s.a[1].el;
s.a = array<strided_arr, 4>(strided_arr(1.0), strided_arr(2.0), strided_arr(3.0), strided_arr(4.0));
s.a[1].el = 5.0;
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, PrivateAliasedStridedArray) {
// type ARR = @stride(32) array<f32, 4>;
// struct S {
// a : ARR;
// };
// @group(0) @binding(0) var<storage, read_write> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let a : ARR = s.a;
// let b : f32 = s.a[1];
// s.a = ARR();
// s.a = ARR(1.0, 2.0, 3.0, 4.0);
// s.a[1] = 5.0;
// }
ProgramBuilder b;
b.Alias("ARR", b.ty.array<f32, 4>(32));
auto* S = b.Structure("S", {b.Member("a", b.ty.type_name("ARR"))});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func(
"f", {}, b.ty.void_(),
{
b.Decl(
b.Const("a", b.ty.type_name("ARR"), b.MemberAccessor("s", "a"))),
b.Decl(b.Const("b", b.ty.f32(),
b.IndexAccessor(b.MemberAccessor("s", "a"), 1))),
b.Assign(b.MemberAccessor("s", "a"),
b.Construct(b.ty.type_name("ARR"))),
b.Assign(b.MemberAccessor("s", "a"),
b.Construct(b.ty.type_name("ARR"), 1.0f, 2.0f, 3.0f, 4.0f)),
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1), 5.0f),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
struct strided_arr {
@size(32)
el : f32;
}
type ARR = array<strided_arr, 4>;
struct S {
a : ARR;
}
@group(0) @binding(0) var<storage, read_write> s : S;
@stage(compute) @workgroup_size(1)
fn f() {
let a : ARR = s.a;
let b : f32 = s.a[1].el;
s.a = ARR();
s.a = ARR(strided_arr(1.0), strided_arr(2.0), strided_arr(3.0), strided_arr(4.0));
s.a[1].el = 5.0;
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedArrayTest, PrivateNestedStridedArray) {
// type ARR_A = @stride(8) array<f32, 2>;
// type ARR_B = @stride(128) array<@stride(16) array<ARR_A, 3>, 4>;
// struct S {
// a : ARR_B;
// };
// @group(0) @binding(0) var<storage, read_write> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let a : ARR_B = s.a;
// let b : array<@stride(8) array<f32, 2>, 3> = s.a[3];
// let c = s.a[3][2];
// let d = s.a[3][2][1];
// s.a = ARR_B();
// s.a[3][2][1] = 5.0;
// }
ProgramBuilder b;
b.Alias("ARR_A", b.ty.array<f32, 2>(8));
b.Alias("ARR_B",
b.ty.array( //
b.ty.array(b.ty.type_name("ARR_A"), 3, 16), //
4, 128));
auto* S = b.Structure("S", {b.Member("a", b.ty.type_name("ARR_B"))});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func("f", {}, b.ty.void_(),
{
b.Decl(b.Const("a", b.ty.type_name("ARR_B"),
b.MemberAccessor("s", "a"))),
b.Decl(b.Const("b", b.ty.array(b.ty.type_name("ARR_A"), 3, 16),
b.IndexAccessor( //
b.MemberAccessor("s", "a"), //
3))),
b.Decl(b.Const("c", b.ty.type_name("ARR_A"),
b.IndexAccessor( //
b.IndexAccessor( //
b.MemberAccessor("s", "a"), //
3),
2))),
b.Decl(b.Const("d", b.ty.f32(),
b.IndexAccessor( //
b.IndexAccessor( //
b.IndexAccessor( //
b.MemberAccessor("s", "a"), //
3),
2),
1))),
b.Assign(b.MemberAccessor("s", "a"),
b.Construct(b.ty.type_name("ARR_B"))),
b.Assign(b.IndexAccessor( //
b.IndexAccessor( //
b.IndexAccessor( //
b.MemberAccessor("s", "a"), //
3),
2),
1),
5.0f),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect =
R"(
struct strided_arr {
@size(8)
el : f32;
}
type ARR_A = array<strided_arr, 2>;
struct strided_arr_1 {
@size(128)
el : array<ARR_A, 3>;
}
type ARR_B = array<strided_arr_1, 4>;
struct S {
a : ARR_B;
}
@group(0) @binding(0) var<storage, read_write> s : S;
@stage(compute) @workgroup_size(1)
fn f() {
let a : ARR_B = s.a;
let b : array<ARR_A, 3> = s.a[3].el;
let c : ARR_A = s.a[3].el[2];
let d : f32 = s.a[3].el[2][1].el;
s.a = ARR_B();
s.a[3].el[2][1].el = 5.0;
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedArray>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,251 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/decompose_strided_matrix.h"
#include <unordered_map>
#include <utility>
#include <vector>
#include "src/tint/program_builder.h"
#include "src/tint/sem/expression.h"
#include "src/tint/sem/member_accessor_expression.h"
#include "src/tint/transform/simplify_pointers.h"
#include "src/tint/utils/hash.h"
#include "src/tint/utils/map.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeStridedMatrix);
namespace tint {
namespace transform {
namespace {
/// MatrixInfo describes a matrix member with a custom stride
struct MatrixInfo {
/// The stride in bytes between columns of the matrix
uint32_t stride = 0;
/// The type of the matrix
const sem::Matrix* matrix = nullptr;
/// @returns a new ast::Array that holds an vector column for each row of the
/// matrix.
const ast::Array* array(ProgramBuilder* b) const {
return b->ty.array(b->ty.vec<ProgramBuilder::f32>(matrix->rows()),
matrix->columns(), stride);
}
/// Equality operator
bool operator==(const MatrixInfo& info) const {
return stride == info.stride && matrix == info.matrix;
}
/// Hash function
struct Hasher {
size_t operator()(const MatrixInfo& t) const {
return utils::Hash(t.stride, t.matrix);
}
};
};
/// Return type of the callback function of GatherCustomStrideMatrixMembers
enum GatherResult { kContinue, kStop };
/// GatherCustomStrideMatrixMembers scans `program` for all matrix members of
/// storage and uniform structs, which are of a matrix type, and have a custom
/// matrix stride attribute. For each matrix member found, `callback` is called.
/// `callback` is a function with the signature:
/// GatherResult(const sem::StructMember* member,
/// sem::Matrix* matrix,
/// uint32_t stride)
/// If `callback` return GatherResult::kStop, then the scanning will immediately
/// terminate, and GatherCustomStrideMatrixMembers() will return, otherwise
/// scanning will continue.
template <typename F>
void GatherCustomStrideMatrixMembers(const Program* program, F&& callback) {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* str = node->As<ast::Struct>()) {
auto* str_ty = program->Sem().Get(str);
if (!str_ty->UsedAs(ast::StorageClass::kUniform) &&
!str_ty->UsedAs(ast::StorageClass::kStorage)) {
continue;
}
for (auto* member : str_ty->Members()) {
auto* matrix = member->Type()->As<sem::Matrix>();
if (!matrix) {
continue;
}
auto* attr = ast::GetAttribute<ast::StrideAttribute>(
member->Declaration()->attributes);
if (!attr) {
continue;
}
uint32_t stride = attr->stride;
if (matrix->ColumnStride() == stride) {
continue;
}
if (callback(member, matrix, stride) == GatherResult::kStop) {
return;
}
}
}
}
}
} // namespace
DecomposeStridedMatrix::DecomposeStridedMatrix() = default;
DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
bool DecomposeStridedMatrix::ShouldRun(const Program* program,
const DataMap&) const {
bool should_run = false;
GatherCustomStrideMatrixMembers(
program, [&](const sem::StructMember*, sem::Matrix*, uint32_t) {
should_run = true;
return GatherResult::kStop;
});
return should_run;
}
void DecomposeStridedMatrix::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
// Scan the program for all storage and uniform structure matrix members with
// a custom stride attribute. Replace these matrices with an equivalent array,
// and populate the `decomposed` map with the members that have been replaced.
std::unordered_map<const ast::StructMember*, MatrixInfo> decomposed;
GatherCustomStrideMatrixMembers(
ctx.src, [&](const sem::StructMember* member, sem::Matrix* matrix,
uint32_t stride) {
// We've got ourselves a struct member of a matrix type with a custom
// stride. Replace this with an array of column vectors.
MatrixInfo info{stride, matrix};
auto* replacement = ctx.dst->Member(
member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst));
ctx.Replace(member->Declaration(), replacement);
decomposed.emplace(member->Declaration(), info);
return GatherResult::kContinue;
});
// For all expressions where a single matrix column vector was indexed, we can
// preserve these without calling conversion functions.
// Example:
// ssbo.mat[2] -> ssbo.mat[2]
ctx.ReplaceAll([&](const ast::IndexAccessorExpression* expr)
-> const ast::IndexAccessorExpression* {
if (auto* access =
ctx.src->Sem().Get<sem::StructMemberAccess>(expr->object)) {
auto it = decomposed.find(access->Member()->Declaration());
if (it != decomposed.end()) {
auto* obj = ctx.CloneWithoutTransform(expr->object);
auto* idx = ctx.Clone(expr->index);
return ctx.dst->IndexAccessor(obj, idx);
}
}
return nullptr;
});
// For all struct member accesses to the matrix on the LHS of an assignment,
// we need to convert the matrix to the array before assigning to the
// structure.
// Example:
// ssbo.mat = mat_to_arr(m)
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt)
-> const ast::Statement* {
if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
auto it = decomposed.find(access->Member()->Declaration());
if (it == decomposed.end()) {
return nullptr;
}
MatrixInfo info = it->second;
auto fn = utils::GetOrCreate(mat_to_arr, info, [&] {
auto name = ctx.dst->Symbols().New(
"mat" + std::to_string(info.matrix->columns()) + "x" +
std::to_string(info.matrix->rows()) + "_stride_" +
std::to_string(info.stride) + "_to_arr");
auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
auto array = [&] { return info.array(ctx.dst); };
auto mat = ctx.dst->Sym("m");
ast::ExpressionList columns(info.matrix->columns());
for (uint32_t i = 0; i < static_cast<uint32_t>(columns.size()); i++) {
columns[i] = ctx.dst->IndexAccessor(mat, i);
}
ctx.dst->Func(name,
{
ctx.dst->Param(mat, matrix()),
},
array(),
{
ctx.dst->Return(ctx.dst->Construct(array(), columns)),
});
return name;
});
auto* lhs = ctx.CloneWithoutTransform(stmt->lhs);
auto* rhs = ctx.dst->Call(fn, ctx.Clone(stmt->rhs));
return ctx.dst->Assign(lhs, rhs);
}
return nullptr;
});
// For all other struct member accesses, we need to convert the array to the
// matrix type. Example:
// m = arr_to_mat(ssbo.mat)
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
ctx.ReplaceAll(
[&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* {
if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr)) {
auto it = decomposed.find(access->Member()->Declaration());
if (it == decomposed.end()) {
return nullptr;
}
MatrixInfo info = it->second;
auto fn = utils::GetOrCreate(arr_to_mat, info, [&] {
auto name = ctx.dst->Symbols().New(
"arr_to_mat" + std::to_string(info.matrix->columns()) + "x" +
std::to_string(info.matrix->rows()) + "_stride_" +
std::to_string(info.stride));
auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
auto array = [&] { return info.array(ctx.dst); };
auto arr = ctx.dst->Sym("arr");
ast::ExpressionList columns(info.matrix->columns());
for (uint32_t i = 0; i < static_cast<uint32_t>(columns.size());
i++) {
columns[i] = ctx.dst->IndexAccessor(arr, i);
}
ctx.dst->Func(
name,
{
ctx.dst->Param(arr, array()),
},
matrix(),
{
ctx.dst->Return(ctx.dst->Construct(matrix(), columns)),
});
return name;
});
return ctx.dst->Call(fn, ctx.CloneWithoutTransform(expr));
}
return nullptr;
});
ctx.Clone();
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,61 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_DECOMPOSE_STRIDED_MATRIX_H_
#define SRC_TINT_TRANSFORM_DECOMPOSE_STRIDED_MATRIX_H_
#include "src/tint/transform/transform.h"
namespace tint {
namespace transform {
/// DecomposeStridedMatrix transforms replaces matrix members of storage or
/// uniform buffer structures, that have a [[stride]] attribute, into an array
/// of N column vectors.
/// This transform is used by the SPIR-V reader to handle the SPIR-V
/// MatrixStride attribute.
///
/// @note Depends on the following transforms to have been run first:
/// * SimplifyPointers
class DecomposeStridedMatrix
: public Castable<DecomposeStridedMatrix, Transform> {
public:
/// Constructor
DecomposeStridedMatrix();
/// Destructor
~DecomposeStridedMatrix() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_DECOMPOSE_STRIDED_MATRIX_H_

View File

@@ -0,0 +1,671 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/decompose_strided_matrix.h"
#include <memory>
#include <utility>
#include <vector>
#include "src/tint/ast/disable_validation_attribute.h"
#include "src/tint/program_builder.h"
#include "src/tint/transform/simplify_pointers.h"
#include "src/tint/transform/test_helper.h"
#include "src/tint/transform/unshadow.h"
namespace tint {
namespace transform {
namespace {
using DecomposeStridedMatrixTest = TransformTest;
using f32 = ProgramBuilder::f32;
TEST_F(DecomposeStridedMatrixTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<DecomposeStridedMatrix>(src));
}
TEST_F(DecomposeStridedMatrixTest, ShouldRunNonStridedMatrox) {
auto* src = R"(
var<private> m : mat3x2<f32>;
)";
EXPECT_FALSE(ShouldRun<DecomposeStridedMatrix>(src));
}
TEST_F(DecomposeStridedMatrixTest, Empty) {
auto* src = R"()";
auto* expect = src;
auto got = Run<DecomposeStridedMatrix>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix) {
// struct S {
// @offset(16) @stride(32)
// @internal(ignore_stride_attribute)
// m : mat2x2<f32>;
// };
// @group(0) @binding(0) var<uniform> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let x : mat2x2<f32> = s.m;
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetAttribute>(16),
b.create<ast::StrideAttribute>(32),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
}),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
b.GroupAndBinding(0, 0));
b.Func(
"f", {}, b.ty.void_(),
{
b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
struct S {
@size(16)
padding : u32;
m : @stride(32) array<vec2<f32>, 2u>;
}
@group(0) @binding(0) var<uniform> s : S;
fn arr_to_mat2x2_stride_32(arr : @stride(32) array<vec2<f32>, 2u>) -> mat2x2<f32> {
return mat2x2<f32>(arr[0u], arr[1u]);
}
@stage(compute) @workgroup_size(1)
fn f() {
let x : mat2x2<f32> = arr_to_mat2x2_stride_32(s.m);
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadUniformColumn) {
// struct S {
// @offset(16) @stride(32)
// @internal(ignore_stride_attribute)
// m : mat2x2<f32>;
// };
// @group(0) @binding(0) var<uniform> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let x : vec2<f32> = s.m[1];
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetAttribute>(16),
b.create<ast::StrideAttribute>(32),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
}),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
b.GroupAndBinding(0, 0));
b.Func("f", {}, b.ty.void_(),
{
b.Decl(b.Const("x", b.ty.vec2<f32>(),
b.IndexAccessor(b.MemberAccessor("s", "m"), 1))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
struct S {
@size(16)
padding : u32;
m : @stride(32) array<vec2<f32>, 2u>;
}
@group(0) @binding(0) var<uniform> s : S;
@stage(compute) @workgroup_size(1)
fn f() {
let x : vec2<f32> = s.m[1];
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix_DefaultStride) {
// struct S {
// @offset(16) @stride(8)
// @internal(ignore_stride_attribute)
// m : mat2x2<f32>;
// };
// @group(0) @binding(0) var<uniform> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let x : mat2x2<f32> = s.m;
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetAttribute>(16),
b.create<ast::StrideAttribute>(8),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
}),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
b.GroupAndBinding(0, 0));
b.Func(
"f", {}, b.ty.void_(),
{
b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
struct S {
@size(16)
padding : u32;
@stride(8) @internal(disable_validation__ignore_stride)
m : mat2x2<f32>;
}
@group(0) @binding(0) var<uniform> s : S;
@stage(compute) @workgroup_size(1)
fn f() {
let x : mat2x2<f32> = s.m;
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadStorageMatrix) {
// struct S {
// @offset(8) @stride(32)
// @internal(ignore_stride_attribute)
// m : mat2x2<f32>;
// };
// @group(0) @binding(0) var<storage, read_write> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let x : mat2x2<f32> = s.m;
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetAttribute>(8),
b.create<ast::StrideAttribute>(32),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
}),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func(
"f", {}, b.ty.void_(),
{
b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
struct S {
@size(8)
padding : u32;
m : @stride(32) array<vec2<f32>, 2u>;
}
@group(0) @binding(0) var<storage, read_write> s : S;
fn arr_to_mat2x2_stride_32(arr : @stride(32) array<vec2<f32>, 2u>) -> mat2x2<f32> {
return mat2x2<f32>(arr[0u], arr[1u]);
}
@stage(compute) @workgroup_size(1)
fn f() {
let x : mat2x2<f32> = arr_to_mat2x2_stride_32(s.m);
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadStorageColumn) {
// struct S {
// @offset(16) @stride(32)
// @internal(ignore_stride_attribute)
// m : mat2x2<f32>;
// };
// @group(0) @binding(0) var<storage, read_write> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let x : vec2<f32> = s.m[1];
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetAttribute>(16),
b.create<ast::StrideAttribute>(32),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
}),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func("f", {}, b.ty.void_(),
{
b.Decl(b.Const("x", b.ty.vec2<f32>(),
b.IndexAccessor(b.MemberAccessor("s", "m"), 1))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
struct S {
@size(16)
padding : u32;
m : @stride(32) array<vec2<f32>, 2u>;
}
@group(0) @binding(0) var<storage, read_write> s : S;
@stage(compute) @workgroup_size(1)
fn f() {
let x : vec2<f32> = s.m[1];
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, WriteStorageMatrix) {
// struct S {
// @offset(8) @stride(32)
// @internal(ignore_stride_attribute)
// m : mat2x2<f32>;
// };
// @group(0) @binding(0) var<storage, read_write> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetAttribute>(8),
b.create<ast::StrideAttribute>(32),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
}),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func("f", {}, b.ty.void_(),
{
b.Assign(b.MemberAccessor("s", "m"),
b.mat2x2<f32>(b.vec2<f32>(1.0f, 2.0f),
b.vec2<f32>(3.0f, 4.0f))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
struct S {
@size(8)
padding : u32;
m : @stride(32) array<vec2<f32>, 2u>;
}
@group(0) @binding(0) var<storage, read_write> s : S;
fn mat2x2_stride_32_to_arr(m : mat2x2<f32>) -> @stride(32) array<vec2<f32>, 2u> {
return @stride(32) array<vec2<f32>, 2u>(m[0u], m[1u]);
}
@stage(compute) @workgroup_size(1)
fn f() {
s.m = mat2x2_stride_32_to_arr(mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0)));
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, WriteStorageColumn) {
// struct S {
// @offset(8) @stride(32)
// @internal(ignore_stride_attribute)
// m : mat2x2<f32>;
// };
// @group(0) @binding(0) var<storage, read_write> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// s.m[1] = vec2<f32>(1.0, 2.0);
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetAttribute>(8),
b.create<ast::StrideAttribute>(32),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
}),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func("f", {}, b.ty.void_(),
{
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m"), 1),
b.vec2<f32>(1.0f, 2.0f)),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
struct S {
@size(8)
padding : u32;
m : @stride(32) array<vec2<f32>, 2u>;
}
@group(0) @binding(0) var<storage, read_write> s : S;
@stage(compute) @workgroup_size(1)
fn f() {
s.m[1] = vec2<f32>(1.0, 2.0);
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadWriteViaPointerLets) {
// struct S {
// @offset(8) @stride(32)
// @internal(ignore_stride_attribute)
// m : mat2x2<f32>;
// };
// @group(0) @binding(0) var<storage, read_write> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let a = &s.m;
// let b = &*&*(a);
// let x = *b;
// let y = (*b)[1];
// let z = x[1];
// (*b) = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
// (*b)[1] = vec2<f32>(5.0, 6.0);
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetAttribute>(8),
b.create<ast::StrideAttribute>(32),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
}),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
b.Func(
"f", {}, b.ty.void_(),
{
b.Decl(
b.Const("a", nullptr, b.AddressOf(b.MemberAccessor("s", "m")))),
b.Decl(b.Const("b", nullptr,
b.AddressOf(b.Deref(b.AddressOf(b.Deref("a")))))),
b.Decl(b.Const("x", nullptr, b.Deref("b"))),
b.Decl(b.Const("y", nullptr, b.IndexAccessor(b.Deref("b"), 1))),
b.Decl(b.Const("z", nullptr, b.IndexAccessor("x", 1))),
b.Assign(b.Deref("b"), b.mat2x2<f32>(b.vec2<f32>(1.0f, 2.0f),
b.vec2<f32>(3.0f, 4.0f))),
b.Assign(b.IndexAccessor(b.Deref("b"), 1), b.vec2<f32>(5.0f, 6.0f)),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
struct S {
@size(8)
padding : u32;
m : @stride(32) array<vec2<f32>, 2u>;
}
@group(0) @binding(0) var<storage, read_write> s : S;
fn arr_to_mat2x2_stride_32(arr : @stride(32) array<vec2<f32>, 2u>) -> mat2x2<f32> {
return mat2x2<f32>(arr[0u], arr[1u]);
}
fn mat2x2_stride_32_to_arr(m : mat2x2<f32>) -> @stride(32) array<vec2<f32>, 2u> {
return @stride(32) array<vec2<f32>, 2u>(m[0u], m[1u]);
}
@stage(compute) @workgroup_size(1)
fn f() {
let x = arr_to_mat2x2_stride_32(s.m);
let y = s.m[1];
let z = x[1];
s.m = mat2x2_stride_32_to_arr(mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0)));
s.m[1] = vec2<f32>(5.0, 6.0);
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, ReadPrivateMatrix) {
// struct S {
// @offset(8) @stride(32)
// @internal(ignore_stride_attribute)
// m : mat2x2<f32>;
// };
// var<private> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// let x : mat2x2<f32> = s.m;
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetAttribute>(8),
b.create<ast::StrideAttribute>(32),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
}),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kPrivate);
b.Func(
"f", {}, b.ty.void_(),
{
b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
struct S {
@size(8)
padding : u32;
@stride(32) @internal(disable_validation__ignore_stride)
m : mat2x2<f32>;
}
var<private> s : S;
@stage(compute) @workgroup_size(1)
fn f() {
let x : mat2x2<f32> = s.m;
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, WritePrivateMatrix) {
// struct S {
// @offset(8) @stride(32)
// @internal(ignore_stride_attribute)
// m : mat2x2<f32>;
// };
// var<private> s : S;
//
// @stage(compute) @workgroup_size(1)
// fn f() {
// s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetAttribute>(8),
b.create<ast::StrideAttribute>(32),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
}),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kPrivate);
b.Func("f", {}, b.ty.void_(),
{
b.Assign(b.MemberAccessor("s", "m"),
b.mat2x2<f32>(b.vec2<f32>(1.0f, 2.0f),
b.vec2<f32>(3.0f, 4.0f))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect = R"(
struct S {
@size(8)
padding : u32;
@stride(32) @internal(disable_validation__ignore_stride)
m : mat2x2<f32>;
}
var<private> s : S;
@stage(compute) @workgroup_size(1)
fn f() {
s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
}
)";
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(
Program(std::move(b)));
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,131 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/external_texture_transform.h"
#include "src/tint/program_builder.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/variable.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::ExternalTextureTransform);
namespace tint {
namespace transform {
ExternalTextureTransform::ExternalTextureTransform() = default;
ExternalTextureTransform::~ExternalTextureTransform() = default;
void ExternalTextureTransform::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
auto& sem = ctx.src->Sem();
// Within this transform, usages of texture_external are replaced with a
// texture_2d<f32>, which will allow us perform operations on a
// texture_external without maintaining texture_external-specific code
// generation paths in the backends.
// When replacing instances of texture_external with texture_2d<f32> we must
// also modify calls to the texture_external overloads of textureLoad and
// textureSampleLevel, which unlike their texture_2d<f32> overloads do not
// require a level parameter. To do this we identify calls to textureLoad and
// textureSampleLevel that use texture_external as the first parameter and add
// a parameter for the level (which is always 0).
// Scan the AST nodes for calls to textureLoad or textureSampleLevel.
for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* call_expr = node->As<ast::CallExpression>()) {
if (auto* builtin = sem.Get(call_expr)->Target()->As<sem::Builtin>()) {
if (builtin->Type() == sem::BuiltinType::kTextureLoad ||
builtin->Type() == sem::BuiltinType::kTextureSampleLevel) {
// When a textureLoad or textureSampleLevel has been identified, check
// if the first parameter is an external texture.
if (auto* var =
sem.Get(call_expr->args[0])->As<sem::VariableUser>()) {
if (var->Variable()
->Type()
->UnwrapRef()
->Is<sem::ExternalTexture>()) {
if (builtin->Type() == sem::BuiltinType::kTextureLoad &&
call_expr->args.size() != 2) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "expected textureLoad call with a texture_external to "
"have 2 parameters, found "
<< call_expr->args.size() << " parameters";
}
if (builtin->Type() == sem::BuiltinType::kTextureSampleLevel &&
call_expr->args.size() != 3) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "expected textureSampleLevel call with a "
"texture_external to have 3 parameters, found "
<< call_expr->args.size() << " parameters";
}
// Replace the call with another that has the same parameters in
// addition to a level parameter (always zero for external
// textures).
auto* exp = ctx.Clone(call_expr->target.name);
auto* externalTextureParam = ctx.Clone(call_expr->args[0]);
ast::ExpressionList params;
if (builtin->Type() == sem::BuiltinType::kTextureLoad) {
auto* coordsParam = ctx.Clone(call_expr->args[1]);
auto* levelParam = ctx.dst->Expr(0);
params = {externalTextureParam, coordsParam, levelParam};
} else if (builtin->Type() ==
sem::BuiltinType::kTextureSampleLevel) {
auto* samplerParam = ctx.Clone(call_expr->args[1]);
auto* coordsParam = ctx.Clone(call_expr->args[2]);
auto* levelParam = ctx.dst->Expr(0.0f);
params = {externalTextureParam, samplerParam, coordsParam,
levelParam};
}
auto* newCall = ctx.dst->create<ast::CallExpression>(exp, params);
ctx.Replace(call_expr, newCall);
}
}
}
}
}
}
// Scan the AST nodes for external texture declarations.
for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* var = node->As<ast::Variable>()) {
if (::tint::Is<ast::ExternalTexture>(var->type)) {
// Replace a single-plane external texture with a 2D, f32 sampled
// texture.
auto* newType = ctx.dst->ty.sampled_texture(ast::TextureDimension::k2d,
ctx.dst->ty.f32());
auto clonedSrc = ctx.Clone(var->source);
auto clonedSym = ctx.Clone(var->symbol);
auto* clonedConstructor = ctx.Clone(var->constructor);
auto clonedAttributes = ctx.Clone(var->attributes);
auto* newVar = ctx.dst->create<ast::Variable>(
clonedSrc, clonedSym, var->declared_storage_class,
var->declared_access, newType, false, false, clonedConstructor,
clonedAttributes);
ctx.Replace(var, newVar);
}
}
}
ctx.Clone();
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,53 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_EXTERNAL_TEXTURE_TRANSFORM_H_
#define SRC_TINT_TRANSFORM_EXTERNAL_TEXTURE_TRANSFORM_H_
#include <utility>
#include "src/tint/transform/transform.h"
namespace tint {
namespace transform {
/// Because an external texture is comprised of 1-3 texture views we can simply
/// transform external textures into the appropriate number of sampled textures.
/// This allows us to share SPIR-V/HLSL writer paths for sampled textures
/// instead of adding dedicated writer paths for external textures.
/// ExternalTextureTransform performs this transformation.
class ExternalTextureTransform
: public Castable<ExternalTextureTransform, Transform> {
public:
/// Constructor
ExternalTextureTransform();
/// Destructor
~ExternalTextureTransform() override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_EXTERNAL_TEXTURE_TRANSFORM_H_

View File

@@ -0,0 +1,187 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/external_texture_transform.h"
#include "src/tint/transform/test_helper.h"
namespace tint {
namespace transform {
namespace {
using ExternalTextureTransformTest = TransformTest;
TEST_F(ExternalTextureTransformTest, SampleLevelSinglePlane) {
auto* src = R"(
@group(0) @binding(0) var s : sampler;
@group(0) @binding(1) var t : texture_external;
@stage(fragment)
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
return textureSampleLevel(t, s, (coord.xy / vec2<f32>(4.0, 4.0)));
}
)";
auto* expect = R"(
@group(0) @binding(0) var s : sampler;
@group(0) @binding(1) var t : texture_2d<f32>;
@stage(fragment)
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
return textureSampleLevel(t, s, (coord.xy / vec2<f32>(4.0, 4.0)), 0.0);
}
)";
auto got = Run<ExternalTextureTransform>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ExternalTextureTransformTest, SampleLevelSinglePlane_OutOfOrder) {
auto* src = R"(
@stage(fragment)
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
return textureSampleLevel(t, s, (coord.xy / vec2<f32>(4.0, 4.0)));
}
@group(0) @binding(1) var t : texture_external;
@group(0) @binding(0) var s : sampler;
)";
auto* expect = R"(
@stage(fragment)
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
return textureSampleLevel(t, s, (coord.xy / vec2<f32>(4.0, 4.0)), 0.0);
}
@group(0) @binding(1) var t : texture_2d<f32>;
@group(0) @binding(0) var s : sampler;
)";
auto got = Run<ExternalTextureTransform>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ExternalTextureTransformTest, LoadSinglePlane) {
auto* src = R"(
@group(0) @binding(0) var t : texture_external;
@stage(fragment)
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
return textureLoad(t, vec2<i32>(1, 1));
}
)";
auto* expect = R"(
@group(0) @binding(0) var t : texture_2d<f32>;
@stage(fragment)
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
return textureLoad(t, vec2<i32>(1, 1), 0);
}
)";
auto got = Run<ExternalTextureTransform>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ExternalTextureTransformTest, LoadSinglePlane_OutOfOrder) {
auto* src = R"(
@stage(fragment)
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
return textureLoad(t, vec2<i32>(1, 1));
}
@group(0) @binding(0) var t : texture_external;
)";
auto* expect = R"(
@stage(fragment)
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
return textureLoad(t, vec2<i32>(1, 1), 0);
}
@group(0) @binding(0) var t : texture_2d<f32>;
)";
auto got = Run<ExternalTextureTransform>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ExternalTextureTransformTest, DimensionsSinglePlane) {
auto* src = R"(
@group(0) @binding(0) var t : texture_external;
@stage(fragment)
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
var dim : vec2<i32>;
dim = textureDimensions(t);
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
}
)";
auto* expect = R"(
@group(0) @binding(0) var t : texture_2d<f32>;
@stage(fragment)
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
var dim : vec2<i32>;
dim = textureDimensions(t);
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
}
)";
auto got = Run<ExternalTextureTransform>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ExternalTextureTransformTest, DimensionsSinglePlane_OutOfOrder) {
auto* src = R"(
@stage(fragment)
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
var dim : vec2<i32>;
dim = textureDimensions(t);
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
}
@group(0) @binding(0) var t : texture_external;
)";
auto* expect = R"(
@stage(fragment)
fn main(@builtin(position) coord : vec4<f32>) -> @location(0) vec4<f32> {
var dim : vec2<i32>;
dim = textureDimensions(t);
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
}
@group(0) @binding(0) var t : texture_2d<f32>;
)";
auto got = Run<ExternalTextureTransform>(src);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,188 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/first_index_offset.h"
#include <memory>
#include <unordered_map>
#include <utility>
#include "src/tint/program_builder.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/member_accessor_expression.h"
#include "src/tint/sem/struct.h"
#include "src/tint/sem/variable.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::FirstIndexOffset);
TINT_INSTANTIATE_TYPEINFO(tint::transform::FirstIndexOffset::BindingPoint);
TINT_INSTANTIATE_TYPEINFO(tint::transform::FirstIndexOffset::Data);
namespace tint {
namespace transform {
namespace {
// Uniform buffer member names
constexpr char kFirstVertexName[] = "first_vertex_index";
constexpr char kFirstInstanceName[] = "first_instance_index";
} // namespace
FirstIndexOffset::BindingPoint::BindingPoint() = default;
FirstIndexOffset::BindingPoint::BindingPoint(uint32_t b, uint32_t g)
: binding(b), group(g) {}
FirstIndexOffset::BindingPoint::~BindingPoint() = default;
FirstIndexOffset::Data::Data(bool has_vtx_index,
bool has_inst_index,
uint32_t first_vtx_offset,
uint32_t first_inst_offset)
: has_vertex_index(has_vtx_index),
has_instance_index(has_inst_index),
first_vertex_offset(first_vtx_offset),
first_instance_offset(first_inst_offset) {}
FirstIndexOffset::Data::Data(const Data&) = default;
FirstIndexOffset::Data::~Data() = default;
FirstIndexOffset::FirstIndexOffset() = default;
FirstIndexOffset::~FirstIndexOffset() = default;
bool FirstIndexOffset::ShouldRun(const Program* program, const DataMap&) const {
for (auto* fn : program->AST().Functions()) {
if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
return true;
}
}
return false;
}
void FirstIndexOffset::Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const {
// Get the uniform buffer binding point
uint32_t ub_binding = binding_;
uint32_t ub_group = group_;
if (auto* binding_point = inputs.Get<BindingPoint>()) {
ub_binding = binding_point->binding;
ub_group = binding_point->group;
}
// Map of builtin usages
std::unordered_map<const sem::Variable*, const char*> builtin_vars;
std::unordered_map<const sem::StructMember*, const char*> builtin_members;
bool has_vertex_index = false;
bool has_instance_index = false;
// Traverse the AST scanning for builtin accesses via variables (includes
// parameters) or structure member accesses.
for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* var = node->As<ast::Variable>()) {
for (auto* attr : var->attributes) {
if (auto* builtin_attr = attr->As<ast::BuiltinAttribute>()) {
ast::Builtin builtin = builtin_attr->builtin;
if (builtin == ast::Builtin::kVertexIndex) {
auto* sem_var = ctx.src->Sem().Get(var);
builtin_vars.emplace(sem_var, kFirstVertexName);
has_vertex_index = true;
}
if (builtin == ast::Builtin::kInstanceIndex) {
auto* sem_var = ctx.src->Sem().Get(var);
builtin_vars.emplace(sem_var, kFirstInstanceName);
has_instance_index = true;
}
}
}
}
if (auto* member = node->As<ast::StructMember>()) {
for (auto* attr : member->attributes) {
if (auto* builtin_attr = attr->As<ast::BuiltinAttribute>()) {
ast::Builtin builtin = builtin_attr->builtin;
if (builtin == ast::Builtin::kVertexIndex) {
auto* sem_mem = ctx.src->Sem().Get(member);
builtin_members.emplace(sem_mem, kFirstVertexName);
has_vertex_index = true;
}
if (builtin == ast::Builtin::kInstanceIndex) {
auto* sem_mem = ctx.src->Sem().Get(member);
builtin_members.emplace(sem_mem, kFirstInstanceName);
has_instance_index = true;
}
}
}
}
}
// Byte offsets on the uniform buffer
uint32_t vertex_index_offset = 0;
uint32_t instance_index_offset = 0;
if (has_vertex_index || has_instance_index) {
// Add uniform buffer members and calculate byte offsets
uint32_t offset = 0;
ast::StructMemberList members;
if (has_vertex_index) {
members.push_back(ctx.dst->Member(kFirstVertexName, ctx.dst->ty.u32()));
vertex_index_offset = offset;
offset += 4;
}
if (has_instance_index) {
members.push_back(ctx.dst->Member(kFirstInstanceName, ctx.dst->ty.u32()));
instance_index_offset = offset;
offset += 4;
}
auto* struct_ = ctx.dst->Structure(ctx.dst->Sym(), std::move(members));
// Create a global to hold the uniform buffer
Symbol buffer_name = ctx.dst->Sym();
ctx.dst->Global(buffer_name, ctx.dst->ty.Of(struct_),
ast::StorageClass::kUniform, nullptr,
ast::AttributeList{
ctx.dst->create<ast::BindingAttribute>(ub_binding),
ctx.dst->create<ast::GroupAttribute>(ub_group),
});
// Fix up all references to the builtins with the offsets
ctx.ReplaceAll(
[=, &ctx](const ast::Expression* expr) -> const ast::Expression* {
if (auto* sem = ctx.src->Sem().Get(expr)) {
if (auto* user = sem->As<sem::VariableUser>()) {
auto it = builtin_vars.find(user->Variable());
if (it != builtin_vars.end()) {
return ctx.dst->Add(
ctx.CloneWithoutTransform(expr),
ctx.dst->MemberAccessor(buffer_name, it->second));
}
}
if (auto* access = sem->As<sem::StructMemberAccess>()) {
auto it = builtin_members.find(access->Member());
if (it != builtin_members.end()) {
return ctx.dst->Add(
ctx.CloneWithoutTransform(expr),
ctx.dst->MemberAccessor(buffer_name, it->second));
}
}
}
// Not interested in this experssion. Just clone.
return nullptr;
});
}
ctx.Clone();
outputs.Add<Data>(has_vertex_index, has_instance_index, vertex_index_offset,
instance_index_offset);
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,143 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_FIRST_INDEX_OFFSET_H_
#define SRC_TINT_TRANSFORM_FIRST_INDEX_OFFSET_H_
#include "src/tint/transform/transform.h"
namespace tint {
namespace transform {
/// Adds firstVertex/Instance (injected via root constants) to
/// vertex/instance index builtins.
///
/// This transform assumes that Name transform has been run before.
///
/// Unlike other APIs, D3D always starts vertex and instance numbering at 0,
/// regardless of the firstVertex/Instance value specified. This transformer
/// adds the value of firstVertex/Instance to each builtin. This action is
/// performed by adding a new constant equal to original builtin +
/// firstVertex/Instance to each function that references one of these builtins.
///
/// Note that D3D does not have any semantics for firstVertex/Instance.
/// Therefore, these values must by passed to the shader.
///
/// Before:
/// ```
/// @builtin(vertex_index) var<in> vert_idx : u32;
/// fn func() -> u32 {
/// return vert_idx;
/// }
/// ```
///
/// After:
/// ```
/// struct TintFirstIndexOffsetData {
/// tint_first_vertex_index : u32;
/// tint_first_instance_index : u32;
/// };
/// @builtin(vertex_index) var<in> tint_first_index_offset_vert_idx : u32;
/// @binding(N) @group(M) var<uniform> tint_first_index_data :
/// TintFirstIndexOffsetData;
/// fn func() -> u32 {
/// const vert_idx = (tint_first_index_offset_vert_idx +
/// tint_first_index_data.tint_first_vertex_index);
/// return vert_idx;
/// }
/// ```
///
class FirstIndexOffset : public Castable<FirstIndexOffset, Transform> {
public:
/// BindingPoint is consumed by the FirstIndexOffset transform.
/// BindingPoint specifies the binding point of the first index uniform
/// buffer.
struct BindingPoint : public Castable<BindingPoint, transform::Data> {
/// Constructor
BindingPoint();
/// Constructor
/// @param b the binding index
/// @param g the binding group
BindingPoint(uint32_t b, uint32_t g);
/// Destructor
~BindingPoint() override;
/// `@binding()` for the first vertex / first instance uniform buffer
uint32_t binding = 0;
/// `@group()` for the first vertex / first instance uniform buffer
uint32_t group = 0;
};
/// Data is outputted by the FirstIndexOffset transform.
/// Data holds information about shader usage and constant buffer offsets.
struct Data : public Castable<Data, transform::Data> {
/// Constructor
/// @param has_vtx_index True if the shader uses vertex_index
/// @param has_inst_index True if the shader uses instance_index
/// @param first_vtx_offset Offset of first vertex into constant buffer
/// @param first_inst_offset Offset of first instance into constant buffer
Data(bool has_vtx_index,
bool has_inst_index,
uint32_t first_vtx_offset,
uint32_t first_inst_offset);
/// Copy constructor
Data(const Data&);
/// Destructor
~Data() override;
/// True if the shader uses vertex_index
const bool has_vertex_index;
/// True if the shader uses instance_index
const bool has_instance_index;
/// Offset of first vertex into constant buffer
const uint32_t first_vertex_offset;
/// Offset of first instance into constant buffer
const uint32_t first_instance_offset;
};
/// Constructor
FirstIndexOffset();
/// Destructor
~FirstIndexOffset() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
private:
uint32_t binding_ = 0;
uint32_t group_ = 0;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_FIRST_INDEX_OFFSET_H_

View File

@@ -0,0 +1,650 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/first_index_offset.h"
#include <memory>
#include <utility>
#include <vector>
#include "src/tint/transform/test_helper.h"
namespace tint {
namespace transform {
namespace {
using FirstIndexOffsetTest = TransformTest;
TEST_F(FirstIndexOffsetTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<FirstIndexOffset>(src));
}
TEST_F(FirstIndexOffsetTest, ShouldRunFragmentStage) {
auto* src = R"(
[[stage(fragment)]]
fn entry() {
return;
}
)";
EXPECT_FALSE(ShouldRun<FirstIndexOffset>(src));
}
TEST_F(FirstIndexOffsetTest, ShouldRunVertexStage) {
auto* src = R"(
[[stage(vertex)]]
fn entry() -> [[builtin(position)]] vec4<f32> {
return vec4<f32>();
}
)";
EXPECT_TRUE(ShouldRun<FirstIndexOffset>(src));
}
TEST_F(FirstIndexOffsetTest, EmptyModule) {
auto* src = "";
auto* expect = "";
DataMap config;
config.Add<FirstIndexOffset::BindingPoint>(0, 0);
auto got = Run<FirstIndexOffset>(src, std::move(config));
EXPECT_EQ(expect, str(got));
auto* data = got.data.Get<FirstIndexOffset::Data>();
EXPECT_EQ(data, nullptr);
}
TEST_F(FirstIndexOffsetTest, BasicVertexShader) {
auto* src = R"(
@stage(vertex)
fn entry() -> @builtin(position) vec4<f32> {
return vec4<f32>();
}
)";
auto* expect = src;
DataMap config;
config.Add<FirstIndexOffset::BindingPoint>(0, 0);
auto got = Run<FirstIndexOffset>(src, std::move(config));
EXPECT_EQ(expect, str(got));
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
EXPECT_EQ(data->has_vertex_index, false);
EXPECT_EQ(data->has_instance_index, false);
EXPECT_EQ(data->first_vertex_offset, 0u);
EXPECT_EQ(data->first_instance_offset, 0u);
}
TEST_F(FirstIndexOffsetTest, BasicModuleVertexIndex) {
auto* src = R"(
fn test(vert_idx : u32) -> u32 {
return vert_idx;
}
@stage(vertex)
fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
test(vert_idx);
return vec4<f32>();
}
)";
auto* expect = R"(
struct tint_symbol {
first_vertex_index : u32;
}
@binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
fn test(vert_idx : u32) -> u32 {
return vert_idx;
}
@stage(vertex)
fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
test((vert_idx + tint_symbol_1.first_vertex_index));
return vec4<f32>();
}
)";
DataMap config;
config.Add<FirstIndexOffset::BindingPoint>(1, 2);
auto got = Run<FirstIndexOffset>(src, std::move(config));
EXPECT_EQ(expect, str(got));
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
EXPECT_EQ(data->has_vertex_index, true);
EXPECT_EQ(data->has_instance_index, false);
EXPECT_EQ(data->first_vertex_offset, 0u);
EXPECT_EQ(data->first_instance_offset, 0u);
}
TEST_F(FirstIndexOffsetTest, BasicModuleVertexIndex_OutOfOrder) {
auto* src = R"(
@stage(vertex)
fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
test(vert_idx);
return vec4<f32>();
}
fn test(vert_idx : u32) -> u32 {
return vert_idx;
}
)";
auto* expect = R"(
struct tint_symbol {
first_vertex_index : u32;
}
@binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
@stage(vertex)
fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
test((vert_idx + tint_symbol_1.first_vertex_index));
return vec4<f32>();
}
fn test(vert_idx : u32) -> u32 {
return vert_idx;
}
)";
DataMap config;
config.Add<FirstIndexOffset::BindingPoint>(1, 2);
auto got = Run<FirstIndexOffset>(src, std::move(config));
EXPECT_EQ(expect, str(got));
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
EXPECT_EQ(data->has_vertex_index, true);
EXPECT_EQ(data->has_instance_index, false);
EXPECT_EQ(data->first_vertex_offset, 0u);
EXPECT_EQ(data->first_instance_offset, 0u);
}
TEST_F(FirstIndexOffsetTest, BasicModuleInstanceIndex) {
auto* src = R"(
fn test(inst_idx : u32) -> u32 {
return inst_idx;
}
@stage(vertex)
fn entry(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
test(inst_idx);
return vec4<f32>();
}
)";
auto* expect = R"(
struct tint_symbol {
first_instance_index : u32;
}
@binding(1) @group(7) var<uniform> tint_symbol_1 : tint_symbol;
fn test(inst_idx : u32) -> u32 {
return inst_idx;
}
@stage(vertex)
fn entry(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
test((inst_idx + tint_symbol_1.first_instance_index));
return vec4<f32>();
}
)";
DataMap config;
config.Add<FirstIndexOffset::BindingPoint>(1, 7);
auto got = Run<FirstIndexOffset>(src, std::move(config));
EXPECT_EQ(expect, str(got));
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
EXPECT_EQ(data->has_vertex_index, false);
EXPECT_EQ(data->has_instance_index, true);
EXPECT_EQ(data->first_vertex_offset, 0u);
EXPECT_EQ(data->first_instance_offset, 0u);
}
TEST_F(FirstIndexOffsetTest, BasicModuleInstanceIndex_OutOfOrder) {
auto* src = R"(
@stage(vertex)
fn entry(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
test(inst_idx);
return vec4<f32>();
}
fn test(inst_idx : u32) -> u32 {
return inst_idx;
}
)";
auto* expect = R"(
struct tint_symbol {
first_instance_index : u32;
}
@binding(1) @group(7) var<uniform> tint_symbol_1 : tint_symbol;
@stage(vertex)
fn entry(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
test((inst_idx + tint_symbol_1.first_instance_index));
return vec4<f32>();
}
fn test(inst_idx : u32) -> u32 {
return inst_idx;
}
)";
DataMap config;
config.Add<FirstIndexOffset::BindingPoint>(1, 7);
auto got = Run<FirstIndexOffset>(src, std::move(config));
EXPECT_EQ(expect, str(got));
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
EXPECT_EQ(data->has_vertex_index, false);
EXPECT_EQ(data->has_instance_index, true);
EXPECT_EQ(data->first_vertex_offset, 0u);
EXPECT_EQ(data->first_instance_offset, 0u);
}
TEST_F(FirstIndexOffsetTest, BasicModuleBothIndex) {
auto* src = R"(
fn test(instance_idx : u32, vert_idx : u32) -> u32 {
return instance_idx + vert_idx;
}
struct Inputs {
@builtin(instance_index) instance_idx : u32;
@builtin(vertex_index) vert_idx : u32;
};
@stage(vertex)
fn entry(inputs : Inputs) -> @builtin(position) vec4<f32> {
test(inputs.instance_idx, inputs.vert_idx);
return vec4<f32>();
}
)";
auto* expect = R"(
struct tint_symbol {
first_vertex_index : u32;
first_instance_index : u32;
}
@binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
fn test(instance_idx : u32, vert_idx : u32) -> u32 {
return (instance_idx + vert_idx);
}
struct Inputs {
@builtin(instance_index)
instance_idx : u32;
@builtin(vertex_index)
vert_idx : u32;
}
@stage(vertex)
fn entry(inputs : Inputs) -> @builtin(position) vec4<f32> {
test((inputs.instance_idx + tint_symbol_1.first_instance_index), (inputs.vert_idx + tint_symbol_1.first_vertex_index));
return vec4<f32>();
}
)";
DataMap config;
config.Add<FirstIndexOffset::BindingPoint>(1, 2);
auto got = Run<FirstIndexOffset>(src, std::move(config));
EXPECT_EQ(expect, str(got));
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
EXPECT_EQ(data->has_vertex_index, true);
EXPECT_EQ(data->has_instance_index, true);
EXPECT_EQ(data->first_vertex_offset, 0u);
EXPECT_EQ(data->first_instance_offset, 4u);
}
TEST_F(FirstIndexOffsetTest, BasicModuleBothIndex_OutOfOrder) {
auto* src = R"(
@stage(vertex)
fn entry(inputs : Inputs) -> @builtin(position) vec4<f32> {
test(inputs.instance_idx, inputs.vert_idx);
return vec4<f32>();
}
struct Inputs {
@builtin(instance_index) instance_idx : u32;
@builtin(vertex_index) vert_idx : u32;
};
fn test(instance_idx : u32, vert_idx : u32) -> u32 {
return instance_idx + vert_idx;
}
)";
auto* expect = R"(
struct tint_symbol {
first_vertex_index : u32;
first_instance_index : u32;
}
@binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
@stage(vertex)
fn entry(inputs : Inputs) -> @builtin(position) vec4<f32> {
test((inputs.instance_idx + tint_symbol_1.first_instance_index), (inputs.vert_idx + tint_symbol_1.first_vertex_index));
return vec4<f32>();
}
struct Inputs {
@builtin(instance_index)
instance_idx : u32;
@builtin(vertex_index)
vert_idx : u32;
}
fn test(instance_idx : u32, vert_idx : u32) -> u32 {
return (instance_idx + vert_idx);
}
)";
DataMap config;
config.Add<FirstIndexOffset::BindingPoint>(1, 2);
auto got = Run<FirstIndexOffset>(src, std::move(config));
EXPECT_EQ(expect, str(got));
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
EXPECT_EQ(data->has_vertex_index, true);
EXPECT_EQ(data->has_instance_index, true);
EXPECT_EQ(data->first_vertex_offset, 0u);
EXPECT_EQ(data->first_instance_offset, 4u);
}
TEST_F(FirstIndexOffsetTest, NestedCalls) {
auto* src = R"(
fn func1(vert_idx : u32) -> u32 {
return vert_idx;
}
fn func2(vert_idx : u32) -> u32 {
return func1(vert_idx);
}
@stage(vertex)
fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
func2(vert_idx);
return vec4<f32>();
}
)";
auto* expect = R"(
struct tint_symbol {
first_vertex_index : u32;
}
@binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
fn func1(vert_idx : u32) -> u32 {
return vert_idx;
}
fn func2(vert_idx : u32) -> u32 {
return func1(vert_idx);
}
@stage(vertex)
fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
func2((vert_idx + tint_symbol_1.first_vertex_index));
return vec4<f32>();
}
)";
DataMap config;
config.Add<FirstIndexOffset::BindingPoint>(1, 2);
auto got = Run<FirstIndexOffset>(src, std::move(config));
EXPECT_EQ(expect, str(got));
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
EXPECT_EQ(data->has_vertex_index, true);
EXPECT_EQ(data->has_instance_index, false);
EXPECT_EQ(data->first_vertex_offset, 0u);
EXPECT_EQ(data->first_instance_offset, 0u);
}
TEST_F(FirstIndexOffsetTest, NestedCalls_OutOfOrder) {
auto* src = R"(
@stage(vertex)
fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
func2(vert_idx);
return vec4<f32>();
}
fn func2(vert_idx : u32) -> u32 {
return func1(vert_idx);
}
fn func1(vert_idx : u32) -> u32 {
return vert_idx;
}
)";
auto* expect = R"(
struct tint_symbol {
first_vertex_index : u32;
}
@binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
@stage(vertex)
fn entry(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
func2((vert_idx + tint_symbol_1.first_vertex_index));
return vec4<f32>();
}
fn func2(vert_idx : u32) -> u32 {
return func1(vert_idx);
}
fn func1(vert_idx : u32) -> u32 {
return vert_idx;
}
)";
DataMap config;
config.Add<FirstIndexOffset::BindingPoint>(1, 2);
auto got = Run<FirstIndexOffset>(src, std::move(config));
EXPECT_EQ(expect, str(got));
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
EXPECT_EQ(data->has_vertex_index, true);
EXPECT_EQ(data->has_instance_index, false);
EXPECT_EQ(data->first_vertex_offset, 0u);
EXPECT_EQ(data->first_instance_offset, 0u);
}
TEST_F(FirstIndexOffsetTest, MultipleEntryPoints) {
auto* src = R"(
fn func(i : u32) -> u32 {
return i;
}
@stage(vertex)
fn entry_a(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
func(vert_idx);
return vec4<f32>();
}
@stage(vertex)
fn entry_b(@builtin(vertex_index) vert_idx : u32, @builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
func(vert_idx + inst_idx);
return vec4<f32>();
}
@stage(vertex)
fn entry_c(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
func(inst_idx);
return vec4<f32>();
}
)";
auto* expect = R"(
struct tint_symbol {
first_vertex_index : u32;
first_instance_index : u32;
}
@binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
fn func(i : u32) -> u32 {
return i;
}
@stage(vertex)
fn entry_a(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
func((vert_idx + tint_symbol_1.first_vertex_index));
return vec4<f32>();
}
@stage(vertex)
fn entry_b(@builtin(vertex_index) vert_idx : u32, @builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
func(((vert_idx + tint_symbol_1.first_vertex_index) + (inst_idx + tint_symbol_1.first_instance_index)));
return vec4<f32>();
}
@stage(vertex)
fn entry_c(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
func((inst_idx + tint_symbol_1.first_instance_index));
return vec4<f32>();
}
)";
DataMap config;
config.Add<FirstIndexOffset::BindingPoint>(1, 2);
auto got = Run<FirstIndexOffset>(src, std::move(config));
EXPECT_EQ(expect, str(got));
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
EXPECT_EQ(data->has_vertex_index, true);
EXPECT_EQ(data->has_instance_index, true);
EXPECT_EQ(data->first_vertex_offset, 0u);
EXPECT_EQ(data->first_instance_offset, 4u);
}
TEST_F(FirstIndexOffsetTest, MultipleEntryPoints_OutOfOrder) {
auto* src = R"(
@stage(vertex)
fn entry_a(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
func(vert_idx);
return vec4<f32>();
}
@stage(vertex)
fn entry_b(@builtin(vertex_index) vert_idx : u32, @builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
func(vert_idx + inst_idx);
return vec4<f32>();
}
@stage(vertex)
fn entry_c(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
func(inst_idx);
return vec4<f32>();
}
fn func(i : u32) -> u32 {
return i;
}
)";
auto* expect = R"(
struct tint_symbol {
first_vertex_index : u32;
first_instance_index : u32;
}
@binding(1) @group(2) var<uniform> tint_symbol_1 : tint_symbol;
@stage(vertex)
fn entry_a(@builtin(vertex_index) vert_idx : u32) -> @builtin(position) vec4<f32> {
func((vert_idx + tint_symbol_1.first_vertex_index));
return vec4<f32>();
}
@stage(vertex)
fn entry_b(@builtin(vertex_index) vert_idx : u32, @builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
func(((vert_idx + tint_symbol_1.first_vertex_index) + (inst_idx + tint_symbol_1.first_instance_index)));
return vec4<f32>();
}
@stage(vertex)
fn entry_c(@builtin(instance_index) inst_idx : u32) -> @builtin(position) vec4<f32> {
func((inst_idx + tint_symbol_1.first_instance_index));
return vec4<f32>();
}
fn func(i : u32) -> u32 {
return i;
}
)";
DataMap config;
config.Add<FirstIndexOffset::BindingPoint>(1, 2);
auto got = Run<FirstIndexOffset>(src, std::move(config));
EXPECT_EQ(expect, str(got));
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
EXPECT_EQ(data->has_vertex_index, true);
EXPECT_EQ(data->has_instance_index, true);
EXPECT_EQ(data->first_vertex_offset, 0u);
EXPECT_EQ(data->first_instance_offset, 4u);
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,99 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/fold_constants.h"
#include <unordered_map>
#include <utility>
#include <vector>
#include "src/tint/program_builder.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/expression.h"
#include "src/tint/sem/type_constructor.h"
#include "src/tint/sem/type_conversion.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::FoldConstants);
namespace tint {
namespace transform {
FoldConstants::FoldConstants() = default;
FoldConstants::~FoldConstants() = default;
void FoldConstants::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
auto* call = ctx.src->Sem().Get<sem::Call>(expr);
if (!call) {
return nullptr;
}
auto value = call->ConstantValue();
if (!value.IsValid()) {
return nullptr;
}
auto* ty = call->Type();
if (!call->Target()->IsAnyOf<sem::TypeConversion, sem::TypeConstructor>()) {
return nullptr;
}
// If original ctor expression had no init values, don't replace the
// expression
if (call->Arguments().empty()) {
return nullptr;
}
if (auto* vec = ty->As<sem::Vector>()) {
uint32_t vec_size = static_cast<uint32_t>(vec->Width());
// We'd like to construct the new vector with the same number of
// constructor args that the original node had, but after folding
// constants, cases like the following are problematic:
//
// vec3<f32> = vec3<f32>(vec2<f32>, 1.0) // vec_size=3, ctor_size=2
//
// In this case, creating a vec3 with 2 args is invalid, so we should
// create it with 3. So what we do is construct with vec_size args,
// except if the original vector was single-value initialized, in
// which case, we only construct with one arg again.
uint32_t ctor_size = (call->Arguments().size() == 1) ? 1 : vec_size;
ast::ExpressionList ctors;
for (uint32_t i = 0; i < ctor_size; ++i) {
value.WithScalarAt(
i, [&](auto&& s) { ctors.emplace_back(ctx.dst->Expr(s)); });
}
auto* el_ty = CreateASTTypeFor(ctx, vec->type());
return ctx.dst->vec(el_ty, vec_size, ctors);
}
if (ty->is_scalar()) {
return value.WithScalarAt(0,
[&](auto&& s) -> const ast::LiteralExpression* {
return ctx.dst->Expr(s);
});
}
return nullptr;
});
ctx.Clone();
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,47 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_FOLD_CONSTANTS_H_
#define SRC_TINT_TRANSFORM_FOLD_CONSTANTS_H_
#include "src/tint/transform/transform.h"
namespace tint {
namespace transform {
/// FoldConstants transforms the AST by folding constant expressions
class FoldConstants : public Castable<FoldConstants, Transform> {
public:
/// Constructor
FoldConstants();
/// Destructor
~FoldConstants() override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_FOLD_CONSTANTS_H_

View File

@@ -0,0 +1,427 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/fold_constants.h"
#include <memory>
#include <utility>
#include <vector>
#include "src/tint/transform/test_helper.h"
namespace tint {
namespace transform {
namespace {
using FoldConstantsTest = TransformTest;
TEST_F(FoldConstantsTest, Module_Scalar_NoConversion) {
auto* src = R"(
var<private> a : i32 = i32(123);
var<private> b : u32 = u32(123u);
var<private> c : f32 = f32(123.0);
var<private> d : bool = bool(true);
fn f() {
}
)";
auto* expect = R"(
var<private> a : i32 = 123;
var<private> b : u32 = 123u;
var<private> c : f32 = 123.0;
var<private> d : bool = true;
fn f() {
}
)";
auto got = Run<FoldConstants>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Module_Scalar_Conversion) {
auto* src = R"(
var<private> a : i32 = i32(123.0);
var<private> b : u32 = u32(123);
var<private> c : f32 = f32(123u);
var<private> d : bool = bool(123);
fn f() {
}
)";
auto* expect = R"(
var<private> a : i32 = 123;
var<private> b : u32 = 123u;
var<private> c : f32 = 123.0;
var<private> d : bool = true;
fn f() {
}
)";
auto got = Run<FoldConstants>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Module_Scalar_MultipleConversions) {
auto* src = R"(
var<private> a : i32 = i32(u32(f32(u32(i32(123.0)))));
var<private> b : u32 = u32(i32(f32(i32(u32(123)))));
var<private> c : f32 = f32(u32(i32(u32(f32(123u)))));
var<private> d : bool = bool(i32(f32(i32(u32(123)))));
fn f() {
}
)";
auto* expect = R"(
var<private> a : i32 = 123;
var<private> b : u32 = 123u;
var<private> c : f32 = 123.0;
var<private> d : bool = true;
fn f() {
}
)";
auto got = Run<FoldConstants>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Module_Vector_NoConversion) {
auto* src = R"(
var<private> a : vec3<i32> = vec3<i32>(123);
var<private> b : vec3<u32> = vec3<u32>(123u);
var<private> c : vec3<f32> = vec3<f32>(123.0);
var<private> d : vec3<bool> = vec3<bool>(true);
fn f() {
}
)";
auto* expect = R"(
var<private> a : vec3<i32> = vec3<i32>(123);
var<private> b : vec3<u32> = vec3<u32>(123u);
var<private> c : vec3<f32> = vec3<f32>(123.0);
var<private> d : vec3<bool> = vec3<bool>(true);
fn f() {
}
)";
auto got = Run<FoldConstants>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Module_Vector_Conversion) {
auto* src = R"(
var<private> a : vec3<i32> = vec3<i32>(vec3<f32>(123.0));
var<private> b : vec3<u32> = vec3<u32>(vec3<i32>(123));
var<private> c : vec3<f32> = vec3<f32>(vec3<u32>(123u));
var<private> d : vec3<bool> = vec3<bool>(vec3<i32>(123));
fn f() {
}
)";
auto* expect = R"(
var<private> a : vec3<i32> = vec3<i32>(123);
var<private> b : vec3<u32> = vec3<u32>(123u);
var<private> c : vec3<f32> = vec3<f32>(123.0);
var<private> d : vec3<bool> = vec3<bool>(true);
fn f() {
}
)";
auto got = Run<FoldConstants>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Module_Vector_MultipleConversions) {
auto* src = R"(
var<private> a : vec3<i32> = vec3<i32>(vec3<u32>(vec3<f32>(vec3<u32>(u32(123.0)))));
var<private> b : vec3<u32> = vec3<u32>(vec3<i32>(vec3<f32>(vec3<i32>(i32(123)))));
var<private> c : vec3<f32> = vec3<f32>(vec3<u32>(vec3<i32>(vec3<u32>(u32(123u)))));
var<private> d : vec3<bool> = vec3<bool>(vec3<i32>(vec3<f32>(vec3<i32>(i32(123)))));
fn f() {
}
)";
auto* expect = R"(
var<private> a : vec3<i32> = vec3<i32>(123);
var<private> b : vec3<u32> = vec3<u32>(123u);
var<private> c : vec3<f32> = vec3<f32>(123.0);
var<private> d : vec3<bool> = vec3<bool>(true);
fn f() {
}
)";
auto got = Run<FoldConstants>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Module_Vector_MixedSizeConversions) {
auto* src = R"(
var<private> a : vec4<i32> = vec4<i32>(vec3<i32>(vec3<u32>(1u, 2u, 3u)), 4);
var<private> b : vec4<i32> = vec4<i32>(vec2<i32>(vec2<u32>(1u, 2u)), vec2<i32>(4, 5));
var<private> c : vec4<i32> = vec4<i32>(1, vec2<i32>(vec2<f32>(2.0, 3.0)), 4);
var<private> d : vec4<i32> = vec4<i32>(1, 2, vec2<i32>(vec2<f32>(3.0, 4.0)));
var<private> e : vec4<bool> = vec4<bool>(false, bool(f32(1.0)), vec2<bool>(vec2<i32>(0, i32(4u))));
fn f() {
}
)";
auto* expect = R"(
var<private> a : vec4<i32> = vec4<i32>(1, 2, 3, 4);
var<private> b : vec4<i32> = vec4<i32>(1, 2, 4, 5);
var<private> c : vec4<i32> = vec4<i32>(1, 2, 3, 4);
var<private> d : vec4<i32> = vec4<i32>(1, 2, 3, 4);
var<private> e : vec4<bool> = vec4<bool>(false, true, false, true);
fn f() {
}
)";
auto got = Run<FoldConstants>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Function_Scalar_NoConversion) {
auto* src = R"(
fn f() {
var a : i32 = i32(123);
var b : u32 = u32(123u);
var c : f32 = f32(123.0);
var d : bool = bool(true);
}
)";
auto* expect = R"(
fn f() {
var a : i32 = 123;
var b : u32 = 123u;
var c : f32 = 123.0;
var d : bool = true;
}
)";
auto got = Run<FoldConstants>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Function_Scalar_Conversion) {
auto* src = R"(
fn f() {
var a : i32 = i32(123.0);
var b : u32 = u32(123);
var c : f32 = f32(123u);
var d : bool = bool(123);
}
)";
auto* expect = R"(
fn f() {
var a : i32 = 123;
var b : u32 = 123u;
var c : f32 = 123.0;
var d : bool = true;
}
)";
auto got = Run<FoldConstants>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Function_Scalar_MultipleConversions) {
auto* src = R"(
fn f() {
var a : i32 = i32(u32(f32(u32(i32(123.0)))));
var b : u32 = u32(i32(f32(i32(u32(123)))));
var c : f32 = f32(u32(i32(u32(f32(123u)))));
var d : bool = bool(i32(f32(i32(u32(123)))));
}
)";
auto* expect = R"(
fn f() {
var a : i32 = 123;
var b : u32 = 123u;
var c : f32 = 123.0;
var d : bool = true;
}
)";
auto got = Run<FoldConstants>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Function_Vector_NoConversion) {
auto* src = R"(
fn f() {
var a : vec3<i32> = vec3<i32>(123);
var b : vec3<u32> = vec3<u32>(123u);
var c : vec3<f32> = vec3<f32>(123.0);
var d : vec3<bool> = vec3<bool>(true);
}
)";
auto* expect = R"(
fn f() {
var a : vec3<i32> = vec3<i32>(123);
var b : vec3<u32> = vec3<u32>(123u);
var c : vec3<f32> = vec3<f32>(123.0);
var d : vec3<bool> = vec3<bool>(true);
}
)";
auto got = Run<FoldConstants>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Function_Vector_Conversion) {
auto* src = R"(
fn f() {
var a : vec3<i32> = vec3<i32>(vec3<f32>(123.0));
var b : vec3<u32> = vec3<u32>(vec3<i32>(123));
var c : vec3<f32> = vec3<f32>(vec3<u32>(123u));
var d : vec3<bool> = vec3<bool>(vec3<i32>(123));
}
)";
auto* expect = R"(
fn f() {
var a : vec3<i32> = vec3<i32>(123);
var b : vec3<u32> = vec3<u32>(123u);
var c : vec3<f32> = vec3<f32>(123.0);
var d : vec3<bool> = vec3<bool>(true);
}
)";
auto got = Run<FoldConstants>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Function_Vector_MultipleConversions) {
auto* src = R"(
fn f() {
var a : vec3<i32> = vec3<i32>(vec3<u32>(vec3<f32>(vec3<u32>(u32(123.0)))));
var b : vec3<u32> = vec3<u32>(vec3<i32>(vec3<f32>(vec3<i32>(i32(123)))));
var c : vec3<f32> = vec3<f32>(vec3<u32>(vec3<i32>(vec3<u32>(u32(123u)))));
var d : vec3<bool> = vec3<bool>(vec3<i32>(vec3<f32>(vec3<i32>(i32(123)))));
}
)";
auto* expect = R"(
fn f() {
var a : vec3<i32> = vec3<i32>(123);
var b : vec3<u32> = vec3<u32>(123u);
var c : vec3<f32> = vec3<f32>(123.0);
var d : vec3<bool> = vec3<bool>(true);
}
)";
auto got = Run<FoldConstants>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Function_Vector_MixedSizeConversions) {
auto* src = R"(
fn f() {
var a : vec4<i32> = vec4<i32>(vec3<i32>(vec3<u32>(1u, 2u, 3u)), 4);
var b : vec4<i32> = vec4<i32>(vec2<i32>(vec2<u32>(1u, 2u)), vec2<i32>(4, 5));
var c : vec4<i32> = vec4<i32>(1, vec2<i32>(vec2<f32>(2.0, 3.0)), 4);
var d : vec4<i32> = vec4<i32>(1, 2, vec2<i32>(vec2<f32>(3.0, 4.0)));
var e : vec4<bool> = vec4<bool>(false, bool(f32(1.0)), vec2<bool>(vec2<i32>(0, i32(4u))));
}
)";
auto* expect = R"(
fn f() {
var a : vec4<i32> = vec4<i32>(1, 2, 3, 4);
var b : vec4<i32> = vec4<i32>(1, 2, 4, 5);
var c : vec4<i32> = vec4<i32>(1, 2, 3, 4);
var d : vec4<i32> = vec4<i32>(1, 2, 3, 4);
var e : vec4<bool> = vec4<bool>(false, true, false, true);
}
)";
auto got = Run<FoldConstants>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(FoldConstantsTest, Function_Vector_ConstantWithNonConstant) {
auto* src = R"(
fn f() {
var a : f32 = f32();
var b : vec2<f32> = vec2<f32>(f32(i32(1)), a);
}
)";
auto* expect = R"(
fn f() {
var a : f32 = f32();
var b : vec2<f32> = vec2<f32>(1.0, a);
}
)";
auto got = Run<FoldConstants>(src);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,92 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/fold_trivial_single_use_lets.h"
#include "src/tint/program_builder.h"
#include "src/tint/sem/block_statement.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/statement.h"
#include "src/tint/sem/variable.h"
#include "src/tint/utils/scoped_assignment.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::FoldTrivialSingleUseLets);
namespace tint {
namespace transform {
namespace {
const ast::VariableDeclStatement* AsTrivialLetDecl(const ast::Statement* stmt) {
auto* var_decl = stmt->As<ast::VariableDeclStatement>();
if (!var_decl) {
return nullptr;
}
auto* var = var_decl->variable;
if (!var->is_const) {
return nullptr;
}
auto* ctor = var->constructor;
if (!IsAnyOf<ast::IdentifierExpression, ast::LiteralExpression>(ctor)) {
return nullptr;
}
return var_decl;
}
} // namespace
FoldTrivialSingleUseLets::FoldTrivialSingleUseLets() = default;
FoldTrivialSingleUseLets::~FoldTrivialSingleUseLets() = default;
void FoldTrivialSingleUseLets::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* block = node->As<ast::BlockStatement>()) {
auto& stmts = block->statements;
for (size_t stmt_idx = 0; stmt_idx < stmts.size(); stmt_idx++) {
auto* stmt = stmts[stmt_idx];
if (auto* let_decl = AsTrivialLetDecl(stmt)) {
auto* let = let_decl->variable;
auto* sem_let = ctx.src->Sem().Get(let);
auto& users = sem_let->Users();
if (users.size() != 1) {
continue; // Does not have a single user.
}
auto* user = users[0];
auto* user_stmt = user->Stmt()->Declaration();
for (size_t i = stmt_idx; i < stmts.size(); i++) {
if (user_stmt == stmts[i]) {
auto* user_expr = user->Declaration();
ctx.Remove(stmts, let_decl);
ctx.Replace(user_expr, ctx.Clone(let->constructor));
}
if (!AsTrivialLetDecl(stmts[i])) {
// Stop if we hit a statement that isn't the single use of the
// let, and isn't a let itself.
break;
}
}
}
}
}
}
ctx.Clone();
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,61 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_FOLD_TRIVIAL_SINGLE_USE_LETS_H_
#define SRC_TINT_TRANSFORM_FOLD_TRIVIAL_SINGLE_USE_LETS_H_
#include <string>
#include <unordered_map>
#include "src/tint/transform/transform.h"
namespace tint {
namespace transform {
/// FoldTrivialSingleUseLets is an optimizer for folding away trivial `let`
/// statements into their single place of use. This transform is intended to
/// clean up the SSA `let`s produced by the SPIR-V reader.
/// `let`s can only be folded if:
/// * There is a single usage of the `let` value.
/// * The `let` is constructed with a ScalarConstructorExpression, or with an
/// IdentifierExpression.
/// * There are only other foldable `let`s between the `let` declaration and its
/// single usage.
/// These rules prevent any hoisting of the let that may affect execution
/// behaviour.
class FoldTrivialSingleUseLets
: public Castable<FoldTrivialSingleUseLets, Transform> {
public:
/// Constructor
FoldTrivialSingleUseLets();
/// Destructor
~FoldTrivialSingleUseLets() override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_FOLD_TRIVIAL_SINGLE_USE_LETS_H_

View File

@@ -0,0 +1,188 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/fold_trivial_single_use_lets.h"
#include "src/tint/transform/test_helper.h"
namespace tint {
namespace transform {
namespace {
using FoldTrivialSingleUseLetsTest = TransformTest;
TEST_F(FoldTrivialSingleUseLetsTest, EmptyModule) {
auto* src = "";
auto* expect = "";
auto got = Run<FoldTrivialSingleUseLets>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(FoldTrivialSingleUseLetsTest, Single) {
auto* src = R"(
fn f() {
let x = 1;
_ = x;
}
)";
auto* expect = R"(
fn f() {
_ = 1;
}
)";
auto got = Run<FoldTrivialSingleUseLets>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(FoldTrivialSingleUseLetsTest, Multiple) {
auto* src = R"(
fn f() {
let x = 1;
let y = 2;
let z = 3;
_ = x + y + z;
}
)";
auto* expect = R"(
fn f() {
_ = ((1 + 2) + 3);
}
)";
auto got = Run<FoldTrivialSingleUseLets>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(FoldTrivialSingleUseLetsTest, Chained) {
auto* src = R"(
fn f() {
let x = 1;
let y = x;
let z = y;
_ = z;
}
)";
auto* expect = R"(
fn f() {
_ = 1;
}
)";
auto got = Run<FoldTrivialSingleUseLets>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(FoldTrivialSingleUseLetsTest, NoFold_NonTrivialLet) {
auto* src = R"(
fn function_with_posssible_side_effect() -> i32 {
return 1;
}
fn f() {
let x = 1;
let y = function_with_posssible_side_effect();
_ = (x + y);
}
)";
auto* expect = src;
auto got = Run<FoldTrivialSingleUseLets>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(FoldTrivialSingleUseLetsTest, NoFold_NonTrivialLet_OutOfOrder) {
auto* src = R"(
fn f() {
let x = 1;
let y = function_with_posssible_side_effect();
_ = (x + y);
}
fn function_with_posssible_side_effect() -> i32 {
return 1;
}
)";
auto* expect = src;
auto got = Run<FoldTrivialSingleUseLets>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(FoldTrivialSingleUseLetsTest, NoFold_UseInSubBlock) {
auto* src = R"(
fn f() {
let x = 1;
{
_ = x;
}
}
)";
auto* expect = src;
auto got = Run<FoldTrivialSingleUseLets>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(FoldTrivialSingleUseLetsTest, NoFold_MultipleUses) {
auto* src = R"(
fn f() {
let x = 1;
_ = (x + x);
}
)";
auto* expect = src;
auto got = Run<FoldTrivialSingleUseLets>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(FoldTrivialSingleUseLetsTest, NoFold_Shadowing) {
auto* src = R"(
fn f() {
var y = 1;
let x = y;
{
let y = false;
_ = (x + x);
}
}
)";
auto* expect = src;
auto got = Run<FoldTrivialSingleUseLets>(src);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,76 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/for_loop_to_loop.h"
#include "src/tint/ast/break_statement.h"
#include "src/tint/program_builder.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::ForLoopToLoop);
namespace tint {
namespace transform {
ForLoopToLoop::ForLoopToLoop() = default;
ForLoopToLoop::~ForLoopToLoop() = default;
bool ForLoopToLoop::ShouldRun(const Program* program, const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (node->Is<ast::ForLoopStatement>()) {
return true;
}
}
return false;
}
void ForLoopToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
ctx.ReplaceAll(
[&](const ast::ForLoopStatement* for_loop) -> const ast::Statement* {
ast::StatementList stmts;
if (auto* cond = for_loop->condition) {
// !condition
auto* not_cond = ctx.dst->create<ast::UnaryOpExpression>(
ast::UnaryOp::kNot, ctx.Clone(cond));
// { break; }
auto* break_body =
ctx.dst->Block(ctx.dst->create<ast::BreakStatement>());
// if (!condition) { break; }
stmts.emplace_back(ctx.dst->If(not_cond, break_body));
}
for (auto* stmt : for_loop->body->statements) {
stmts.emplace_back(ctx.Clone(stmt));
}
const ast::BlockStatement* continuing = nullptr;
if (auto* cont = for_loop->continuing) {
continuing = ctx.dst->Block(ctx.Clone(cont));
}
auto* body = ctx.dst->Block(stmts);
auto* loop = ctx.dst->create<ast::LoopStatement>(body, continuing);
if (auto* init = for_loop->initializer) {
return ctx.dst->Block(ctx.Clone(init), loop);
}
return loop;
});
ctx.Clone();
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,54 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_FOR_LOOP_TO_LOOP_H_
#define SRC_TINT_TRANSFORM_FOR_LOOP_TO_LOOP_H_
#include "src/tint/transform/transform.h"
namespace tint {
namespace transform {
/// ForLoopToLoop is a Transform that converts a for-loop statement into a loop
/// statement. This is required by the SPIR-V writer.
class ForLoopToLoop : public Castable<ForLoopToLoop, Transform> {
public:
/// Constructor
ForLoopToLoop();
/// Destructor
~ForLoopToLoop() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_FOR_LOOP_TO_LOOP_H_

View File

@@ -0,0 +1,374 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/for_loop_to_loop.h"
#include "src/tint/transform/test_helper.h"
namespace tint {
namespace transform {
namespace {
using ForLoopToLoopTest = TransformTest;
TEST_F(ForLoopToLoopTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<ForLoopToLoop>(src));
}
TEST_F(ForLoopToLoopTest, ShouldRunHasForLoop) {
auto* src = R"(
fn f() {
for (;;) {
break;
}
}
)";
EXPECT_TRUE(ShouldRun<ForLoopToLoop>(src));
}
TEST_F(ForLoopToLoopTest, EmptyModule) {
auto* src = "";
auto* expect = src;
auto got = Run<ForLoopToLoop>(src);
EXPECT_EQ(expect, str(got));
}
// Test an empty for loop.
TEST_F(ForLoopToLoopTest, Empty) {
auto* src = R"(
fn f() {
for (;;) {
break;
}
}
)";
auto* expect = R"(
fn f() {
loop {
break;
}
}
)";
auto got = Run<ForLoopToLoop>(src);
EXPECT_EQ(expect, str(got));
}
// Test a for loop with non-empty body.
TEST_F(ForLoopToLoopTest, Body) {
auto* src = R"(
fn f() {
for (;;) {
discard;
}
}
)";
auto* expect = R"(
fn f() {
loop {
discard;
}
}
)";
auto got = Run<ForLoopToLoop>(src);
EXPECT_EQ(expect, str(got));
}
// Test a for loop declaring a variable in the initializer statement.
TEST_F(ForLoopToLoopTest, InitializerStatementDecl) {
auto* src = R"(
fn f() {
for (var i: i32;;) {
break;
}
}
)";
auto* expect = R"(
fn f() {
{
var i : i32;
loop {
break;
}
}
}
)";
auto got = Run<ForLoopToLoop>(src);
EXPECT_EQ(expect, str(got));
}
// Test a for loop declaring and initializing a variable in the initializer
// statement.
TEST_F(ForLoopToLoopTest, InitializerStatementDeclEqual) {
auto* src = R"(
fn f() {
for (var i: i32 = 0;;) {
break;
}
}
)";
auto* expect = R"(
fn f() {
{
var i : i32 = 0;
loop {
break;
}
}
}
)";
auto got = Run<ForLoopToLoop>(src);
EXPECT_EQ(expect, str(got));
}
// Test a for loop declaring a const variable in the initializer statement.
TEST_F(ForLoopToLoopTest, InitializerStatementConstDecl) {
auto* src = R"(
fn f() {
for (let i: i32 = 0;;) {
break;
}
}
)";
auto* expect = R"(
fn f() {
{
let i : i32 = 0;
loop {
break;
}
}
}
)";
auto got = Run<ForLoopToLoop>(src);
EXPECT_EQ(expect, str(got));
}
// Test a for loop assigning a variable in the initializer statement.
TEST_F(ForLoopToLoopTest, InitializerStatementAssignment) {
auto* src = R"(
fn f() {
var i: i32;
for (i = 0;;) {
break;
}
}
)";
auto* expect = R"(
fn f() {
var i : i32;
{
i = 0;
loop {
break;
}
}
}
)";
auto got = Run<ForLoopToLoop>(src);
EXPECT_EQ(expect, str(got));
}
// Test a for loop calling a function in the initializer statement.
TEST_F(ForLoopToLoopTest, InitializerStatementFuncCall) {
auto* src = R"(
fn a(x : i32, y : i32) {
}
fn f() {
var b : i32;
var c : i32;
for (a(b,c);;) {
break;
}
}
)";
auto* expect = R"(
fn a(x : i32, y : i32) {
}
fn f() {
var b : i32;
var c : i32;
{
a(b, c);
loop {
break;
}
}
}
)";
auto got = Run<ForLoopToLoop>(src);
EXPECT_EQ(expect, str(got));
}
// Test a for loop with a break condition
TEST_F(ForLoopToLoopTest, BreakCondition) {
auto* src = R"(
fn f() {
for (; 0 == 1;) {
}
}
)";
auto* expect = R"(
fn f() {
loop {
if (!((0 == 1))) {
break;
}
}
}
)";
auto got = Run<ForLoopToLoop>(src);
EXPECT_EQ(expect, str(got));
}
// Test a for loop assigning a variable in the continuing statement.
TEST_F(ForLoopToLoopTest, ContinuingAssignment) {
auto* src = R"(
fn f() {
var x: i32;
for (;;x = 2) {
break;
}
}
)";
auto* expect = R"(
fn f() {
var x : i32;
loop {
break;
continuing {
x = 2;
}
}
}
)";
auto got = Run<ForLoopToLoop>(src);
EXPECT_EQ(expect, str(got));
}
// Test a for loop calling a function in the continuing statement.
TEST_F(ForLoopToLoopTest, ContinuingFuncCall) {
auto* src = R"(
fn a(x : i32, y : i32) {
}
fn f() {
var b : i32;
var c : i32;
for (;;a(b,c)) {
break;
}
}
)";
auto* expect = R"(
fn a(x : i32, y : i32) {
}
fn f() {
var b : i32;
var c : i32;
loop {
break;
continuing {
a(b, c);
}
}
}
)";
auto got = Run<ForLoopToLoop>(src);
EXPECT_EQ(expect, str(got));
}
// Test a for loop with all statements non-empty.
TEST_F(ForLoopToLoopTest, All) {
auto* src = R"(
fn f() {
var a : i32;
for(var i : i32 = 0; i < 4; i = i + 1) {
if (a == 0) {
continue;
}
a = a + 2;
}
}
)";
auto* expect = R"(
fn f() {
var a : i32;
{
var i : i32 = 0;
loop {
if (!((i < 4))) {
break;
}
if ((a == 0)) {
continue;
}
a = (a + 2);
continuing {
i = (i + 1);
}
}
}
}
)";
auto got = Run<ForLoopToLoop>(src);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

118
src/tint/transform/glsl.cc Normal file
View File

@@ -0,0 +1,118 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/glsl.h"
#include <utility>
#include "src/tint/program_builder.h"
#include "src/tint/transform/add_empty_entry_point.h"
#include "src/tint/transform/add_spirv_block_attribute.h"
#include "src/tint/transform/binding_remapper.h"
#include "src/tint/transform/canonicalize_entry_point_io.h"
#include "src/tint/transform/combine_samplers.h"
#include "src/tint/transform/decompose_memory_access.h"
#include "src/tint/transform/external_texture_transform.h"
#include "src/tint/transform/fold_trivial_single_use_lets.h"
#include "src/tint/transform/loop_to_for_loop.h"
#include "src/tint/transform/manager.h"
#include "src/tint/transform/pad_array_elements.h"
#include "src/tint/transform/promote_initializers_to_const_var.h"
#include "src/tint/transform/remove_phonies.h"
#include "src/tint/transform/renamer.h"
#include "src/tint/transform/simplify_pointers.h"
#include "src/tint/transform/single_entry_point.h"
#include "src/tint/transform/unshadow.h"
#include "src/tint/transform/zero_init_workgroup_memory.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::Glsl);
TINT_INSTANTIATE_TYPEINFO(tint::transform::Glsl::Config);
namespace tint {
namespace transform {
Glsl::Glsl() = default;
Glsl::~Glsl() = default;
Output Glsl::Run(const Program* in, const DataMap& inputs) const {
Manager manager;
DataMap data;
auto* cfg = inputs.Get<Config>();
if (cfg && !cfg->entry_point.empty()) {
manager.Add<SingleEntryPoint>();
data.Add<SingleEntryPoint::Config>(cfg->entry_point);
}
manager.Add<Renamer>();
data.Add<Renamer::Config>(Renamer::Target::kGlslKeywords,
/* preserve_unicode */ false);
manager.Add<Unshadow>();
// Attempt to convert `loop`s into for-loops. This is to try and massage the
// output into something that will not cause FXC to choke or misbehave.
manager.Add<FoldTrivialSingleUseLets>();
manager.Add<LoopToForLoop>();
if (!cfg || !cfg->disable_workgroup_init) {
// ZeroInitWorkgroupMemory must come before CanonicalizeEntryPointIO as
// ZeroInitWorkgroupMemory may inject new builtin parameters.
manager.Add<ZeroInitWorkgroupMemory>();
}
manager.Add<CanonicalizeEntryPointIO>();
manager.Add<SimplifyPointers>();
manager.Add<RemovePhonies>();
manager.Add<CombineSamplers>();
if (auto* binding_info = inputs.Get<CombineSamplers::BindingInfo>()) {
data.Add<CombineSamplers::BindingInfo>(*binding_info);
} else {
data.Add<CombineSamplers::BindingInfo>(CombineSamplers::BindingMap(),
sem::BindingPoint());
}
manager.Add<BindingRemapper>();
if (auto* remappings = inputs.Get<BindingRemapper::Remappings>()) {
data.Add<BindingRemapper::Remappings>(*remappings);
} else {
BindingRemapper::BindingPoints bp;
BindingRemapper::AccessControls ac;
data.Add<BindingRemapper::Remappings>(bp, ac, /* mayCollide */ true);
}
manager.Add<ExternalTextureTransform>();
manager.Add<PromoteInitializersToConstVar>();
manager.Add<PadArrayElements>();
manager.Add<AddEmptyEntryPoint>();
manager.Add<AddSpirvBlockAttribute>();
data.Add<CanonicalizeEntryPointIO::Config>(
CanonicalizeEntryPointIO::ShaderStyle::kGlsl);
auto out = manager.Run(in, data);
if (!out.program.IsValid()) {
return out;
}
ProgramBuilder builder;
CloneContext ctx(&builder, &out.program);
ctx.Clone();
return Output{Program(std::move(builder))};
}
Glsl::Config::Config(const std::string& ep, bool disable_wi)
: entry_point(ep), disable_workgroup_init(disable_wi) {}
Glsl::Config::Config(const Config&) = default;
Glsl::Config::~Config() = default;
} // namespace transform
} // namespace tint

70
src/tint/transform/glsl.h Normal file
View File

@@ -0,0 +1,70 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_GLSL_H_
#define SRC_TINT_TRANSFORM_GLSL_H_
#include <string>
#include "src/tint/transform/transform.h"
namespace tint {
// Forward declarations
class CloneContext;
namespace transform {
/// Glsl is a transform used to sanitize a Program for use with the Glsl writer.
/// Passing a non-sanitized Program to the Glsl writer will result in undefined
/// behavior.
class Glsl : public Castable<Glsl, Transform> {
public:
/// Configuration options for the Glsl sanitizer transform.
struct Config : public Castable<Data, transform::Data> {
/// Constructor
/// @param entry_point the root entry point function to generate
/// @param disable_workgroup_init `true` to disable workgroup memory zero
/// initialization
explicit Config(const std::string& entry_point,
bool disable_workgroup_init = false);
/// Copy constructor
Config(const Config&);
/// Destructor
~Config() override;
/// GLSL generator wraps a single entry point in a main() function.
std::string entry_point;
/// Set to `true` to disable workgroup memory zero initialization
bool disable_workgroup_init = false;
};
/// Constructor
Glsl();
~Glsl() override;
/// Runs the transform on `program`, returning the transformation result.
/// @param program the source program to transform
/// @param data optional extra transform-specific data
/// @returns the transformation result
Output Run(const Program* program, const DataMap& data = {}) const override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_GLSL_H_

View File

@@ -0,0 +1,41 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/glsl.h"
#include "src/tint/transform/test_helper.h"
namespace tint {
namespace transform {
namespace {
using GlslTest = TransformTest;
TEST_F(GlslTest, AddEmptyEntryPoint) {
auto* src = R"()";
auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn unused_entry_point() {
}
)";
auto got = Run<Glsl>(src);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,224 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/localize_struct_array_assignment.h"
#include <unordered_map>
#include <utility>
#include "src/tint/ast/assignment_statement.h"
#include "src/tint/ast/traverse_expressions.h"
#include "src/tint/program_builder.h"
#include "src/tint/sem/expression.h"
#include "src/tint/sem/member_accessor_expression.h"
#include "src/tint/sem/reference_type.h"
#include "src/tint/sem/statement.h"
#include "src/tint/sem/variable.h"
#include "src/tint/transform/simplify_pointers.h"
#include "src/tint/utils/scoped_assignment.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::LocalizeStructArrayAssignment);
namespace tint {
namespace transform {
/// Private implementation of LocalizeStructArrayAssignment transform
class LocalizeStructArrayAssignment::State {
private:
CloneContext& ctx;
ProgramBuilder& b;
/// Returns true if `expr` contains an index accessor expression to a
/// structure member of array type.
bool ContainsStructArrayIndex(const ast::Expression* expr) {
bool result = false;
ast::TraverseExpressions(
expr, b.Diagnostics(), [&](const ast::IndexAccessorExpression* ia) {
// Indexing using a runtime value?
auto* idx_sem = ctx.src->Sem().Get(ia->index);
if (!idx_sem->ConstantValue().IsValid()) {
// Indexing a member access expr?
if (auto* ma = ia->object->As<ast::MemberAccessorExpression>()) {
// That accesses an array?
if (ctx.src->TypeOf(ma)->UnwrapRef()->Is<sem::Array>()) {
result = true;
return ast::TraverseAction::Stop;
}
}
}
return ast::TraverseAction::Descend;
});
return result;
}
// Returns the type and storage class of the originating variable of the lhs
// of the assignment statement.
// See https://www.w3.org/TR/WGSL/#originating-variable-section
std::pair<const sem::Type*, ast::StorageClass>
GetOriginatingTypeAndStorageClass(
const ast::AssignmentStatement* assign_stmt) {
// Get first IdentifierExpr from lhs of assignment, which should resolve to
// the pointer or reference of the originating variable of the assignment.
// TraverseExpressions traverses left to right, and this code depends on the
// fact that for an assignment statement, the variable will be the left-most
// expression.
// TODO(crbug.com/tint/1341): do this in the Resolver, setting the
// originating variable on sem::Expression.
const ast::IdentifierExpression* ident = nullptr;
ast::TraverseExpressions(assign_stmt->lhs, b.Diagnostics(),
[&](const ast::IdentifierExpression* id) {
ident = id;
return ast::TraverseAction::Stop;
});
auto* sem_var_user = ctx.src->Sem().Get<sem::VariableUser>(ident);
if (!sem_var_user) {
TINT_ICE(Transform, b.Diagnostics())
<< "Expected to find variable of lhs of assignment statement";
return {};
}
auto* var = sem_var_user->Variable();
if (auto* ptr = var->Type()->As<sem::Pointer>()) {
return {ptr->StoreType(), ptr->StorageClass()};
}
auto* ref = var->Type()->As<sem::Reference>();
if (!ref) {
TINT_ICE(Transform, b.Diagnostics())
<< "Expecting to find variable of type pointer or reference on lhs "
"of assignment statement";
return {};
}
return {ref->StoreType(), ref->StorageClass()};
}
public:
/// Constructor
/// @param ctx_in the CloneContext primed with the input program and
/// ProgramBuilder
explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {}
/// Runs the transform
void Run() {
struct Shared {
bool process_nested_nodes = false;
ast::StatementList insert_before_stmts;
ast::StatementList insert_after_stmts;
} s;
ctx.ReplaceAll([&](const ast::AssignmentStatement* assign_stmt)
-> const ast::Statement* {
// Process if it's an assignment statement to a dynamically indexed array
// within a struct on a function or private storage variable. This
// specific use-case is what FXC fails to compile with:
// error X3500: array reference cannot be used as an l-value; not natively
// addressable
if (!ContainsStructArrayIndex(assign_stmt->lhs)) {
return nullptr;
}
auto og = GetOriginatingTypeAndStorageClass(assign_stmt);
if (!(og.first->Is<sem::Struct>() &&
(og.second == ast::StorageClass::kFunction ||
og.second == ast::StorageClass::kPrivate))) {
return nullptr;
}
// Reset shared state for this assignment statement
s = Shared{};
const ast::Expression* new_lhs = nullptr;
{
TINT_SCOPED_ASSIGNMENT(s.process_nested_nodes, true);
new_lhs = ctx.Clone(assign_stmt->lhs);
}
auto* new_assign_stmt = b.Assign(new_lhs, ctx.Clone(assign_stmt->rhs));
// Combine insert_before_stmts + new_assign_stmt + insert_after_stmts into
// a block and return it
ast::StatementList stmts = std::move(s.insert_before_stmts);
stmts.reserve(1 + s.insert_after_stmts.size());
stmts.emplace_back(new_assign_stmt);
stmts.insert(stmts.end(), s.insert_after_stmts.begin(),
s.insert_after_stmts.end());
return b.Block(std::move(stmts));
});
ctx.ReplaceAll([&](const ast::IndexAccessorExpression* index_access)
-> const ast::Expression* {
if (!s.process_nested_nodes) {
return nullptr;
}
// Indexing a member access expr?
auto* mem_access =
index_access->object->As<ast::MemberAccessorExpression>();
if (!mem_access) {
return nullptr;
}
// Process any nested IndexAccessorExpressions
mem_access = ctx.Clone(mem_access);
// Store the address of the member access into a let as we need to read
// the value twice e.g. let tint_symbol = &(s.a1);
auto mem_access_ptr = b.Sym();
s.insert_before_stmts.push_back(
b.Decl(b.Const(mem_access_ptr, nullptr, b.AddressOf(mem_access))));
// Disable further transforms when cloning
TINT_SCOPED_ASSIGNMENT(s.process_nested_nodes, false);
// Copy entire array out of struct into local temp var
// e.g. var tint_symbol_1 = *(tint_symbol);
auto tmp_var = b.Sym();
s.insert_before_stmts.push_back(
b.Decl(b.Var(tmp_var, nullptr, b.Deref(mem_access_ptr))));
// Replace input index_access with a clone of itself, but with its
// .object replaced by the new temp var. This is returned from this
// function to modify the original assignment statement. e.g.
// tint_symbol_1[uniforms.i]
auto* new_index_access =
b.IndexAccessor(tmp_var, ctx.Clone(index_access->index));
// Assign temp var back to array
// e.g. *(tint_symbol) = tint_symbol_1;
auto* assign_rhs_to_temp = b.Assign(b.Deref(mem_access_ptr), tmp_var);
s.insert_after_stmts.insert(s.insert_after_stmts.begin(),
assign_rhs_to_temp); // push_front
return new_index_access;
});
ctx.Clone();
}
};
LocalizeStructArrayAssignment::LocalizeStructArrayAssignment() = default;
LocalizeStructArrayAssignment::~LocalizeStructArrayAssignment() = default;
void LocalizeStructArrayAssignment::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
State state(ctx);
state.Run();
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,58 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_LOCALIZE_STRUCT_ARRAY_ASSIGNMENT_H_
#define SRC_TINT_TRANSFORM_LOCALIZE_STRUCT_ARRAY_ASSIGNMENT_H_
#include "src/tint/transform/transform.h"
namespace tint {
namespace transform {
/// This transforms replaces assignment to dynamically-indexed fixed-size arrays
/// in structs on shader-local variables with code that copies the arrays to a
/// temporary local variable, assigns to the local variable, and copies the
/// array back. This is to work around FXC's compilation failure for these cases
/// (see crbug.com/tint/1206).
///
/// @note Depends on the following transforms to have been run first:
/// * SimplifyPointers
class LocalizeStructArrayAssignment
: public Castable<LocalizeStructArrayAssignment, Transform> {
public:
/// Constructor
LocalizeStructArrayAssignment();
/// Destructor
~LocalizeStructArrayAssignment() override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
private:
class State;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_LOCALIZE_STRUCT_ARRAY_ASSIGNMENT_H_

View File

@@ -0,0 +1,884 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/localize_struct_array_assignment.h"
#include "src/tint/transform/simplify_pointers.h"
#include "src/tint/transform/unshadow.h"
#include "src/tint/transform/test_helper.h"
namespace tint {
namespace transform {
namespace {
using LocalizeStructArrayAssignmentTest = TransformTest;
TEST_F(LocalizeStructArrayAssignmentTest, EmptyModule) {
auto* src = R"()";
auto* expect = src;
auto got =
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, StructArray) {
auto* src = R"(
@block struct Uniforms {
i : u32;
};
struct InnerS {
v : i32;
};
struct OuterS {
a1 : array<InnerS, 8>;
};
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
var s1 : OuterS;
s1.a1[uniforms.i] = v;
}
)";
auto* expect = R"(
@block
struct Uniforms {
i : u32;
}
struct InnerS {
v : i32;
}
struct OuterS {
a1 : array<InnerS, 8>;
}
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
var s1 : OuterS;
{
let tint_symbol = &(s1.a1);
var tint_symbol_1 = *(tint_symbol);
tint_symbol_1[uniforms.i] = v;
*(tint_symbol) = tint_symbol_1;
}
}
)";
auto got =
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, StructArray_OutOfOrder) {
auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
var s1 : OuterS;
s1.a1[uniforms.i] = v;
}
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
struct OuterS {
a1 : array<InnerS, 8>;
};
struct InnerS {
v : i32;
};
@block struct Uniforms {
i : u32;
};
)";
auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
var s1 : OuterS;
{
let tint_symbol = &(s1.a1);
var tint_symbol_1 = *(tint_symbol);
tint_symbol_1[uniforms.i] = v;
*(tint_symbol) = tint_symbol_1;
}
}
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
struct OuterS {
a1 : array<InnerS, 8>;
}
struct InnerS {
v : i32;
}
@block
struct Uniforms {
i : u32;
}
)";
auto got =
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, StructStructArray) {
auto* src = R"(
@block struct Uniforms {
i : u32;
};
struct InnerS {
v : i32;
};
struct S1 {
a : array<InnerS, 8>;
};
struct OuterS {
s2 : S1;
};
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
var s1 : OuterS;
s1.s2.a[uniforms.i] = v;
}
)";
auto* expect = R"(
@block
struct Uniforms {
i : u32;
}
struct InnerS {
v : i32;
}
struct S1 {
a : array<InnerS, 8>;
}
struct OuterS {
s2 : S1;
}
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
var s1 : OuterS;
{
let tint_symbol = &(s1.s2.a);
var tint_symbol_1 = *(tint_symbol);
tint_symbol_1[uniforms.i] = v;
*(tint_symbol) = tint_symbol_1;
}
}
)";
auto got =
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, StructStructArray_OutOfOrder) {
auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
var s1 : OuterS;
s1.s2.a[uniforms.i] = v;
}
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
struct OuterS {
s2 : S1;
};
struct S1 {
a : array<InnerS, 8>;
};
struct InnerS {
v : i32;
};
@block struct Uniforms {
i : u32;
};
)";
auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
var s1 : OuterS;
{
let tint_symbol = &(s1.s2.a);
var tint_symbol_1 = *(tint_symbol);
tint_symbol_1[uniforms.i] = v;
*(tint_symbol) = tint_symbol_1;
}
}
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
struct OuterS {
s2 : S1;
}
struct S1 {
a : array<InnerS, 8>;
}
struct InnerS {
v : i32;
}
@block
struct Uniforms {
i : u32;
}
)";
auto got =
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, StructArrayArray) {
auto* src = R"(
@block struct Uniforms {
i : u32;
j : u32;
};
struct InnerS {
v : i32;
};
struct OuterS {
a1 : array<array<InnerS, 8>, 8>;
};
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
var s1 : OuterS;
s1.a1[uniforms.i][uniforms.j] = v;
}
)";
auto* expect = R"(
@block
struct Uniforms {
i : u32;
j : u32;
}
struct InnerS {
v : i32;
}
struct OuterS {
a1 : array<array<InnerS, 8>, 8>;
}
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
var s1 : OuterS;
{
let tint_symbol = &(s1.a1);
var tint_symbol_1 = *(tint_symbol);
tint_symbol_1[uniforms.i][uniforms.j] = v;
*(tint_symbol) = tint_symbol_1;
}
}
)";
auto got =
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, StructArrayStruct) {
auto* src = R"(
@block struct Uniforms {
i : u32;
};
struct InnerS {
v : i32;
};
struct S1 {
s2 : InnerS;
};
struct OuterS {
a1 : array<S1, 8>;
};
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
var s1 : OuterS;
s1.a1[uniforms.i].s2 = v;
}
)";
auto* expect = R"(
@block
struct Uniforms {
i : u32;
}
struct InnerS {
v : i32;
}
struct S1 {
s2 : InnerS;
}
struct OuterS {
a1 : array<S1, 8>;
}
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
var s1 : OuterS;
{
let tint_symbol = &(s1.a1);
var tint_symbol_1 = *(tint_symbol);
tint_symbol_1[uniforms.i].s2 = v;
*(tint_symbol) = tint_symbol_1;
}
}
)";
auto got =
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, StructArrayStructArray) {
auto* src = R"(
@block struct Uniforms {
i : u32;
j : u32;
};
struct InnerS {
v : i32;
};
struct S1 {
a2 : array<InnerS, 8>;
};
struct OuterS {
a1 : array<S1, 8>;
};
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
var s : OuterS;
s.a1[uniforms.i].a2[uniforms.j] = v;
}
)";
auto* expect = R"(
@block
struct Uniforms {
i : u32;
j : u32;
}
struct InnerS {
v : i32;
}
struct S1 {
a2 : array<InnerS, 8>;
}
struct OuterS {
a1 : array<S1, 8>;
}
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
var s : OuterS;
{
let tint_symbol = &(s.a1);
var tint_symbol_1 = *(tint_symbol);
let tint_symbol_2 = &(tint_symbol_1[uniforms.i].a2);
var tint_symbol_3 = *(tint_symbol_2);
tint_symbol_3[uniforms.j] = v;
*(tint_symbol_2) = tint_symbol_3;
*(tint_symbol) = tint_symbol_1;
}
}
)";
auto got =
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, IndexingWithSideEffectFunc) {
auto* src = R"(
@block struct Uniforms {
i : u32;
j : u32;
};
struct InnerS {
v : i32;
};
struct S1 {
a2 : array<InnerS, 8>;
};
struct OuterS {
a1 : array<S1, 8>;
};
var<private> nextIndex : u32;
fn getNextIndex() -> u32 {
nextIndex = nextIndex + 1u;
return nextIndex;
}
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
var s : OuterS;
s.a1[getNextIndex()].a2[uniforms.j] = v;
}
)";
auto* expect = R"(
@block
struct Uniforms {
i : u32;
j : u32;
}
struct InnerS {
v : i32;
}
struct S1 {
a2 : array<InnerS, 8>;
}
struct OuterS {
a1 : array<S1, 8>;
}
var<private> nextIndex : u32;
fn getNextIndex() -> u32 {
nextIndex = (nextIndex + 1u);
return nextIndex;
}
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
var s : OuterS;
{
let tint_symbol = &(s.a1);
var tint_symbol_1 = *(tint_symbol);
let tint_symbol_2 = &(tint_symbol_1[getNextIndex()].a2);
var tint_symbol_3 = *(tint_symbol_2);
tint_symbol_3[uniforms.j] = v;
*(tint_symbol_2) = tint_symbol_3;
*(tint_symbol) = tint_symbol_1;
}
}
)";
auto got =
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest,
IndexingWithSideEffectFunc_OutOfOrder) {
auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
var s : OuterS;
s.a1[getNextIndex()].a2[uniforms.j] = v;
}
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
@block struct Uniforms {
i : u32;
j : u32;
};
var<private> nextIndex : u32;
fn getNextIndex() -> u32 {
nextIndex = nextIndex + 1u;
return nextIndex;
}
struct OuterS {
a1 : array<S1, 8>;
};
struct S1 {
a2 : array<InnerS, 8>;
};
struct InnerS {
v : i32;
};
)";
auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
var s : OuterS;
{
let tint_symbol = &(s.a1);
var tint_symbol_1 = *(tint_symbol);
let tint_symbol_2 = &(tint_symbol_1[getNextIndex()].a2);
var tint_symbol_3 = *(tint_symbol_2);
tint_symbol_3[uniforms.j] = v;
*(tint_symbol_2) = tint_symbol_3;
*(tint_symbol) = tint_symbol_1;
}
}
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
@block
struct Uniforms {
i : u32;
j : u32;
}
var<private> nextIndex : u32;
fn getNextIndex() -> u32 {
nextIndex = (nextIndex + 1u);
return nextIndex;
}
struct OuterS {
a1 : array<S1, 8>;
}
struct S1 {
a2 : array<InnerS, 8>;
}
struct InnerS {
v : i32;
}
)";
auto got =
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, ViaPointerArg) {
auto* src = R"(
@block struct Uniforms {
i : u32;
};
struct InnerS {
v : i32;
};
struct OuterS {
a1 : array<InnerS, 8>;
};
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
fn f(p : ptr<function, OuterS>) {
var v : InnerS;
(*p).a1[uniforms.i] = v;
}
@stage(compute) @workgroup_size(1)
fn main() {
var s1 : OuterS;
f(&s1);
}
)";
auto* expect = R"(
@block
struct Uniforms {
i : u32;
}
struct InnerS {
v : i32;
}
struct OuterS {
a1 : array<InnerS, 8>;
}
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
fn f(p : ptr<function, OuterS>) {
var v : InnerS;
{
let tint_symbol = &((*(p)).a1);
var tint_symbol_1 = *(tint_symbol);
tint_symbol_1[uniforms.i] = v;
*(tint_symbol) = tint_symbol_1;
}
}
@stage(compute) @workgroup_size(1)
fn main() {
var s1 : OuterS;
f(&(s1));
}
)";
auto got =
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, ViaPointerArg_OutOfOrder) {
auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var s1 : OuterS;
f(&s1);
}
fn f(p : ptr<function, OuterS>) {
var v : InnerS;
(*p).a1[uniforms.i] = v;
}
struct InnerS {
v : i32;
};
struct OuterS {
a1 : array<InnerS, 8>;
};
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
@block struct Uniforms {
i : u32;
};
)";
auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn main() {
var s1 : OuterS;
f(&(s1));
}
fn f(p : ptr<function, OuterS>) {
var v : InnerS;
{
let tint_symbol = &((*(p)).a1);
var tint_symbol_1 = *(tint_symbol);
tint_symbol_1[uniforms.i] = v;
*(tint_symbol) = tint_symbol_1;
}
}
struct InnerS {
v : i32;
}
struct OuterS {
a1 : array<InnerS, 8>;
}
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
@block
struct Uniforms {
i : u32;
}
)";
auto got =
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, ViaPointerVar) {
auto* src = R"(
@block
struct Uniforms {
i : u32;
};
struct InnerS {
v : i32;
};
struct OuterS {
a1 : array<InnerS, 8>;
};
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
fn f(p : ptr<function, InnerS>, v : InnerS) {
*(p) = v;
}
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
var s1 : OuterS;
let p = &(s1.a1[uniforms.i]);
*(p) = v;
}
)";
auto* expect = R"(
@block
struct Uniforms {
i : u32;
}
struct InnerS {
v : i32;
}
struct OuterS {
a1 : array<InnerS, 8>;
}
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
fn f(p : ptr<function, InnerS>, v : InnerS) {
*(p) = v;
}
@stage(compute) @workgroup_size(1)
fn main() {
var v : InnerS;
var s1 : OuterS;
let p_save = uniforms.i;
{
let tint_symbol = &(s1.a1);
var tint_symbol_1 = *(tint_symbol);
tint_symbol_1[p_save] = v;
*(tint_symbol) = tint_symbol_1;
}
}
)";
auto got =
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, VectorAssignment) {
auto* src = R"(
@block
struct Uniforms {
i : u32;
}
@block
struct OuterS {
a1 : array<u32, 8>;
}
@group(1) @binding(4) var<uniform> uniforms : Uniforms;
fn f(i : u32) -> u32 {
return (i + 1u);
}
@stage(compute) @workgroup_size(1)
fn main() {
var s1 : OuterS;
var v : vec3<f32>;
v[s1.a1[uniforms.i]] = 1.0;
v[f(s1.a1[uniforms.i])] = 1.0;
}
)";
// Transform does nothing here as we're not actually assigning to the array in
// the struct.
auto* expect = src;
auto got =
Run<Unshadow, SimplifyPointers, LocalizeStructArrayAssignment>(src);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,145 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/loop_to_for_loop.h"
#include "src/tint/ast/break_statement.h"
#include "src/tint/ast/for_loop_statement.h"
#include "src/tint/program_builder.h"
#include "src/tint/sem/block_statement.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/statement.h"
#include "src/tint/sem/variable.h"
#include "src/tint/utils/scoped_assignment.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::LoopToForLoop);
namespace tint {
namespace transform {
namespace {
bool IsBlockWithSingleBreak(const ast::BlockStatement* block) {
if (block->statements.size() != 1) {
return false;
}
return block->statements[0]->Is<ast::BreakStatement>();
}
bool IsVarUsedByStmt(const sem::Info& sem,
const ast::Variable* var,
const ast::Statement* stmt) {
auto* var_sem = sem.Get(var);
for (auto* user : var_sem->Users()) {
if (auto* s = user->Stmt()) {
if (s->Declaration() == stmt) {
return true;
}
}
}
return false;
}
} // namespace
LoopToForLoop::LoopToForLoop() = default;
LoopToForLoop::~LoopToForLoop() = default;
bool LoopToForLoop::ShouldRun(const Program* program, const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (node->Is<ast::LoopStatement>()) {
return true;
}
}
return false;
}
void LoopToForLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
ctx.ReplaceAll([&](const ast::LoopStatement* loop) -> const ast::Statement* {
// For loop condition is taken from the first statement in the loop.
// This requires an if-statement with either:
// * A true block with no else statements, and the true block contains a
// single 'break' statement.
// * An empty true block with a single, no-condition else statement
// containing a single 'break' statement.
// Examples:
// loop { if (condition) { break; } ... }
// loop { if (condition) {} else { break; } ... }
auto& stmts = loop->body->statements;
if (stmts.empty()) {
return nullptr;
}
auto* if_stmt = stmts[0]->As<ast::IfStatement>();
if (!if_stmt) {
return nullptr;
}
bool negate_condition = false;
if (IsBlockWithSingleBreak(if_stmt->body) &&
if_stmt->else_statements.empty()) {
negate_condition = true;
} else if (if_stmt->body->Empty() && if_stmt->else_statements.size() == 1 &&
if_stmt->else_statements[0]->condition == nullptr &&
IsBlockWithSingleBreak(if_stmt->else_statements[0]->body)) {
negate_condition = false;
} else {
return nullptr;
}
// The continuing block must be empty or contain a single, assignment or
// function call statement.
const ast::Statement* continuing = nullptr;
if (auto* loop_cont = loop->continuing) {
if (loop_cont->statements.size() != 1) {
return nullptr;
}
continuing = loop_cont->statements[0];
if (!continuing
->IsAnyOf<ast::AssignmentStatement, ast::CallStatement>()) {
return nullptr;
}
// And the continuing statement must not use any of the variables declared
// in the loop body.
for (auto* stmt : loop->body->statements) {
if (auto* var_decl = stmt->As<ast::VariableDeclStatement>()) {
if (IsVarUsedByStmt(ctx.src->Sem(), var_decl->variable, continuing)) {
return nullptr;
}
}
}
continuing = ctx.Clone(continuing);
}
auto* condition = ctx.Clone(if_stmt->condition);
if (negate_condition) {
condition = ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNot,
condition);
}
ast::Statement* initializer = nullptr;
ctx.Remove(loop->body->statements, if_stmt);
auto* body = ctx.Clone(loop->body);
return ctx.dst->create<ast::ForLoopStatement>(initializer, condition,
continuing, body);
});
ctx.Clone();
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,54 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_LOOP_TO_FOR_LOOP_H_
#define SRC_TINT_TRANSFORM_LOOP_TO_FOR_LOOP_H_
#include "src/tint/transform/transform.h"
namespace tint {
namespace transform {
/// LoopToForLoop is a Transform that attempts to convert WGSL `loop {}`
/// statements into a for-loop statement.
class LoopToForLoop : public Castable<LoopToForLoop, Transform> {
public:
/// Constructor
LoopToForLoop();
/// Destructor
~LoopToForLoop() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_LOOP_TO_FOR_LOOP_H_

View File

@@ -0,0 +1,308 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/loop_to_for_loop.h"
#include "src/tint/transform/test_helper.h"
namespace tint {
namespace transform {
namespace {
using LoopToForLoopTest = TransformTest;
TEST_F(LoopToForLoopTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<LoopToForLoop>(src));
}
TEST_F(LoopToForLoopTest, ShouldRunHasForLoop) {
auto* src = R"(
fn f() {
loop {
break;
}
}
)";
EXPECT_TRUE(ShouldRun<LoopToForLoop>(src));
}
TEST_F(LoopToForLoopTest, EmptyModule) {
auto* src = "";
auto* expect = "";
auto got = Run<LoopToForLoop>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(LoopToForLoopTest, IfBreak) {
auto* src = R"(
fn f() {
var i : i32;
i = 0;
loop {
if (i > 15) {
break;
}
_ = 123;
continuing {
i = i + 1;
}
}
}
)";
auto* expect = R"(
fn f() {
var i : i32;
i = 0;
for(; !((i > 15)); i = (i + 1)) {
_ = 123;
}
}
)";
auto got = Run<LoopToForLoop>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(LoopToForLoopTest, IfElseBreak) {
auto* src = R"(
fn f() {
var i : i32;
i = 0;
loop {
if (i < 15) {
} else {
break;
}
_ = 123;
continuing {
i = i + 1;
}
}
}
)";
auto* expect = R"(
fn f() {
var i : i32;
i = 0;
for(; (i < 15); i = (i + 1)) {
_ = 123;
}
}
)";
auto got = Run<LoopToForLoop>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(LoopToForLoopTest, Nested) {
auto* src = R"(
let N = 16u;
fn f() {
var i : u32 = 0u;
loop {
if (i >= N) {
break;
}
{
var j : u32 = 0u;
loop {
if (j >= N) {
break;
}
_ = i;
_ = j;
continuing {
j = (j + 1u);
}
}
}
continuing {
i = (i + 1u);
}
}
}
)";
auto* expect = R"(
let N = 16u;
fn f() {
var i : u32 = 0u;
for(; !((i >= N)); i = (i + 1u)) {
{
var j : u32 = 0u;
for(; !((j >= N)); j = (j + 1u)) {
_ = i;
_ = j;
}
}
}
}
)";
auto got = Run<LoopToForLoop>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(LoopToForLoopTest, NoTransform_IfMultipleStmts) {
auto* src = R"(
fn f() {
var i : i32;
i = 0;
loop {
if ((i < 15)) {
_ = i;
break;
}
_ = 123;
continuing {
i = (i + 1);
}
}
}
)";
auto* expect = src;
auto got = Run<LoopToForLoop>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(LoopToForLoopTest, NoTransform_IfElseMultipleStmts) {
auto* src = R"(
fn f() {
var i : i32;
i = 0;
loop {
if ((i < 15)) {
} else {
_ = i;
break;
}
_ = 123;
continuing {
i = (i + 1);
}
}
}
)";
auto* expect = src;
auto got = Run<LoopToForLoop>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(LoopToForLoopTest, NoTransform_ContinuingIsCompound) {
auto* src = R"(
fn f() {
var i : i32;
i = 0;
loop {
if ((i < 15)) {
break;
}
_ = 123;
continuing {
if (false) {
}
}
}
}
)";
auto* expect = src;
auto got = Run<LoopToForLoop>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(LoopToForLoopTest, NoTransform_ContinuingMultipleStmts) {
auto* src = R"(
fn f() {
var i : i32;
i = 0;
loop {
if ((i < 15)) {
break;
}
_ = 123;
continuing {
i = (i + 1);
_ = i;
}
}
}
)";
auto* expect = src;
auto got = Run<LoopToForLoop>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(LoopToForLoopTest, NoTransform_ContinuingUsesVarDeclInLoopBody) {
auto* src = R"(
fn f() {
var i : i32;
i = 0;
loop {
if ((i < 15)) {
break;
}
var j : i32;
continuing {
i = (i + j);
}
}
}
)";
auto* expect = src;
auto got = Run<LoopToForLoop>(src);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,86 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/manager.h"
/// If set to 1 then the transform::Manager will dump the WGSL of the program
/// before and after each transform. Helpful for debugging bad output.
#define TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM 0
#if TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM
#define TINT_IF_PRINT_PROGRAM(x) x
#else // TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM
#define TINT_IF_PRINT_PROGRAM(x)
#endif // TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM
TINT_INSTANTIATE_TYPEINFO(tint::transform::Manager);
namespace tint {
namespace transform {
Manager::Manager() = default;
Manager::~Manager() = default;
Output Manager::Run(const Program* program, const DataMap& data) const {
const Program* in = program;
#if TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM
auto print_program = [&](const char* msg, const Transform* transform) {
auto wgsl = Program::printer(in);
std::cout << "---------------------------------------------------------"
<< std::endl;
std::cout << "-- " << msg << " " << transform->TypeInfo().name << ":"
<< std::endl;
std::cout << "---------------------------------------------------------"
<< std::endl;
std::cout << wgsl << std::endl;
std::cout << "---------------------------------------------------------"
<< std::endl
<< std::endl;
};
#endif
Output out;
for (const auto& transform : transforms_) {
if (!transform->ShouldRun(in, data)) {
TINT_IF_PRINT_PROGRAM(std::cout << "Skipping "
<< transform->TypeInfo().name);
continue;
}
TINT_IF_PRINT_PROGRAM(print_program("Input to", transform.get()));
auto res = transform->Run(in, data);
out.program = std::move(res.program);
out.data.Add(std::move(res.data));
in = &out.program;
if (!in->IsValid()) {
TINT_IF_PRINT_PROGRAM(
print_program("Invalid output of", transform.get()));
return out;
}
if (transform == transforms_.back()) {
TINT_IF_PRINT_PROGRAM(print_program("Output of", transform.get()));
}
}
if (program == in) {
out.program = program->Clone();
}
return out;
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,64 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_MANAGER_H_
#define SRC_TINT_TRANSFORM_MANAGER_H_
#include <memory>
#include <utility>
#include <vector>
#include "src/tint/transform/transform.h"
namespace tint {
namespace transform {
/// A collection of Transforms that act as a single Transform.
/// The inner transforms will execute in the appended order.
/// If any inner transform fails the manager will return immediately and
/// the error can be retrieved with the Output's diagnostics.
class Manager : public Castable<Manager, Transform> {
public:
/// Constructor
Manager();
~Manager() override;
/// Add pass to the manager
/// @param transform the transform to append
void append(std::unique_ptr<Transform> transform) {
transforms_.push_back(std::move(transform));
}
/// Add pass to the manager of type `T`, constructed with the provided
/// arguments.
/// @param args the arguments to forward to the `T` constructor
template <typename T, typename... ARGS>
void Add(ARGS&&... args) {
transforms_.emplace_back(std::make_unique<T>(std::forward<ARGS>(args)...));
}
/// Runs the transforms on `program`, returning the transformation result.
/// @param program the source program to transform
/// @param data optional extra transform-specific input data
/// @returns the transformed program and diagnostics
Output Run(const Program* program, const DataMap& data = {}) const override;
private:
std::vector<std::unique_ptr<Transform>> transforms_;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_MANAGER_H_

View File

@@ -0,0 +1,399 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/module_scope_var_to_entry_point_param.h"
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "src/tint/ast/disable_validation_attribute.h"
#include "src/tint/program_builder.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/statement.h"
#include "src/tint/sem/variable.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::ModuleScopeVarToEntryPointParam);
namespace tint {
namespace transform {
namespace {
// Returns `true` if `type` is or contains a matrix type.
bool ContainsMatrix(const sem::Type* type) {
type = type->UnwrapRef();
if (type->Is<sem::Matrix>()) {
return true;
} else if (auto* ary = type->As<sem::Array>()) {
return ContainsMatrix(ary->ElemType());
} else if (auto* str = type->As<sem::Struct>()) {
for (auto* member : str->Members()) {
if (ContainsMatrix(member->Type())) {
return true;
}
}
}
return false;
}
} // namespace
/// State holds the current transform state.
struct ModuleScopeVarToEntryPointParam::State {
/// The clone context.
CloneContext& ctx;
/// Constructor
/// @param context the clone context
explicit State(CloneContext& context) : ctx(context) {}
/// Clone any struct types that are contained in `ty` (including `ty` itself),
/// and add it to the global declarations now, so that they precede new global
/// declarations that need to reference them.
/// @param ty the type to clone
void CloneStructTypes(const sem::Type* ty) {
if (auto* str = ty->As<sem::Struct>()) {
if (!cloned_structs_.emplace(str).second) {
// The struct has already been cloned.
return;
}
// Recurse into members.
for (auto* member : str->Members()) {
CloneStructTypes(member->Type());
}
// Clone the struct and add it to the global declaration list.
// Remove the old declaration.
auto* ast_str = str->Declaration();
ctx.dst->AST().AddTypeDecl(ctx.Clone(ast_str));
ctx.Remove(ctx.src->AST().GlobalDeclarations(), ast_str);
} else if (auto* arr = ty->As<sem::Array>()) {
CloneStructTypes(arr->ElemType());
}
}
/// Process the module.
void Process() {
// Predetermine the list of function calls that need to be replaced.
using CallList = std::vector<const ast::CallExpression*>;
std::unordered_map<const ast::Function*, CallList> calls_to_replace;
std::vector<const ast::Function*> functions_to_process;
// Build a list of functions that transitively reference any module-scope
// variables.
for (auto* func_ast : ctx.src->AST().Functions()) {
auto* func_sem = ctx.src->Sem().Get(func_ast);
bool needs_processing = false;
for (auto* var : func_sem->TransitivelyReferencedGlobals()) {
if (var->StorageClass() != ast::StorageClass::kNone) {
needs_processing = true;
break;
}
}
if (needs_processing) {
functions_to_process.push_back(func_ast);
// Find all of the calls to this function that will need to be replaced.
for (auto* call : func_sem->CallSites()) {
calls_to_replace[call->Stmt()->Function()->Declaration()].push_back(
call->Declaration());
}
}
}
// Build a list of `&ident` expressions. We'll use this later to avoid
// generating expressions of the form `&*ident`, which break WGSL validation
// rules when this expression is passed to a function.
// TODO(jrprice): We should add support for bidirectional SEM tree traversal
// so that we can do this on the fly instead.
std::unordered_map<const ast::IdentifierExpression*,
const ast::UnaryOpExpression*>
ident_to_address_of;
for (auto* node : ctx.src->ASTNodes().Objects()) {
auto* address_of = node->As<ast::UnaryOpExpression>();
if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
continue;
}
if (auto* ident = address_of->expr->As<ast::IdentifierExpression>()) {
ident_to_address_of[ident] = address_of;
}
}
for (auto* func_ast : functions_to_process) {
auto* func_sem = ctx.src->Sem().Get(func_ast);
bool is_entry_point = func_ast->IsEntryPoint();
// Map module-scope variables onto their replacement.
struct NewVar {
Symbol symbol;
bool is_pointer;
bool is_wrapped;
};
const char* kWrappedArrayMemberName = "arr";
std::unordered_map<const sem::Variable*, NewVar> var_to_newvar;
// We aggregate all workgroup variables into a struct to avoid hitting
// MSL's limit for threadgroup memory arguments.
Symbol workgroup_parameter_symbol;
ast::StructMemberList workgroup_parameter_members;
auto workgroup_param = [&]() {
if (!workgroup_parameter_symbol.IsValid()) {
workgroup_parameter_symbol = ctx.dst->Sym();
}
return workgroup_parameter_symbol;
};
for (auto* var : func_sem->TransitivelyReferencedGlobals()) {
auto sc = var->StorageClass();
auto* ty = var->Type()->UnwrapRef();
if (sc == ast::StorageClass::kNone) {
continue;
}
if (sc != ast::StorageClass::kPrivate &&
sc != ast::StorageClass::kStorage &&
sc != ast::StorageClass::kUniform &&
sc != ast::StorageClass::kUniformConstant &&
sc != ast::StorageClass::kWorkgroup) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unhandled module-scope storage class (" << sc << ")";
}
// This is the symbol for the variable that replaces the module-scope
// var.
auto new_var_symbol = ctx.dst->Sym();
// Helper to create an AST node for the store type of the variable.
auto store_type = [&]() { return CreateASTTypeFor(ctx, ty); };
// Track whether the new variable is a pointer or not.
bool is_pointer = false;
// Track whether the new variable was wrapped in a struct or not.
bool is_wrapped = false;
if (is_entry_point) {
if (var->Type()->UnwrapRef()->is_handle()) {
// For a texture or sampler variable, redeclare it as an entry point
// parameter. Disable entry point parameter validation.
auto* disable_validation =
ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter);
auto attrs = ctx.Clone(var->Declaration()->attributes);
attrs.push_back(disable_validation);
auto* param = ctx.dst->Param(new_var_symbol, store_type(), attrs);
ctx.InsertFront(func_ast->params, param);
} else if (sc == ast::StorageClass::kStorage ||
sc == ast::StorageClass::kUniform) {
// Variables into the Storage and Uniform storage classes are
// redeclared as entry point parameters with a pointer type.
auto attributes = ctx.Clone(var->Declaration()->attributes);
attributes.push_back(ctx.dst->Disable(
ast::DisabledValidation::kEntryPointParameter));
attributes.push_back(
ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
auto* param_type = store_type();
if (auto* arr = ty->As<sem::Array>();
arr && arr->IsRuntimeSized()) {
// Wrap runtime-sized arrays in structures, so that we can declare
// pointers to them. Ideally we'd just emit the array itself as a
// pointer, but this is not representable in Tint's AST.
CloneStructTypes(ty);
auto* wrapper = ctx.dst->Structure(
ctx.dst->Sym(),
{ctx.dst->Member(kWrappedArrayMemberName, param_type)});
param_type = ctx.dst->ty.Of(wrapper);
is_wrapped = true;
}
param_type = ctx.dst->ty.pointer(
param_type, sc, var->Declaration()->declared_access);
auto* param =
ctx.dst->Param(new_var_symbol, param_type, attributes);
ctx.InsertFront(func_ast->params, param);
is_pointer = true;
} else if (sc == ast::StorageClass::kWorkgroup &&
ContainsMatrix(var->Type())) {
// Due to a bug in the MSL compiler, we use a threadgroup memory
// argument for any workgroup allocation that contains a matrix.
// See crbug.com/tint/938.
// TODO(jrprice): Do this for all other workgroup variables too.
// Create a member in the workgroup parameter struct.
auto member = ctx.Clone(var->Declaration()->symbol);
workgroup_parameter_members.push_back(
ctx.dst->Member(member, store_type()));
CloneStructTypes(var->Type()->UnwrapRef());
// Create a function-scope variable that is a pointer to the member.
auto* member_ptr = ctx.dst->AddressOf(ctx.dst->MemberAccessor(
ctx.dst->Deref(workgroup_param()), member));
auto* local_var =
ctx.dst->Const(new_var_symbol,
ctx.dst->ty.pointer(
store_type(), ast::StorageClass::kWorkgroup),
member_ptr);
ctx.InsertFront(func_ast->body->statements,
ctx.dst->Decl(local_var));
is_pointer = true;
} else {
// Variables in the Private and Workgroup storage classes are
// redeclared at function scope. Disable storage class validation on
// this variable.
auto* disable_validation =
ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass);
auto* constructor = ctx.Clone(var->Declaration()->constructor);
auto* local_var =
ctx.dst->Var(new_var_symbol, store_type(), sc, constructor,
ast::AttributeList{disable_validation});
ctx.InsertFront(func_ast->body->statements,
ctx.dst->Decl(local_var));
}
} else {
// For a regular function, redeclare the variable as a parameter.
// Use a pointer for non-handle types.
auto* param_type = store_type();
ast::AttributeList attributes;
if (!var->Type()->UnwrapRef()->is_handle()) {
param_type = ctx.dst->ty.pointer(
param_type, sc, var->Declaration()->declared_access);
is_pointer = true;
// Disable validation of the parameter's storage class and of
// arguments passed it.
attributes.push_back(
ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
attributes.push_back(ctx.dst->Disable(
ast::DisabledValidation::kIgnoreInvalidPointerArgument));
}
ctx.InsertBack(
func_ast->params,
ctx.dst->Param(new_var_symbol, param_type, attributes));
}
// Replace all uses of the module-scope variable.
// For non-entry points, dereference non-handle pointer parameters.
for (auto* user : var->Users()) {
if (user->Stmt()->Function()->Declaration() == func_ast) {
const ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
if (is_pointer) {
// If this identifier is used by an address-of operator, just
// remove the address-of instead of adding a deref, since we
// already have a pointer.
auto* ident =
user->Declaration()->As<ast::IdentifierExpression>();
if (ident_to_address_of.count(ident)) {
ctx.Replace(ident_to_address_of[ident], expr);
continue;
}
expr = ctx.dst->Deref(expr);
}
if (is_wrapped) {
// Get the member from the wrapper structure.
expr = ctx.dst->MemberAccessor(expr, kWrappedArrayMemberName);
}
ctx.Replace(user->Declaration(), expr);
}
}
var_to_newvar[var] = {new_var_symbol, is_pointer, is_wrapped};
}
if (!workgroup_parameter_members.empty()) {
// Create the workgroup memory parameter.
// The parameter is a struct that contains members for each workgroup
// variable.
auto* str = ctx.dst->Structure(ctx.dst->Sym(),
std::move(workgroup_parameter_members));
auto* param_type = ctx.dst->ty.pointer(ctx.dst->ty.Of(str),
ast::StorageClass::kWorkgroup);
auto* disable_validation =
ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter);
auto* param =
ctx.dst->Param(workgroup_param(), param_type, {disable_validation});
ctx.InsertFront(func_ast->params, param);
}
// Pass the variables as pointers to any functions that need them.
for (auto* call : calls_to_replace[func_ast]) {
auto* target =
ctx.src->AST().Functions().Find(call->target.name->symbol);
auto* target_sem = ctx.src->Sem().Get(target);
// Add new arguments for any variables that are needed by the callee.
// For entry points, pass non-handle types as pointers.
for (auto* target_var : target_sem->TransitivelyReferencedGlobals()) {
auto sc = target_var->StorageClass();
if (sc == ast::StorageClass::kNone) {
continue;
}
auto new_var = var_to_newvar[target_var];
bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
const ast::Expression* arg = ctx.dst->Expr(new_var.symbol);
if (new_var.is_wrapped) {
// The variable is wrapped in a struct, so we need to pass a pointer
// to the struct member instead.
arg = ctx.dst->AddressOf(ctx.dst->MemberAccessor(
ctx.dst->Deref(arg), kWrappedArrayMemberName));
} else if (is_entry_point && !is_handle && !new_var.is_pointer) {
// We need to pass a pointer and we don't already have one, so take
// the address of the new variable.
arg = ctx.dst->AddressOf(arg);
}
ctx.InsertBack(call->args, arg);
}
}
}
// Now remove all module-scope variables with these storage classes.
for (auto* var_ast : ctx.src->AST().GlobalVariables()) {
auto* var_sem = ctx.src->Sem().Get(var_ast);
if (var_sem->StorageClass() != ast::StorageClass::kNone) {
ctx.Remove(ctx.src->AST().GlobalDeclarations(), var_ast);
}
}
}
private:
std::unordered_set<const sem::Struct*> cloned_structs_;
};
ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default;
bool ModuleScopeVarToEntryPointParam::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* decl : program->AST().GlobalDeclarations()) {
if (decl->Is<ast::Variable>()) {
return true;
}
}
return false;
}
void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
State state{ctx};
state.Process();
ctx.Clone();
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,96 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_MODULE_SCOPE_VAR_TO_ENTRY_POINT_PARAM_H_
#define SRC_TINT_TRANSFORM_MODULE_SCOPE_VAR_TO_ENTRY_POINT_PARAM_H_
#include "src/tint/transform/transform.h"
namespace tint {
namespace transform {
/// Move module-scope variables into the entry point as parameters.
///
/// MSL does not allow module-scope variables to have any address space other
/// than `constant`. This transform moves all module-scope declarations into the
/// entry point function (either as parameters or function-scope variables) and
/// then passes them as pointer parameters to any function that references them.
///
/// Since WGSL does not allow entry point parameters or function-scope variables
/// to have these storage classes, we annotate the new variable declarations
/// with an attribute that bypasses that validation rule.
///
/// Before:
/// ```
/// struct S {
/// f : f32;
/// };
/// @binding(0) @group(0)
/// var<storage, read> s : S;
/// var<private> p : f32 = 2.0;
///
/// fn foo() {
/// p = p + f;
/// }
///
/// @stage(compute) @workgroup_size(1)
/// fn main() {
/// foo();
/// }
/// ```
///
/// After:
/// ```
/// fn foo(p : ptr<private, f32>, sptr : ptr<storage, S, read>) {
/// *p = *p + (*sptr).f;
/// }
///
/// @stage(compute) @workgroup_size(1)
/// fn main(sptr : ptr<storage, S, read>) {
/// var<private> p : f32 = 2.0;
/// foo(&p, sptr);
/// }
/// ```
class ModuleScopeVarToEntryPointParam
: public Castable<ModuleScopeVarToEntryPointParam, Transform> {
public:
/// Constructor
ModuleScopeVarToEntryPointParam();
/// Destructor
~ModuleScopeVarToEntryPointParam() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
struct State;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_MODULE_SCOPE_VAR_TO_ENTRY_POINT_PARAM_H_

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,454 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/multiplanar_external_texture.h"
#include <string>
#include <vector>
#include "src/tint/ast/function.h"
#include "src/tint/program_builder.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/variable.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::MultiplanarExternalTexture);
TINT_INSTANTIATE_TYPEINFO(
tint::transform::MultiplanarExternalTexture::NewBindingPoints);
namespace tint {
namespace transform {
namespace {
/// This struct stores symbols for new bindings created as a result of
/// transforming a texture_external instance.
struct NewBindingSymbols {
Symbol params;
Symbol plane_0;
Symbol plane_1;
};
} // namespace
/// State holds the current transform state
struct MultiplanarExternalTexture::State {
/// The clone context.
CloneContext& ctx;
/// ProgramBuilder for the context
ProgramBuilder& b;
/// Destination binding locations for the expanded texture_external provided
/// as input into the transform.
const NewBindingPoints* new_binding_points;
/// Symbol for the ExternalTextureParams struct
Symbol params_struct_sym;
/// Symbol for the textureLoadExternal function
Symbol texture_load_external_sym;
/// Symbol for the textureSampleExternal function
Symbol texture_sample_external_sym;
/// Storage for new bindings that have been created corresponding to an
/// original texture_external binding.
std::unordered_map<const sem::Variable*, NewBindingSymbols>
new_binding_symbols;
/// Constructor
/// @param context the clone
/// @param newBindingPoints the input destination binding locations for the
/// expanded texture_external
State(CloneContext& context, const NewBindingPoints* newBindingPoints)
: ctx(context), b(*context.dst), new_binding_points(newBindingPoints) {}
/// Processes the module
void Process() {
auto& sem = ctx.src->Sem();
// For each texture_external binding, we replace it with a texture_2d<f32>
// binding and create two additional bindings (one texture_2d<f32> to
// represent the secondary plane and one uniform buffer for the
// ExternalTextureParams struct).
for (auto* var : ctx.src->AST().GlobalVariables()) {
auto* sem_var = sem.Get(var);
if (!sem_var->Type()->UnwrapRef()->Is<sem::ExternalTexture>()) {
continue;
}
// If the attributes are empty, then this must be a texture_external
// passed as a function parameter. These variables are transformed
// elsewhere.
if (var->attributes.empty()) {
continue;
}
// If we find a texture_external binding, we know we must emit the
// ExternalTextureParams struct.
if (!params_struct_sym.IsValid()) {
createExtTexParamsStruct();
}
// The binding points for the newly introduced bindings must have been
// provided to this transform. We fetch the new binding points by
// providing the original texture_external binding points into the
// passed map.
BindingPoint bp = {var->BindingPoint().group->value,
var->BindingPoint().binding->value};
BindingsMap::const_iterator it =
new_binding_points->bindings_map.find(bp);
if (it == new_binding_points->bindings_map.end()) {
b.Diagnostics().add_error(
diag::System::Transform,
"missing new binding points for texture_external at binding {" +
std::to_string(bp.group) + "," + std::to_string(bp.binding) +
"}");
continue;
}
BindingPoints bps = it->second;
// Symbols for the newly created bindings must be saved so they can be
// passed as parameters later. These are placed in a map and keyed by
// the source symbol associated with the texture_external binding that
// corresponds with the new destination bindings.
// NewBindingSymbols new_binding_syms;
auto& syms = new_binding_symbols[sem_var];
syms.plane_0 = ctx.Clone(var->symbol);
syms.plane_1 = b.Symbols().New("ext_tex_plane_1");
b.Global(syms.plane_1,
b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32()),
b.GroupAndBinding(bps.plane_1.group, bps.plane_1.binding));
syms.params = b.Symbols().New("ext_tex_params");
b.Global(syms.params, b.ty.type_name("ExternalTextureParams"),
ast::StorageClass::kUniform,
b.GroupAndBinding(bps.params.group, bps.params.binding));
// Replace the original texture_external binding with a texture_2d<f32>
// binding.
ast::AttributeList cloned_attributes = ctx.Clone(var->attributes);
const ast::Expression* cloned_constructor = ctx.Clone(var->constructor);
auto* replacement =
b.Var(syms.plane_0,
b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32()),
cloned_constructor, cloned_attributes);
ctx.Replace(var, replacement);
}
// We must update all the texture_external parameters for user declared
// functions.
for (auto* fn : ctx.src->AST().Functions()) {
for (const ast::Variable* param : fn->params) {
if (auto* sem_var = sem.Get(param)) {
if (!sem_var->Type()->UnwrapRef()->Is<sem::ExternalTexture>()) {
continue;
}
// If we find a texture_external, we must ensure the
// ExternalTextureParams struct exists.
if (!params_struct_sym.IsValid()) {
createExtTexParamsStruct();
}
// When a texture_external is found, we insert all components
// the texture_external into the parameter list. We must also place
// the new symbols into the transform state so they can be used when
// transforming function calls.
auto& syms = new_binding_symbols[sem_var];
syms.plane_0 = ctx.Clone(param->symbol);
syms.plane_1 = b.Symbols().New("ext_tex_plane_1");
syms.params = b.Symbols().New("ext_tex_params");
auto tex2d_f32 = [&] {
return b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32());
};
ctx.Replace(param, b.Param(syms.plane_0, tex2d_f32()));
ctx.InsertAfter(fn->params, param,
b.Param(syms.plane_1, tex2d_f32()));
ctx.InsertAfter(
fn->params, param,
b.Param(syms.params, b.ty.type_name(params_struct_sym)));
}
}
}
// Transform the original textureLoad and textureSampleLevel calls into
// textureLoadExternal and textureSampleExternal calls.
ctx.ReplaceAll(
[&](const ast::CallExpression* expr) -> const ast::CallExpression* {
auto* builtin = sem.Get(expr)->Target()->As<sem::Builtin>();
if (builtin && !builtin->Parameters().empty() &&
builtin->Parameters()[0]->Type()->Is<sem::ExternalTexture>() &&
builtin->Type() != sem::BuiltinType::kTextureDimensions) {
if (auto* var_user = sem.Get<sem::VariableUser>(expr->args[0])) {
auto it = new_binding_symbols.find(var_user->Variable());
if (it == new_binding_symbols.end()) {
// If valid new binding locations were not provided earlier, we
// would have been unable to create these symbols. An error
// message was emitted earlier, so just return early to avoid
// internal compiler errors and retain a clean error message.
return nullptr;
}
auto& syms = it->second;
if (builtin->Type() == sem::BuiltinType::kTextureLoad) {
return createTexLdExt(expr, syms);
}
if (builtin->Type() == sem::BuiltinType::kTextureSampleLevel) {
return createTexSmpExt(expr, syms);
}
}
} else if (sem.Get(expr)->Target()->Is<sem::Function>()) {
// The call expression may be to a user-defined function that
// contains a texture_external parameter. These need to be expanded
// out to multiple plane textures and the texture parameters
// structure.
for (auto* arg : expr->args) {
if (auto* var_user = sem.Get<sem::VariableUser>(arg)) {
// Check if a parameter is a texture_external by trying to find
// it in the transform state.
auto it = new_binding_symbols.find(var_user->Variable());
if (it != new_binding_symbols.end()) {
auto& syms = it->second;
// When we find a texture_external, we must unpack it into its
// components.
ctx.Replace(arg, b.Expr(syms.plane_0));
ctx.InsertAfter(expr->args, arg, b.Expr(syms.plane_1));
ctx.InsertAfter(expr->args, arg, b.Expr(syms.params));
}
}
}
}
return nullptr;
});
}
/// Creates the ExternalTextureParams struct.
void createExtTexParamsStruct() {
ast::StructMemberList member_list = {
b.Member("numPlanes", b.ty.u32()), b.Member("vr", b.ty.f32()),
b.Member("ug", b.ty.f32()), b.Member("vg", b.ty.f32()),
b.Member("ub", b.ty.f32())};
params_struct_sym = b.Symbols().New("ExternalTextureParams");
b.Structure(params_struct_sym, member_list);
}
/// Constructs a StatementList containing all the statements making up the
/// bodies of the textureSampleExternal and textureLoadExternal functions.
/// @param call_type determines which function body to generate
/// @returns a statement list that makes of the body of the chosen function
ast::StatementList createTexFnExtStatementList(sem::BuiltinType call_type) {
using f32 = ProgramBuilder::f32;
const ast::CallExpression* single_plane_call = nullptr;
const ast::CallExpression* plane_0_call = nullptr;
const ast::CallExpression* plane_1_call = nullptr;
if (call_type == sem::BuiltinType::kTextureSampleLevel) {
// textureSampleLevel(plane0, smp, coord.xy, 0.0);
single_plane_call =
b.Call("textureSampleLevel", "plane0", "smp", "coord", 0.0f);
// textureSampleLevel(plane0, smp, coord.xy, 0.0);
plane_0_call =
b.Call("textureSampleLevel", "plane0", "smp", "coord", 0.0f);
// textureSampleLevel(plane1, smp, coord.xy, 0.0);
plane_1_call =
b.Call("textureSampleLevel", "plane1", "smp", "coord", 0.0f);
} else if (call_type == sem::BuiltinType::kTextureLoad) {
// textureLoad(plane0, coords.xy, 0);
single_plane_call = b.Call("textureLoad", "plane0", "coord", 0);
// textureLoad(plane0, coords.xy, 0);
plane_0_call = b.Call("textureLoad", "plane0", "coord", 0);
// textureLoad(plane1, coords.xy, 0);
plane_1_call = b.Call("textureLoad", "plane1", "coord", 0);
} else {
TINT_ICE(Transform, b.Diagnostics())
<< "unhandled builtin: " << call_type;
}
return {
// if (params.numPlanes == 1u) {
// return singlePlaneCall
// }
b.If(b.create<ast::BinaryExpression>(
ast::BinaryOp::kEqual, b.MemberAccessor("params", "numPlanes"),
b.Expr(1u)),
b.Block(b.Return(single_plane_call))),
// let y = plane0Call.r - 0.0625;
b.Decl(b.Const("y", nullptr,
b.Sub(b.MemberAccessor(plane_0_call, "r"), 0.0625f))),
// let uv = plane1Call.rg - 0.5;
b.Decl(b.Const("uv", nullptr,
b.Sub(b.MemberAccessor(plane_1_call, "rg"), 0.5f))),
// let u = uv.x;
b.Decl(b.Const("u", nullptr, b.MemberAccessor("uv", "x"))),
// let v = uv.y;
b.Decl(b.Const("v", nullptr, b.MemberAccessor("uv", "y"))),
// let r = 1.164 * y + params.vr * v;
b.Decl(b.Const("r", nullptr,
b.Add(b.Mul(1.164f, "y"),
b.Mul(b.MemberAccessor("params", "vr"), "v")))),
// let g = 1.164 * y - params.ug * u - params.vg * v;
b.Decl(
b.Const("g", nullptr,
b.Sub(b.Sub(b.Mul(1.164f, "y"),
b.Mul(b.MemberAccessor("params", "ug"), "u")),
b.Mul(b.MemberAccessor("params", "vg"), "v")))),
// let b = 1.164 * y + params.ub * u;
b.Decl(b.Const("b", nullptr,
b.Add(b.Mul(1.164f, "y"),
b.Mul(b.MemberAccessor("params", "ub"), "u")))),
// return vec4<f32>(r, g, b, 1.0);
b.Return(b.vec4<f32>("r", "g", "b", 1.0f)),
};
}
/// Creates the textureSampleExternal function if needed and returns a call
/// expression to it.
/// @param expr the call expression being transformed
/// @param syms the expanded symbols to be used in the new call
/// @returns a call expression to textureSampleExternal
const ast::CallExpression* createTexSmpExt(const ast::CallExpression* expr,
NewBindingSymbols syms) {
ast::ExpressionList params;
const ast::Expression* plane_0_binding_param = ctx.Clone(expr->args[0]);
if (expr->args.size() != 3) {
TINT_ICE(Transform, b.Diagnostics())
<< "expected textureSampleLevel call with a "
"texture_external to have 3 parameters, found "
<< expr->args.size() << " parameters";
}
if (!texture_sample_external_sym.IsValid()) {
texture_sample_external_sym = b.Symbols().New("textureSampleExternal");
// Emit the textureSampleExternal function.
ast::VariableList varList = {
b.Param("plane0",
b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32())),
b.Param("plane1",
b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32())),
b.Param("smp", b.ty.sampler(ast::SamplerKind::kSampler)),
b.Param("coord", b.ty.vec2(b.ty.f32())),
b.Param("params", b.ty.type_name(params_struct_sym))};
ast::StatementList statementList =
createTexFnExtStatementList(sem::BuiltinType::kTextureSampleLevel);
b.Func(texture_sample_external_sym, varList, b.ty.vec4(b.ty.f32()),
statementList, {});
}
const ast::IdentifierExpression* exp = b.Expr(texture_sample_external_sym);
params = {plane_0_binding_param, b.Expr(syms.plane_1),
ctx.Clone(expr->args[1]), ctx.Clone(expr->args[2]),
b.Expr(syms.params)};
return b.Call(exp, params);
}
/// Creates the textureLoadExternal function if needed and returns a call
/// expression to it.
/// @param expr the call expression being transformed
/// @param syms the expanded symbols to be used in the new call
/// @returns a call expression to textureLoadExternal
const ast::CallExpression* createTexLdExt(const ast::CallExpression* expr,
NewBindingSymbols syms) {
ast::ExpressionList params;
const ast::Expression* plane_0_binding_param = ctx.Clone(expr->args[0]);
if (expr->args.size() != 2) {
TINT_ICE(Transform, b.Diagnostics())
<< "expected textureLoad call with a texture_external "
"to have 2 parameters, found "
<< expr->args.size() << " parameters";
}
if (!texture_load_external_sym.IsValid()) {
texture_load_external_sym = b.Symbols().New("textureLoadExternal");
// Emit the textureLoadExternal function.
ast::VariableList var_list = {
b.Param("plane0",
b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32())),
b.Param("plane1",
b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32())),
b.Param("coord", b.ty.vec2(b.ty.i32())),
b.Param("params", b.ty.type_name(params_struct_sym))};
ast::StatementList statement_list =
createTexFnExtStatementList(sem::BuiltinType::kTextureLoad);
b.Func(texture_load_external_sym, var_list, b.ty.vec4(b.ty.f32()),
statement_list, {});
}
const ast::IdentifierExpression* exp = b.Expr(texture_load_external_sym);
params = {plane_0_binding_param, b.Expr(syms.plane_1),
ctx.Clone(expr->args[1]), b.Expr(syms.params)};
return b.Call(exp, params);
}
};
MultiplanarExternalTexture::NewBindingPoints::NewBindingPoints(
BindingsMap inputBindingsMap)
: bindings_map(std::move(inputBindingsMap)) {}
MultiplanarExternalTexture::NewBindingPoints::~NewBindingPoints() = default;
MultiplanarExternalTexture::MultiplanarExternalTexture() = default;
MultiplanarExternalTexture::~MultiplanarExternalTexture() = default;
bool MultiplanarExternalTexture::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* ty = node->As<ast::Type>()) {
if (program->Sem().Get<sem::ExternalTexture>(ty)) {
return true;
}
}
}
return false;
}
// Within this transform, an instance of a texture_external binding is unpacked
// into two texture_2d<f32> bindings representing two possible planes of a
// single texture and a uniform buffer binding representing a struct of
// parameters. Calls to textureLoad or textureSampleLevel that contain a
// texture_external parameter will be transformed into a newly generated version
// of the function, which can perform the desired operation on a single RGBA
// plane or on separate Y and UV planes.
void MultiplanarExternalTexture::Run(CloneContext& ctx,
const DataMap& inputs,
DataMap&) const {
auto* new_binding_points = inputs.Get<NewBindingPoints>();
if (!new_binding_points) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform,
"missing new binding point data for " + std::string(TypeInfo().name));
return;
}
State state(ctx, new_binding_points);
state.Process();
ctx.Clone();
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,102 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_MULTIPLANAR_EXTERNAL_TEXTURE_H_
#define SRC_TINT_TRANSFORM_MULTIPLANAR_EXTERNAL_TEXTURE_H_
#include <unordered_map>
#include <utility>
#include "src/tint/ast/struct_member.h"
#include "src/tint/sem/binding_point.h"
#include "src/tint/sem/builtin_type.h"
#include "src/tint/transform/transform.h"
namespace tint {
namespace transform {
/// BindingPoint is an alias to sem::BindingPoint
using BindingPoint = sem::BindingPoint;
/// This struct identifies the binding groups and locations for new bindings to
/// use when transforming a texture_external instance.
struct BindingPoints {
/// The desired binding location of the texture_2d representing plane #1 when
/// a texture_external binding is expanded.
BindingPoint plane_1;
/// The desired binding location of the ExternalTextureParams uniform when a
/// texture_external binding is expanded.
BindingPoint params;
};
/// Within the MultiplanarExternalTexture transform, each instance of a
/// texture_external binding is unpacked into two texture_2d<f32> bindings
/// representing two possible planes of a texture and a uniform buffer binding
/// representing a struct of parameters. Calls to textureLoad or
/// textureSampleLevel that contain a texture_external parameter will be
/// transformed into a newly generated version of the function, which can
/// perform the desired operation on a single RGBA plane or on seperate Y and UV
/// planes.
class MultiplanarExternalTexture
: public Castable<MultiplanarExternalTexture, Transform> {
public:
/// BindingsMap is a map where the key is the binding location of a
/// texture_external and the value is a struct containing the desired
/// locations for new bindings expanded from the texture_external instance.
using BindingsMap = std::unordered_map<BindingPoint, BindingPoints>;
/// NewBindingPoints is consumed by the MultiplanarExternalTexture transform.
/// Data holds information about location of each texture_external binding and
/// which binding slots it should expand into.
struct NewBindingPoints : public Castable<Data, transform::Data> {
/// Constructor
/// @param bm a map to the new binding slots to use.
explicit NewBindingPoints(BindingsMap bm);
/// Destructor
~NewBindingPoints() override;
/// A map of new binding points to use.
const BindingsMap bindings_map;
};
/// Constructor
MultiplanarExternalTexture();
/// Destructor
~MultiplanarExternalTexture() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
struct State;
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_MULTIPLANAR_EXTERNAL_TEXTURE_H_

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,170 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/num_workgroups_from_uniform.h"
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include "src/tint/program_builder.h"
#include "src/tint/sem/function.h"
#include "src/tint/transform/canonicalize_entry_point_io.h"
#include "src/tint/utils/hash.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::NumWorkgroupsFromUniform);
TINT_INSTANTIATE_TYPEINFO(tint::transform::NumWorkgroupsFromUniform::Config);
namespace tint {
namespace transform {
namespace {
/// Accessor describes the identifiers used in a member accessor that is being
/// used to retrieve the num_workgroups builtin from a parameter.
struct Accessor {
Symbol param;
Symbol member;
/// Equality operator
bool operator==(const Accessor& other) const {
return param == other.param && member == other.member;
}
/// Hash function
struct Hasher {
size_t operator()(const Accessor& a) const {
return utils::Hash(a.param, a.member);
}
};
};
} // namespace
NumWorkgroupsFromUniform::NumWorkgroupsFromUniform() = default;
NumWorkgroupsFromUniform::~NumWorkgroupsFromUniform() = default;
bool NumWorkgroupsFromUniform::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* attr = node->As<ast::BuiltinAttribute>()) {
if (attr->builtin == ast::Builtin::kNumWorkgroups) {
return true;
}
}
}
return false;
}
void NumWorkgroupsFromUniform::Run(CloneContext& ctx,
const DataMap& inputs,
DataMap&) const {
auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform,
"missing transform data for " + std::string(TypeInfo().name));
return;
}
const char* kNumWorkgroupsMemberName = "num_workgroups";
// Find all entry point parameters that declare the num_workgroups builtin.
std::unordered_set<Accessor, Accessor::Hasher> to_replace;
for (auto* func : ctx.src->AST().Functions()) {
// num_workgroups is only valid for compute stages.
if (func->PipelineStage() != ast::PipelineStage::kCompute) {
continue;
}
for (auto* param : ctx.src->Sem().Get(func)->Parameters()) {
// Because the CanonicalizeEntryPointIO transform has been run, builtins
// will only appear as struct members.
auto* str = param->Type()->As<sem::Struct>();
if (!str) {
continue;
}
for (auto* member : str->Members()) {
auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(
member->Declaration()->attributes);
if (!builtin || builtin->builtin != ast::Builtin::kNumWorkgroups) {
continue;
}
// Capture the symbols that would be used to access this member, which
// we will replace later. We currently have no way to get from the
// parameter directly to the member accessor expressions that use it.
to_replace.insert(
{param->Declaration()->symbol, member->Declaration()->symbol});
// Remove the struct member.
// The CanonicalizeEntryPointIO transform will have generated this
// struct uniquely for this particular entry point, so we know that
// there will be no other uses of this struct in the module and that we
// can safely modify it here.
ctx.Remove(str->Declaration()->members, member->Declaration());
// If this is the only member, remove the struct and parameter too.
if (str->Members().size() == 1) {
ctx.Remove(func->params, param->Declaration());
ctx.Remove(ctx.src->AST().GlobalDeclarations(), str->Declaration());
}
}
}
}
// Get (or create, on first call) the uniform buffer that will receive the
// number of workgroups.
const ast::Variable* num_workgroups_ubo = nullptr;
auto get_ubo = [&]() {
if (!num_workgroups_ubo) {
auto* num_workgroups_struct = ctx.dst->Structure(
ctx.dst->Sym(),
{ctx.dst->Member(kNumWorkgroupsMemberName,
ctx.dst->ty.vec3(ctx.dst->ty.u32()))});
num_workgroups_ubo = ctx.dst->Global(
ctx.dst->Sym(), ctx.dst->ty.Of(num_workgroups_struct),
ast::StorageClass::kUniform,
ast::AttributeList{ctx.dst->GroupAndBinding(
cfg->ubo_binding.group, cfg->ubo_binding.binding)});
}
return num_workgroups_ubo;
};
// Now replace all the places where the builtins are accessed with the value
// loaded from the uniform buffer.
for (auto* node : ctx.src->ASTNodes().Objects()) {
auto* accessor = node->As<ast::MemberAccessorExpression>();
if (!accessor) {
continue;
}
auto* ident = accessor->structure->As<ast::IdentifierExpression>();
if (!ident) {
continue;
}
if (to_replace.count({ident->symbol, accessor->member->symbol})) {
ctx.Replace(accessor, ctx.dst->MemberAccessor(get_ubo()->symbol,
kNumWorkgroupsMemberName));
}
}
ctx.Clone();
}
NumWorkgroupsFromUniform::Config::Config(sem::BindingPoint ubo_bp)
: ubo_binding(ubo_bp) {}
NumWorkgroupsFromUniform::Config::Config(const Config&) = default;
NumWorkgroupsFromUniform::Config::~Config() = default;
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,90 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_NUM_WORKGROUPS_FROM_UNIFORM_H_
#define SRC_TINT_TRANSFORM_NUM_WORKGROUPS_FROM_UNIFORM_H_
#include "src/tint/sem/binding_point.h"
#include "src/tint/transform/transform.h"
namespace tint {
// Forward declarations
class CloneContext;
namespace transform {
/// NumWorkgroupsFromUniform is a transform that implements the `num_workgroups`
/// builtin by loading it from a uniform buffer.
///
/// The generated uniform buffer will have the form:
/// ```
/// struct num_workgroups_struct {
/// num_workgroups : vec3<u32>;
/// };
///
/// @group(0) @binding(0)
/// var<uniform> num_workgroups_ubo : num_workgroups_struct;
/// ```
/// The binding group and number used for this uniform buffer is provided via
/// the `Config` transform input.
///
/// @note Depends on the following transforms to have been run first:
/// * CanonicalizeEntryPointIO
class NumWorkgroupsFromUniform
: public Castable<NumWorkgroupsFromUniform, Transform> {
public:
/// Constructor
NumWorkgroupsFromUniform();
/// Destructor
~NumWorkgroupsFromUniform() override;
/// Configuration options for the NumWorkgroupsFromUniform transform.
struct Config : public Castable<Data, transform::Data> {
/// Constructor
/// @param ubo_bp the binding point to use for the generated uniform buffer.
explicit Config(sem::BindingPoint ubo_bp);
/// Copy constructor
Config(const Config&);
/// Destructor
~Config() override;
/// The binding point to use for the generated uniform buffer.
sem::BindingPoint ubo_binding;
};
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_NUM_WORKGROUPS_FROM_UNIFORM_H_

View File

@@ -0,0 +1,456 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/num_workgroups_from_uniform.h"
#include <utility>
#include "src/tint/transform/canonicalize_entry_point_io.h"
#include "src/tint/transform/test_helper.h"
#include "src/tint/transform/unshadow.h"
namespace tint {
namespace transform {
namespace {
using NumWorkgroupsFromUniformTest = TransformTest;
TEST_F(NumWorkgroupsFromUniformTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<NumWorkgroupsFromUniform>(src));
}
TEST_F(NumWorkgroupsFromUniformTest, ShouldRunHasNumWorkgroups) {
auto* src = R"(
[[stage(compute), workgroup_size(1)]]
fn main([[builtin(num_workgroups)]] num_wgs : vec3<u32>) {
}
)";
EXPECT_TRUE(ShouldRun<NumWorkgroupsFromUniform>(src));
}
TEST_F(NumWorkgroupsFromUniformTest, Error_MissingTransformData) {
auto* src = R"(
[[stage(compute), workgroup_size(1)]]
fn main([[builtin(num_workgroups)]] num_wgs : vec3<u32>) {
}
)";
auto* expect =
"error: missing transform data for "
"tint::transform::NumWorkgroupsFromUniform";
DataMap data;
data.Add<CanonicalizeEntryPointIO::Config>(
CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(NumWorkgroupsFromUniformTest, Basic) {
auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main(@builtin(num_workgroups) num_wgs : vec3<u32>) {
let groups_x = num_wgs.x;
let groups_y = num_wgs.y;
let groups_z = num_wgs.z;
}
)";
auto* expect = R"(
struct tint_symbol_2 {
num_workgroups : vec3<u32>;
}
@group(0) @binding(30) var<uniform> tint_symbol_3 : tint_symbol_2;
fn main_inner(num_wgs : vec3<u32>) {
let groups_x = num_wgs.x;
let groups_y = num_wgs.y;
let groups_z = num_wgs.z;
}
@stage(compute) @workgroup_size(1)
fn main() {
main_inner(tint_symbol_3.num_workgroups);
}
)";
DataMap data;
data.Add<CanonicalizeEntryPointIO::Config>(
CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(NumWorkgroupsFromUniformTest, StructOnlyMember) {
auto* src = R"(
struct Builtins {
@builtin(num_workgroups) num_wgs : vec3<u32>;
};
@stage(compute) @workgroup_size(1)
fn main(in : Builtins) {
let groups_x = in.num_wgs.x;
let groups_y = in.num_wgs.y;
let groups_z = in.num_wgs.z;
}
)";
auto* expect = R"(
struct tint_symbol_2 {
num_workgroups : vec3<u32>;
}
@group(0) @binding(30) var<uniform> tint_symbol_3 : tint_symbol_2;
struct Builtins {
num_wgs : vec3<u32>;
}
fn main_inner(in : Builtins) {
let groups_x = in.num_wgs.x;
let groups_y = in.num_wgs.y;
let groups_z = in.num_wgs.z;
}
@stage(compute) @workgroup_size(1)
fn main() {
main_inner(Builtins(tint_symbol_3.num_workgroups));
}
)";
DataMap data;
data.Add<CanonicalizeEntryPointIO::Config>(
CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(NumWorkgroupsFromUniformTest, StructOnlyMember_OutOfOrder) {
auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main(in : Builtins) {
let groups_x = in.num_wgs.x;
let groups_y = in.num_wgs.y;
let groups_z = in.num_wgs.z;
}
struct Builtins {
@builtin(num_workgroups) num_wgs : vec3<u32>;
};
)";
auto* expect = R"(
struct tint_symbol_2 {
num_workgroups : vec3<u32>;
}
@group(0) @binding(30) var<uniform> tint_symbol_3 : tint_symbol_2;
fn main_inner(in : Builtins) {
let groups_x = in.num_wgs.x;
let groups_y = in.num_wgs.y;
let groups_z = in.num_wgs.z;
}
@stage(compute) @workgroup_size(1)
fn main() {
main_inner(Builtins(tint_symbol_3.num_workgroups));
}
struct Builtins {
num_wgs : vec3<u32>;
}
)";
DataMap data;
data.Add<CanonicalizeEntryPointIO::Config>(
CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(NumWorkgroupsFromUniformTest, StructMultipleMembers) {
auto* src = R"(
struct Builtins {
@builtin(global_invocation_id) gid : vec3<u32>;
@builtin(num_workgroups) num_wgs : vec3<u32>;
@builtin(workgroup_id) wgid : vec3<u32>;
};
@stage(compute) @workgroup_size(1)
fn main(in : Builtins) {
let groups_x = in.num_wgs.x;
let groups_y = in.num_wgs.y;
let groups_z = in.num_wgs.z;
}
)";
auto* expect = R"(
struct tint_symbol_2 {
num_workgroups : vec3<u32>;
}
@group(0) @binding(30) var<uniform> tint_symbol_3 : tint_symbol_2;
struct Builtins {
gid : vec3<u32>;
num_wgs : vec3<u32>;
wgid : vec3<u32>;
}
struct tint_symbol_1 {
@builtin(global_invocation_id)
gid : vec3<u32>;
@builtin(workgroup_id)
wgid : vec3<u32>;
}
fn main_inner(in : Builtins) {
let groups_x = in.num_wgs.x;
let groups_y = in.num_wgs.y;
let groups_z = in.num_wgs.z;
}
@stage(compute) @workgroup_size(1)
fn main(tint_symbol : tint_symbol_1) {
main_inner(Builtins(tint_symbol.gid, tint_symbol_3.num_workgroups, tint_symbol.wgid));
}
)";
DataMap data;
data.Add<CanonicalizeEntryPointIO::Config>(
CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(NumWorkgroupsFromUniformTest, StructMultipleMembers_OutOfOrder) {
auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main(in : Builtins) {
let groups_x = in.num_wgs.x;
let groups_y = in.num_wgs.y;
let groups_z = in.num_wgs.z;
}
struct Builtins {
@builtin(global_invocation_id) gid : vec3<u32>;
@builtin(num_workgroups) num_wgs : vec3<u32>;
@builtin(workgroup_id) wgid : vec3<u32>;
};
)";
auto* expect = R"(
struct tint_symbol_2 {
num_workgroups : vec3<u32>;
}
@group(0) @binding(30) var<uniform> tint_symbol_3 : tint_symbol_2;
struct tint_symbol_1 {
@builtin(global_invocation_id)
gid : vec3<u32>;
@builtin(workgroup_id)
wgid : vec3<u32>;
}
fn main_inner(in : Builtins) {
let groups_x = in.num_wgs.x;
let groups_y = in.num_wgs.y;
let groups_z = in.num_wgs.z;
}
@stage(compute) @workgroup_size(1)
fn main(tint_symbol : tint_symbol_1) {
main_inner(Builtins(tint_symbol.gid, tint_symbol_3.num_workgroups, tint_symbol.wgid));
}
struct Builtins {
gid : vec3<u32>;
num_wgs : vec3<u32>;
wgid : vec3<u32>;
}
)";
DataMap data;
data.Add<CanonicalizeEntryPointIO::Config>(
CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(NumWorkgroupsFromUniformTest, MultipleEntryPoints) {
auto* src = R"(
struct Builtins1 {
@builtin(num_workgroups) num_wgs : vec3<u32>;
};
struct Builtins2 {
@builtin(global_invocation_id) gid : vec3<u32>;
@builtin(num_workgroups) num_wgs : vec3<u32>;
@builtin(workgroup_id) wgid : vec3<u32>;
};
@stage(compute) @workgroup_size(1)
fn main1(in : Builtins1) {
let groups_x = in.num_wgs.x;
let groups_y = in.num_wgs.y;
let groups_z = in.num_wgs.z;
}
@stage(compute) @workgroup_size(1)
fn main2(in : Builtins2) {
let groups_x = in.num_wgs.x;
let groups_y = in.num_wgs.y;
let groups_z = in.num_wgs.z;
}
@stage(compute) @workgroup_size(1)
fn main3(@builtin(num_workgroups) num_wgs : vec3<u32>) {
let groups_x = num_wgs.x;
let groups_y = num_wgs.y;
let groups_z = num_wgs.z;
}
)";
auto* expect = R"(
struct tint_symbol_6 {
num_workgroups : vec3<u32>;
}
@group(0) @binding(30) var<uniform> tint_symbol_7 : tint_symbol_6;
struct Builtins1 {
num_wgs : vec3<u32>;
}
struct Builtins2 {
gid : vec3<u32>;
num_wgs : vec3<u32>;
wgid : vec3<u32>;
}
fn main1_inner(in : Builtins1) {
let groups_x = in.num_wgs.x;
let groups_y = in.num_wgs.y;
let groups_z = in.num_wgs.z;
}
@stage(compute) @workgroup_size(1)
fn main1() {
main1_inner(Builtins1(tint_symbol_7.num_workgroups));
}
struct tint_symbol_3 {
@builtin(global_invocation_id)
gid : vec3<u32>;
@builtin(workgroup_id)
wgid : vec3<u32>;
}
fn main2_inner(in : Builtins2) {
let groups_x = in.num_wgs.x;
let groups_y = in.num_wgs.y;
let groups_z = in.num_wgs.z;
}
@stage(compute) @workgroup_size(1)
fn main2(tint_symbol_2 : tint_symbol_3) {
main2_inner(Builtins2(tint_symbol_2.gid, tint_symbol_7.num_workgroups, tint_symbol_2.wgid));
}
fn main3_inner(num_wgs : vec3<u32>) {
let groups_x = num_wgs.x;
let groups_y = num_wgs.y;
let groups_z = num_wgs.z;
}
@stage(compute) @workgroup_size(1)
fn main3() {
main3_inner(tint_symbol_7.num_workgroups);
}
)";
DataMap data;
data.Add<CanonicalizeEntryPointIO::Config>(
CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(NumWorkgroupsFromUniformTest, NoUsages) {
auto* src = R"(
struct Builtins {
@builtin(global_invocation_id) gid : vec3<u32>;
@builtin(workgroup_id) wgid : vec3<u32>;
};
@stage(compute) @workgroup_size(1)
fn main(in : Builtins) {
}
)";
auto* expect = R"(
struct Builtins {
gid : vec3<u32>;
wgid : vec3<u32>;
}
struct tint_symbol_1 {
@builtin(global_invocation_id)
gid : vec3<u32>;
@builtin(workgroup_id)
wgid : vec3<u32>;
}
fn main_inner(in : Builtins) {
}
@stage(compute) @workgroup_size(1)
fn main(tint_symbol : tint_symbol_1) {
main_inner(Builtins(tint_symbol.gid, tint_symbol.wgid));
}
)";
DataMap data;
data.Add<CanonicalizeEntryPointIO::Config>(
CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
src, data);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,177 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/pad_array_elements.h"
#include <unordered_map>
#include <utility>
#include "src/tint/program_builder.h"
#include "src/tint/sem/array.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/expression.h"
#include "src/tint/sem/type_constructor.h"
#include "src/tint/utils/map.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::PadArrayElements);
namespace tint {
namespace transform {
namespace {
using ArrayBuilder = std::function<const ast::Array*()>;
/// PadArray returns a function that constructs a new array in `ctx.dst` with
/// the element type padded to account for the explicit stride. PadArray will
/// recursively pad arrays-of-arrays. The new array element type will be added
/// to module-scope type declarations of `ctx.dst`.
/// @param ctx the CloneContext
/// @param create_ast_type_for Transform::CreateASTTypeFor()
/// @param padded_arrays a map of src array type to the new array name
/// @param array the array type
/// @return the new AST array
template <typename CREATE_AST_TYPE_FOR>
ArrayBuilder PadArray(
CloneContext& ctx,
CREATE_AST_TYPE_FOR&& create_ast_type_for,
std::unordered_map<const sem::Array*, ArrayBuilder>& padded_arrays,
const sem::Array* array) {
if (array->IsStrideImplicit()) {
// We don't want to wrap arrays that have an implicit stride
return nullptr;
}
return utils::GetOrCreate(padded_arrays, array, [&] {
// Generate a unique name for the array element type
auto name = ctx.dst->Symbols().New("tint_padded_array_element");
// Examine the element type. Is it also an array?
const ast::Type* el_ty = nullptr;
if (auto* el_array = array->ElemType()->As<sem::Array>()) {
// Array of array - call PadArray() on the element type
if (auto p =
PadArray(ctx, create_ast_type_for, padded_arrays, el_array)) {
el_ty = p();
}
}
// If the element wasn't a padded array, just create the typical AST type
// for it
if (el_ty == nullptr) {
el_ty = create_ast_type_for(ctx, array->ElemType());
}
// Structure() will create and append the ast::Struct to the
// global declarations of `ctx.dst`. As we haven't finished building the
// current module-scope statement or function, this will be placed
// immediately before the usage.
ctx.dst->Structure(
name,
{ctx.dst->Member("el", el_ty, {ctx.dst->MemberSize(array->Stride())})});
auto* dst = ctx.dst;
return [=] {
if (array->IsRuntimeSized()) {
return dst->ty.array(dst->create<ast::TypeName>(name));
} else {
return dst->ty.array(dst->create<ast::TypeName>(name), array->Count());
}
};
});
}
} // namespace
PadArrayElements::PadArrayElements() = default;
PadArrayElements::~PadArrayElements() = default;
bool PadArrayElements::ShouldRun(const Program* program, const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* var = node->As<ast::Type>()) {
if (auto* arr = program->Sem().Get<sem::Array>(var)) {
if (!arr->IsStrideImplicit()) {
return true;
}
}
}
}
return false;
}
void PadArrayElements::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
auto& sem = ctx.src->Sem();
std::unordered_map<const sem::Array*, ArrayBuilder> padded_arrays;
auto pad = [&](const sem::Array* array) {
return PadArray(ctx, CreateASTTypeFor, padded_arrays, array);
};
// Replace all array types with their corresponding padded array type
ctx.ReplaceAll([&](const ast::Type* ast_type) -> const ast::Type* {
auto* type = ctx.src->TypeOf(ast_type);
if (auto* array = type->UnwrapRef()->As<sem::Array>()) {
if (auto p = pad(array)) {
return p();
}
}
return nullptr;
});
// Fix up index accessors so `a[1]` becomes `a[1].el`
ctx.ReplaceAll([&](const ast::IndexAccessorExpression* accessor)
-> const ast::Expression* {
if (auto* array = tint::As<sem::Array>(
sem.Get(accessor->object)->Type()->UnwrapRef())) {
if (pad(array)) {
// Array element is wrapped in a structure. Emit a member accessor
// to get to the actual array element.
auto* idx = ctx.CloneWithoutTransform(accessor);
return ctx.dst->MemberAccessor(idx, "el");
}
}
return nullptr;
});
// Fix up array constructors so `A(1,2)` becomes
// `A(padded(1), padded(2))`
ctx.ReplaceAll(
[&](const ast::CallExpression* expr) -> const ast::Expression* {
auto* call = sem.Get(expr);
if (auto* ctor = call->Target()->As<sem::TypeConstructor>()) {
if (auto* array = ctor->ReturnType()->As<sem::Array>()) {
if (auto p = pad(array)) {
auto* arr_ty = p();
auto el_typename = arr_ty->type->As<ast::TypeName>()->name;
ast::ExpressionList args;
args.reserve(call->Arguments().size());
for (auto* arg : call->Arguments()) {
auto* val = ctx.Clone(arg->Declaration());
args.emplace_back(ctx.dst->Construct(
ctx.dst->create<ast::TypeName>(el_typename), val));
}
return ctx.dst->Construct(arr_ty, args);
}
}
}
return nullptr;
});
ctx.Clone();
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,62 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_PAD_ARRAY_ELEMENTS_H_
#define SRC_TINT_TRANSFORM_PAD_ARRAY_ELEMENTS_H_
#include "src/tint/transform/transform.h"
namespace tint {
namespace transform {
/// PadArrayElements is a transform that replaces array types with an explicit
/// stride that is larger than the implicit stride, with an array of a new
/// structure type. This structure holds with a single field of the element
/// type, decorated with a `@size` attribute to pad the structure to the
/// required array stride. The new array types have no explicit stride,
/// structure size is equal to the desired stride.
/// Array index expressions and constructors are also adjusted to deal with this
/// structure element type.
/// This transform helps with backends that cannot directly return arrays or use
/// them as parameters.
class PadArrayElements : public Castable<PadArrayElements, Transform> {
public:
/// Constructor
PadArrayElements();
/// Destructor
~PadArrayElements() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_PAD_ARRAY_ELEMENTS_H_

View File

@@ -0,0 +1,518 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/pad_array_elements.h"
#include <utility>
#include "src/tint/transform/test_helper.h"
namespace tint {
namespace transform {
namespace {
using PadArrayElementsTest = TransformTest;
TEST_F(PadArrayElementsTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<PadArrayElements>(src));
}
TEST_F(PadArrayElementsTest, ShouldRunHasImplicitArrayStride) {
auto* src = R"(
var<private> arr : array<i32, 4>;
)";
EXPECT_FALSE(ShouldRun<PadArrayElements>(src));
}
TEST_F(PadArrayElementsTest, ShouldRunHasExplicitArrayStride) {
auto* src = R"(
var<private> arr : [[stride(8)]] array<i32, 4>;
)";
EXPECT_TRUE(ShouldRun<PadArrayElements>(src));
}
TEST_F(PadArrayElementsTest, EmptyModule) {
auto* src = "";
auto* expect = "";
auto got = Run<PadArrayElements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PadArrayElementsTest, ImplicitArrayStride) {
auto* src = R"(
var<private> arr : array<i32, 4>;
)";
auto* expect = src;
auto got = Run<PadArrayElements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PadArrayElementsTest, ArrayAsGlobal) {
auto* src = R"(
var<private> arr : @stride(8) array<i32, 4>;
)";
auto* expect = R"(
struct tint_padded_array_element {
@size(8)
el : i32;
}
var<private> arr : array<tint_padded_array_element, 4u>;
)";
auto got = Run<PadArrayElements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PadArrayElementsTest, RuntimeArray) {
auto* src = R"(
struct S {
rta : @stride(8) array<i32>;
};
)";
auto* expect = R"(
struct tint_padded_array_element {
@size(8)
el : i32;
}
struct S {
rta : array<tint_padded_array_element>;
}
)";
auto got = Run<PadArrayElements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PadArrayElementsTest, ArrayFunctionVar) {
auto* src = R"(
fn f() {
var arr : @stride(16) array<i32, 4>;
arr = @stride(16) array<i32, 4>();
arr = @stride(16) array<i32, 4>(1, 2, 3, 4);
let x = arr[3];
}
)";
auto* expect = R"(
struct tint_padded_array_element {
@size(16)
el : i32;
}
fn f() {
var arr : array<tint_padded_array_element, 4u>;
arr = array<tint_padded_array_element, 4u>();
arr = array<tint_padded_array_element, 4u>(tint_padded_array_element(1), tint_padded_array_element(2), tint_padded_array_element(3), tint_padded_array_element(4));
let x = arr[3].el;
}
)";
auto got = Run<PadArrayElements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PadArrayElementsTest, ArrayAsParam) {
auto* src = R"(
fn f(a : @stride(12) array<i32, 4>) -> i32 {
return a[2];
}
)";
auto* expect = R"(
struct tint_padded_array_element {
@size(12)
el : i32;
}
fn f(a : array<tint_padded_array_element, 4u>) -> i32 {
return a[2].el;
}
)";
auto got = Run<PadArrayElements>(src);
EXPECT_EQ(expect, str(got));
}
// TODO(crbug.com/tint/781): Cannot parse the stride on the return array type.
TEST_F(PadArrayElementsTest, DISABLED_ArrayAsReturn) {
auto* src = R"(
fn f() -> @stride(8) array<i32, 4> {
return array<i32, 4>(1, 2, 3, 4);
}
)";
auto* expect = R"(
struct tint_padded_array_element {
el : i32;
@size(4)
padding : u32;
};
fn f() -> array<tint_padded_array_element, 4> {
return array<tint_padded_array_element, 4>(tint_padded_array_element(1, 0u), tint_padded_array_element(2, 0u), tint_padded_array_element(3, 0u), tint_padded_array_element(4, 0u));
}
)";
auto got = Run<PadArrayElements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PadArrayElementsTest, ArrayAlias) {
auto* src = R"(
type Array = @stride(16) array<i32, 4>;
fn f() {
var arr : Array;
arr = Array();
arr = Array(1, 2, 3, 4);
let vals : Array = Array(1, 2, 3, 4);
arr = vals;
let x = arr[3];
}
)";
auto* expect = R"(
struct tint_padded_array_element {
@size(16)
el : i32;
}
type Array = array<tint_padded_array_element, 4u>;
fn f() {
var arr : array<tint_padded_array_element, 4u>;
arr = array<tint_padded_array_element, 4u>();
arr = array<tint_padded_array_element, 4u>(tint_padded_array_element(1), tint_padded_array_element(2), tint_padded_array_element(3), tint_padded_array_element(4));
let vals : array<tint_padded_array_element, 4u> = array<tint_padded_array_element, 4u>(tint_padded_array_element(1), tint_padded_array_element(2), tint_padded_array_element(3), tint_padded_array_element(4));
arr = vals;
let x = arr[3].el;
}
)";
auto got = Run<PadArrayElements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PadArrayElementsTest, ArrayAlias_OutOfOrder) {
auto* src = R"(
fn f() {
var arr : Array;
arr = Array();
arr = Array(1, 2, 3, 4);
let vals : Array = Array(1, 2, 3, 4);
arr = vals;
let x = arr[3];
}
type Array = @stride(16) array<i32, 4>;
)";
auto* expect = R"(
struct tint_padded_array_element {
@size(16)
el : i32;
}
fn f() {
var arr : array<tint_padded_array_element, 4u>;
arr = array<tint_padded_array_element, 4u>();
arr = array<tint_padded_array_element, 4u>(tint_padded_array_element(1), tint_padded_array_element(2), tint_padded_array_element(3), tint_padded_array_element(4));
let vals : array<tint_padded_array_element, 4u> = array<tint_padded_array_element, 4u>(tint_padded_array_element(1), tint_padded_array_element(2), tint_padded_array_element(3), tint_padded_array_element(4));
arr = vals;
let x = arr[3].el;
}
type Array = array<tint_padded_array_element, 4u>;
)";
auto got = Run<PadArrayElements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PadArrayElementsTest, ArraysInStruct) {
auto* src = R"(
struct S {
a : @stride(8) array<i32, 4>;
b : @stride(8) array<i32, 8>;
c : @stride(8) array<i32, 4>;
d : @stride(12) array<i32, 8>;
};
)";
auto* expect = R"(
struct tint_padded_array_element {
@size(8)
el : i32;
}
struct tint_padded_array_element_1 {
@size(8)
el : i32;
}
struct tint_padded_array_element_2 {
@size(12)
el : i32;
}
struct S {
a : array<tint_padded_array_element, 4u>;
b : array<tint_padded_array_element_1, 8u>;
c : array<tint_padded_array_element, 4u>;
d : array<tint_padded_array_element_2, 8u>;
}
)";
auto got = Run<PadArrayElements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PadArrayElementsTest, ArraysOfArraysInStruct) {
auto* src = R"(
struct S {
a : @stride(512) array<i32, 4>;
b : @stride(512) array<@stride(32) array<i32, 4>, 4>;
c : @stride(512) array<@stride(64) array<@stride(8) array<i32, 4>, 4>, 4>;
};
)";
auto* expect = R"(
struct tint_padded_array_element {
@size(512)
el : i32;
}
struct tint_padded_array_element_2 {
@size(32)
el : i32;
}
struct tint_padded_array_element_1 {
@size(512)
el : array<tint_padded_array_element_2, 4u>;
}
struct tint_padded_array_element_5 {
@size(8)
el : i32;
}
struct tint_padded_array_element_4 {
@size(64)
el : array<tint_padded_array_element_5, 4u>;
}
struct tint_padded_array_element_3 {
@size(512)
el : array<tint_padded_array_element_4, 4u>;
}
struct S {
a : array<tint_padded_array_element, 4u>;
b : array<tint_padded_array_element_1, 4u>;
c : array<tint_padded_array_element_3, 4u>;
}
)";
auto got = Run<PadArrayElements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PadArrayElementsTest, AccessArraysOfArraysInStruct) {
auto* src = R"(
struct S {
a : @stride(512) array<i32, 4>;
b : @stride(512) array<@stride(32) array<i32, 4>, 4>;
c : @stride(512) array<@stride(64) array<@stride(8) array<i32, 4>, 4>, 4>;
};
fn f(s : S) -> i32 {
return s.a[2] + s.b[1][2] + s.c[3][1][2];
}
)";
auto* expect = R"(
struct tint_padded_array_element {
@size(512)
el : i32;
}
struct tint_padded_array_element_2 {
@size(32)
el : i32;
}
struct tint_padded_array_element_1 {
@size(512)
el : array<tint_padded_array_element_2, 4u>;
}
struct tint_padded_array_element_5 {
@size(8)
el : i32;
}
struct tint_padded_array_element_4 {
@size(64)
el : array<tint_padded_array_element_5, 4u>;
}
struct tint_padded_array_element_3 {
@size(512)
el : array<tint_padded_array_element_4, 4u>;
}
struct S {
a : array<tint_padded_array_element, 4u>;
b : array<tint_padded_array_element_1, 4u>;
c : array<tint_padded_array_element_3, 4u>;
}
fn f(s : S) -> i32 {
return ((s.a[2].el + s.b[1].el[2].el) + s.c[3].el[1].el[2].el);
}
)";
auto got = Run<PadArrayElements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PadArrayElementsTest, AccessArraysOfArraysInStruct_OutOfOrder) {
auto* src = R"(
fn f(s : S) -> i32 {
return s.a[2] + s.b[1][2] + s.c[3][1][2];
}
struct S {
a : @stride(512) array<i32, 4>;
b : @stride(512) array<@stride(32) array<i32, 4>, 4>;
c : @stride(512) array<@stride(64) array<@stride(8) array<i32, 4>, 4>, 4>;
};
)";
auto* expect = R"(
struct tint_padded_array_element {
@size(512)
el : i32;
}
struct tint_padded_array_element_1 {
@size(32)
el : i32;
}
struct tint_padded_array_element_2 {
@size(512)
el : array<tint_padded_array_element_1, 4u>;
}
struct tint_padded_array_element_3 {
@size(8)
el : i32;
}
struct tint_padded_array_element_4 {
@size(64)
el : array<tint_padded_array_element_3, 4u>;
}
struct tint_padded_array_element_5 {
@size(512)
el : array<tint_padded_array_element_4, 4u>;
}
fn f(s : S) -> i32 {
return ((s.a[2].el + s.b[1].el[2].el) + s.c[3].el[1].el[2].el);
}
struct S {
a : array<tint_padded_array_element, 4u>;
b : array<tint_padded_array_element_2, 4u>;
c : array<tint_padded_array_element_5, 4u>;
}
)";
auto got = Run<PadArrayElements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PadArrayElementsTest, DeclarationOrder) {
auto* src = R"(
type T0 = i32;
type T1 = @stride(8) array<i32, 1>;
type T2 = i32;
fn f1(a : @stride(8) array<i32, 2>) {
}
type T3 = i32;
fn f2() {
var v : @stride(8) array<i32, 3>;
}
)";
auto* expect = R"(
type T0 = i32;
struct tint_padded_array_element {
@size(8)
el : i32;
}
type T1 = array<tint_padded_array_element, 1u>;
type T2 = i32;
struct tint_padded_array_element_1 {
@size(8)
el : i32;
}
fn f1(a : array<tint_padded_array_element_1, 2u>) {
}
type T3 = i32;
struct tint_padded_array_element_2 {
@size(8)
el : i32;
}
fn f2() {
var v : array<tint_padded_array_element_2, 3u>;
}
)";
auto got = Run<PadArrayElements>(src);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,83 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/promote_initializers_to_const_var.h"
#include "src/tint/program_builder.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/statement.h"
#include "src/tint/sem/type_constructor.h"
#include "src/tint/transform/utils/hoist_to_decl_before.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::PromoteInitializersToConstVar);
namespace tint::transform {
PromoteInitializersToConstVar::PromoteInitializersToConstVar() = default;
PromoteInitializersToConstVar::~PromoteInitializersToConstVar() = default;
void PromoteInitializersToConstVar::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
HoistToDeclBefore hoist_to_decl_before(ctx);
// Hoists array and structure initializers to a constant variable, declared
// just before the statement of usage.
auto type_ctor_to_let = [&](const ast::CallExpression* expr) {
auto* ctor = ctx.src->Sem().Get(expr);
if (!ctor->Target()->Is<sem::TypeConstructor>()) {
return true;
}
auto* sem_stmt = ctor->Stmt();
if (!sem_stmt) {
// Expression is outside of a statement. This usually means the
// expression is part of a global (module-scope) constant declaration.
// These must be constexpr, and so cannot contain the type of
// expressions that must be sanitized.
return true;
}
auto* stmt = sem_stmt->Declaration();
if (auto* src_var_decl = stmt->As<ast::VariableDeclStatement>()) {
if (src_var_decl->variable->constructor == expr) {
// This statement is just a variable declaration with the
// initializer as the constructor value. This is what we're
// attempting to transform to, and so ignore.
return true;
}
}
auto* src_ty = ctor->Type();
if (!src_ty->IsAnyOf<sem::Array, sem::Struct>()) {
// We only care about array and struct initializers
return true;
}
return hoist_to_decl_before.Add(ctor, expr, true);
};
for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* call_expr = node->As<ast::CallExpression>()) {
if (!type_ctor_to_let(call_expr)) {
return;
}
}
}
hoist_to_decl_before.Apply();
ctx.Clone();
}
} // namespace tint::transform

View File

@@ -0,0 +1,48 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_PROMOTE_INITIALIZERS_TO_CONST_VAR_H_
#define SRC_TINT_TRANSFORM_PROMOTE_INITIALIZERS_TO_CONST_VAR_H_
#include "src/tint/transform/transform.h"
namespace tint::transform {
/// A transform that hoists the array and structure initializers to a constant
/// variable, declared just before the statement of usage.
/// @see crbug.com/tint/406
class PromoteInitializersToConstVar
: public Castable<PromoteInitializersToConstVar, Transform> {
public:
/// Constructor
PromoteInitializersToConstVar();
/// Destructor
~PromoteInitializersToConstVar() override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace tint::transform
#endif // SRC_TINT_TRANSFORM_PROMOTE_INITIALIZERS_TO_CONST_VAR_H_

View File

@@ -0,0 +1,627 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/promote_initializers_to_const_var.h"
#include "src/tint/transform/test_helper.h"
namespace tint {
namespace transform {
namespace {
using PromoteInitializersToConstVarTest = TransformTest;
TEST_F(PromoteInitializersToConstVarTest, EmptyModule) {
auto* src = "";
auto* expect = "";
auto got = Run<PromoteInitializersToConstVar>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, BasicArray) {
auto* src = R"(
fn f() {
var f0 = 1.0;
var f1 = 2.0;
var f2 = 3.0;
var f3 = 4.0;
var i = array<f32, 4u>(f0, f1, f2, f3)[2];
}
)";
auto* expect = R"(
fn f() {
var f0 = 1.0;
var f1 = 2.0;
var f2 = 3.0;
var f3 = 4.0;
let tint_symbol = array<f32, 4u>(f0, f1, f2, f3);
var i = tint_symbol[2];
}
)";
DataMap data;
auto got = Run<PromoteInitializersToConstVar>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, BasicStruct) {
auto* src = R"(
struct S {
a : i32;
b : f32;
c : vec3<f32>;
};
fn f() {
var x = S(1, 2.0, vec3<f32>()).b;
}
)";
auto* expect = R"(
struct S {
a : i32;
b : f32;
c : vec3<f32>;
}
fn f() {
let tint_symbol = S(1, 2.0, vec3<f32>());
var x = tint_symbol.b;
}
)";
DataMap data;
auto got = Run<PromoteInitializersToConstVar>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, BasicStruct_OutOfOrder) {
auto* src = R"(
fn f() {
var x = S(1, 2.0, vec3<f32>()).b;
}
struct S {
a : i32;
b : f32;
c : vec3<f32>;
};
)";
auto* expect = R"(
fn f() {
let tint_symbol = S(1, 2.0, vec3<f32>());
var x = tint_symbol.b;
}
struct S {
a : i32;
b : f32;
c : vec3<f32>;
}
)";
DataMap data;
auto got = Run<PromoteInitializersToConstVar>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, ArrayInForLoopInit) {
auto* src = R"(
fn f() {
var insert_after = 1;
for(var i = array<f32, 4u>(0.0, 1.0, 2.0, 3.0)[2]; ; ) {
break;
}
}
)";
auto* expect = R"(
fn f() {
var insert_after = 1;
let tint_symbol = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
for(var i = tint_symbol[2]; ; ) {
break;
}
}
)";
DataMap data;
auto got = Run<PromoteInitializersToConstVar>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, StructInForLoopInit) {
auto* src = R"(
struct S {
a : i32;
b : f32;
c : vec3<f32>;
};
fn f() {
var insert_after = 1;
for(var x = S(1, 2.0, vec3<f32>()).b; ; ) {
break;
}
}
)";
auto* expect = R"(
struct S {
a : i32;
b : f32;
c : vec3<f32>;
}
fn f() {
var insert_after = 1;
let tint_symbol = S(1, 2.0, vec3<f32>());
for(var x = tint_symbol.b; ; ) {
break;
}
}
)";
DataMap data;
auto got = Run<PromoteInitializersToConstVar>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, StructInForLoopInit_OutOfOrder) {
auto* src = R"(
fn f() {
var insert_after = 1;
for(var x = S(1, 2.0, vec3<f32>()).b; ; ) {
break;
}
}
struct S {
a : i32;
b : f32;
c : vec3<f32>;
};
)";
auto* expect = R"(
fn f() {
var insert_after = 1;
let tint_symbol = S(1, 2.0, vec3<f32>());
for(var x = tint_symbol.b; ; ) {
break;
}
}
struct S {
a : i32;
b : f32;
c : vec3<f32>;
}
)";
DataMap data;
auto got = Run<PromoteInitializersToConstVar>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, ArrayInForLoopCond) {
auto* src = R"(
fn f() {
var f = 1.0;
for(; f == array<f32, 1u>(f)[0]; f = f + 1.0) {
var marker = 1;
}
}
)";
auto* expect = R"(
fn f() {
var f = 1.0;
loop {
let tint_symbol = array<f32, 1u>(f);
if (!((f == tint_symbol[0]))) {
break;
}
{
var marker = 1;
}
continuing {
f = (f + 1.0);
}
}
}
)";
DataMap data;
auto got = Run<PromoteInitializersToConstVar>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, ArrayInForLoopCont) {
auto* src = R"(
fn f() {
var f = 0.0;
for(; f < 10.0; f = f + array<f32, 1u>(1.0)[0]) {
var marker = 1;
}
}
)";
auto* expect = R"(
fn f() {
var f = 0.0;
loop {
if (!((f < 10.0))) {
break;
}
{
var marker = 1;
}
continuing {
let tint_symbol = array<f32, 1u>(1.0);
f = (f + tint_symbol[0]);
}
}
}
)";
DataMap data;
auto got = Run<PromoteInitializersToConstVar>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, ArrayInForLoopInitCondCont) {
auto* src = R"(
fn f() {
for(var f = array<f32, 1u>(0.0)[0];
f < array<f32, 1u>(1.0)[0];
f = f + array<f32, 1u>(2.0)[0]) {
var marker = 1;
}
}
)";
auto* expect = R"(
fn f() {
let tint_symbol = array<f32, 1u>(0.0);
{
var f = tint_symbol[0];
loop {
let tint_symbol_1 = array<f32, 1u>(1.0);
if (!((f < tint_symbol_1[0]))) {
break;
}
{
var marker = 1;
}
continuing {
let tint_symbol_2 = array<f32, 1u>(2.0);
f = (f + tint_symbol_2[0]);
}
}
}
}
)";
DataMap data;
auto got = Run<PromoteInitializersToConstVar>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, ArrayInElseIf) {
auto* src = R"(
fn f() {
var f = 1.0;
if (true) {
var marker = 0;
} else if (f == array<f32, 2u>(f, f)[0]) {
var marker = 1;
}
}
)";
auto* expect = R"(
fn f() {
var f = 1.0;
if (true) {
var marker = 0;
} else {
let tint_symbol = array<f32, 2u>(f, f);
if ((f == tint_symbol[0])) {
var marker = 1;
}
}
}
)";
DataMap data;
auto got = Run<PromoteInitializersToConstVar>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, ArrayInElseIfChain) {
auto* src = R"(
fn f() {
var f = 1.0;
if (true) {
var marker = 0;
} else if (true) {
var marker = 1;
} else if (f == array<f32, 2u>(f, f)[0]) {
var marker = 2;
} else if (f == array<f32, 2u>(f, f)[1]) {
var marker = 3;
} else if (true) {
var marker = 4;
} else {
var marker = 5;
}
}
)";
auto* expect = R"(
fn f() {
var f = 1.0;
if (true) {
var marker = 0;
} else if (true) {
var marker = 1;
} else {
let tint_symbol = array<f32, 2u>(f, f);
if ((f == tint_symbol[0])) {
var marker = 2;
} else {
let tint_symbol_1 = array<f32, 2u>(f, f);
if ((f == tint_symbol_1[1])) {
var marker = 3;
} else if (true) {
var marker = 4;
} else {
var marker = 5;
}
}
}
}
)";
DataMap data;
auto got = Run<PromoteInitializersToConstVar>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, ArrayInArrayArray) {
auto* src = R"(
fn f() {
var i = array<array<f32, 2u>, 2u>(array<f32, 2u>(1.0, 2.0), array<f32, 2u>(3.0, 4.0))[0][1];
}
)";
auto* expect = R"(
fn f() {
let tint_symbol = array<f32, 2u>(1.0, 2.0);
let tint_symbol_1 = array<f32, 2u>(3.0, 4.0);
let tint_symbol_2 = array<array<f32, 2u>, 2u>(tint_symbol, tint_symbol_1);
var i = tint_symbol_2[0][1];
}
)";
DataMap data;
auto got = Run<PromoteInitializersToConstVar>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, StructNested) {
auto* src = R"(
struct S1 {
a : i32;
};
struct S2 {
a : i32;
b : S1;
c : i32;
};
struct S3 {
a : S2;
};
fn f() {
var x = S3(S2(1, S1(2), 3)).a.b.a;
}
)";
auto* expect = R"(
struct S1 {
a : i32;
}
struct S2 {
a : i32;
b : S1;
c : i32;
}
struct S3 {
a : S2;
}
fn f() {
let tint_symbol = S1(2);
let tint_symbol_1 = S2(1, tint_symbol, 3);
let tint_symbol_2 = S3(tint_symbol_1);
var x = tint_symbol_2.a.b.a;
}
)";
DataMap data;
auto got = Run<PromoteInitializersToConstVar>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, Mixed) {
auto* src = R"(
struct S1 {
a : i32;
};
struct S2 {
a : array<S1, 3u>;
};
fn f() {
var x = S2(array<S1, 3u>(S1(1), S1(2), S1(3))).a[1].a;
}
)";
auto* expect = R"(
struct S1 {
a : i32;
}
struct S2 {
a : array<S1, 3u>;
}
fn f() {
let tint_symbol = S1(1);
let tint_symbol_1 = S1(2);
let tint_symbol_2 = S1(3);
let tint_symbol_3 = array<S1, 3u>(tint_symbol, tint_symbol_1, tint_symbol_2);
let tint_symbol_4 = S2(tint_symbol_3);
var x = tint_symbol_4.a[1].a;
}
)";
DataMap data;
auto got = Run<PromoteInitializersToConstVar>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, Mixed_OutOfOrder) {
auto* src = R"(
fn f() {
var x = S2(array<S1, 3u>(S1(1), S1(2), S1(3))).a[1].a;
}
struct S2 {
a : array<S1, 3u>;
};
struct S1 {
a : i32;
};
)";
auto* expect = R"(
fn f() {
let tint_symbol = S1(1);
let tint_symbol_1 = S1(2);
let tint_symbol_2 = S1(3);
let tint_symbol_3 = array<S1, 3u>(tint_symbol, tint_symbol_1, tint_symbol_2);
let tint_symbol_4 = S2(tint_symbol_3);
var x = tint_symbol_4.a[1].a;
}
struct S2 {
a : array<S1, 3u>;
}
struct S1 {
a : i32;
}
)";
DataMap data;
auto got = Run<PromoteInitializersToConstVar>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, NoChangeOnVarDecl) {
auto* src = R"(
struct S {
a : i32;
b : f32;
c : i32;
}
fn f() {
var local_arr = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
var local_str = S(1, 2.0, 3);
}
let module_arr : array<f32, 4u> = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
let module_str : S = S(1, 2.0, 3);
)";
auto* expect = src;
DataMap data;
auto got = Run<PromoteInitializersToConstVar>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteInitializersToConstVarTest, NoChangeOnVarDecl_OutOfOrder) {
auto* src = R"(
fn f() {
var local_arr = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
var local_str = S(1, 2.0, 3);
}
let module_str : S = S(1, 2.0, 3);
struct S {
a : i32;
b : f32;
c : i32;
}
let module_arr : array<f32, 4u> = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
)";
auto* expect = src;
DataMap data;
auto got = Run<PromoteInitializersToConstVar>(src);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,156 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/remove_phonies.h"
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include "src/tint/ast/traverse_expressions.h"
#include "src/tint/program_builder.h"
#include "src/tint/sem/block_statement.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/statement.h"
#include "src/tint/sem/variable.h"
#include "src/tint/utils/map.h"
#include "src/tint/utils/scoped_assignment.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::RemovePhonies);
namespace tint {
namespace transform {
namespace {
struct SinkSignature {
std::vector<const sem::Type*> types;
bool operator==(const SinkSignature& other) const {
if (types.size() != other.types.size()) {
return false;
}
for (size_t i = 0; i < types.size(); i++) {
if (types[i] != other.types[i]) {
return false;
}
}
return true;
}
struct Hasher {
/// @param sig the CallTargetSignature to hash
/// @return the hash value
std::size_t operator()(const SinkSignature& sig) const {
size_t hash = tint::utils::Hash(sig.types.size());
for (auto* ty : sig.types) {
tint::utils::HashCombine(&hash, ty);
}
return hash;
}
};
};
} // namespace
RemovePhonies::RemovePhonies() = default;
RemovePhonies::~RemovePhonies() = default;
bool RemovePhonies::ShouldRun(const Program* program, const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (node->Is<ast::PhonyExpression>()) {
return true;
}
}
return false;
}
void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
auto& sem = ctx.src->Sem();
std::unordered_map<SinkSignature, Symbol, SinkSignature::Hasher> sinks;
for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* stmt = node->As<ast::AssignmentStatement>()) {
if (stmt->lhs->Is<ast::PhonyExpression>()) {
std::vector<const ast::Expression*> side_effects;
if (!ast::TraverseExpressions(
stmt->rhs, ctx.dst->Diagnostics(),
[&](const ast::CallExpression* call) {
// ast::CallExpression may map to a function or builtin call
// (both may have side-effects), or a type constructor or
// type conversion (both do not have side effects).
if (sem.Get(call)
->Target()
->IsAnyOf<sem::Function, sem::Builtin>()) {
side_effects.push_back(call);
return ast::TraverseAction::Skip;
}
return ast::TraverseAction::Descend;
})) {
return;
}
if (side_effects.empty()) {
// Phony assignment with no side effects.
// Just remove it.
RemoveStatement(ctx, stmt);
continue;
}
if (side_effects.size() == 1) {
if (auto* call = side_effects[0]->As<ast::CallExpression>()) {
// Phony assignment with single call side effect.
// Replace phony assignment with call.
ctx.Replace(
stmt, [&, call] { return ctx.dst->CallStmt(ctx.Clone(call)); });
continue;
}
}
// Phony assignment with multiple side effects.
// Generate a call to a placeholder function with the side
// effects as arguments.
ctx.Replace(stmt, [&, side_effects] {
SinkSignature sig;
for (auto* arg : side_effects) {
sig.types.push_back(sem.Get(arg)->Type()->UnwrapRef());
}
auto sink = utils::GetOrCreate(sinks, sig, [&] {
auto name = ctx.dst->Symbols().New("phony_sink");
ast::VariableList params;
for (auto* ty : sig.types) {
auto* ast_ty = CreateASTTypeFor(ctx, ty);
params.push_back(
ctx.dst->Param("p" + std::to_string(params.size()), ast_ty));
}
ctx.dst->Func(name, params, ctx.dst->ty.void_(), {});
return name;
});
ast::ExpressionList args;
for (auto* arg : side_effects) {
args.push_back(ctx.Clone(arg));
}
return ctx.dst->CallStmt(ctx.dst->Call(sink, args));
});
}
}
}
ctx.Clone();
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,58 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_REMOVE_PHONIES_H_
#define SRC_TINT_TRANSFORM_REMOVE_PHONIES_H_
#include <string>
#include <unordered_map>
#include "src/tint/transform/transform.h"
namespace tint {
namespace transform {
/// RemovePhonies is a Transform that removes all phony-assignment statements,
/// while preserving function call expressions in the RHS of the assignment that
/// may have side-effects.
class RemovePhonies : public Castable<RemovePhonies, Transform> {
public:
/// Constructor
RemovePhonies();
/// Destructor
~RemovePhonies() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_REMOVE_PHONIES_H_

View File

@@ -0,0 +1,431 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/remove_phonies.h"
#include <memory>
#include <utility>
#include <vector>
#include "src/tint/transform/test_helper.h"
namespace tint {
namespace transform {
namespace {
using RemovePhoniesTest = TransformTest;
TEST_F(RemovePhoniesTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<RemovePhonies>(src));
}
TEST_F(RemovePhoniesTest, ShouldRunHasPhony) {
auto* src = R"(
fn f() {
_ = 1;
}
)";
EXPECT_TRUE(ShouldRun<RemovePhonies>(src));
}
TEST_F(RemovePhoniesTest, EmptyModule) {
auto* src = "";
auto* expect = "";
auto got = Run<RemovePhonies>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemovePhoniesTest, NoSideEffects) {
auto* src = R"(
@group(0) @binding(0) var t : texture_2d<f32>;
fn f() {
var v : i32;
_ = &v;
_ = 1;
_ = 1 + 2;
_ = t;
_ = u32(3.0);
_ = f32(i32(4u));
_ = vec2<f32>(5.0);
_ = vec3<i32>(6, 7, 8);
_ = mat2x2<f32>(9.0, 10.0, 11.0, 12.0);
}
)";
auto* expect = R"(
@group(0) @binding(0) var t : texture_2d<f32>;
fn f() {
var v : i32;
}
)";
auto got = Run<RemovePhonies>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemovePhoniesTest, SingleSideEffects) {
auto* src = R"(
fn neg(a : i32) -> i32 {
return -(a);
}
fn add(a : i32, b : i32) -> i32 {
return (a + b);
}
fn f() {
_ = neg(1);
_ = add(2, 3);
_ = add(neg(4), neg(5));
_ = u32(neg(6));
_ = f32(add(7, 8));
_ = vec2<f32>(f32(neg(9)));
_ = vec3<i32>(1, neg(10), 3);
_ = mat2x2<f32>(1.0, f32(add(11, 12)), 3.0, 4.0);
}
)";
auto* expect = R"(
fn neg(a : i32) -> i32 {
return -(a);
}
fn add(a : i32, b : i32) -> i32 {
return (a + b);
}
fn f() {
neg(1);
add(2, 3);
add(neg(4), neg(5));
neg(6);
add(7, 8);
neg(9);
neg(10);
add(11, 12);
}
)";
auto got = Run<RemovePhonies>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemovePhoniesTest, SingleSideEffects_OutOfOrder) {
auto* src = R"(
fn f() {
_ = neg(1);
_ = add(2, 3);
_ = add(neg(4), neg(5));
_ = u32(neg(6));
_ = f32(add(7, 8));
_ = vec2<f32>(f32(neg(9)));
_ = vec3<i32>(1, neg(10), 3);
_ = mat2x2<f32>(1.0, f32(add(11, 12)), 3.0, 4.0);
}
fn add(a : i32, b : i32) -> i32 {
return (a + b);
}
fn neg(a : i32) -> i32 {
return -(a);
}
)";
auto* expect = R"(
fn f() {
neg(1);
add(2, 3);
add(neg(4), neg(5));
neg(6);
add(7, 8);
neg(9);
neg(10);
add(11, 12);
}
fn add(a : i32, b : i32) -> i32 {
return (a + b);
}
fn neg(a : i32) -> i32 {
return -(a);
}
)";
auto got = Run<RemovePhonies>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemovePhoniesTest, MultipleSideEffects) {
auto* src = R"(
fn neg(a : i32) -> i32 {
return -(a);
}
fn add(a : i32, b : i32) -> i32 {
return (a + b);
}
fn xor(a : u32, b : u32) -> u32 {
return (a ^ b);
}
fn f() {
_ = (1 + add(2 + add(3, 4), 5)) * add(6, 7) * neg(8);
_ = add(9, neg(10)) + neg(11);
_ = xor(12u, 13u) + xor(14u, 15u);
_ = neg(16) / neg(17) + add(18, 19);
}
)";
auto* expect = R"(
fn neg(a : i32) -> i32 {
return -(a);
}
fn add(a : i32, b : i32) -> i32 {
return (a + b);
}
fn xor(a : u32, b : u32) -> u32 {
return (a ^ b);
}
fn phony_sink(p0 : i32, p1 : i32, p2 : i32) {
}
fn phony_sink_1(p0 : i32, p1 : i32) {
}
fn phony_sink_2(p0 : u32, p1 : u32) {
}
fn f() {
phony_sink(add((2 + add(3, 4)), 5), add(6, 7), neg(8));
phony_sink_1(add(9, neg(10)), neg(11));
phony_sink_2(xor(12u, 13u), xor(14u, 15u));
phony_sink(neg(16), neg(17), add(18, 19));
}
)";
auto got = Run<RemovePhonies>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemovePhoniesTest, MultipleSideEffects_OutOfOrder) {
auto* src = R"(
fn f() {
_ = (1 + add(2 + add(3, 4), 5)) * add(6, 7) * neg(8);
_ = add(9, neg(10)) + neg(11);
_ = xor(12u, 13u) + xor(14u, 15u);
_ = neg(16) / neg(17) + add(18, 19);
}
fn neg(a : i32) -> i32 {
return -(a);
}
fn add(a : i32, b : i32) -> i32 {
return (a + b);
}
fn xor(a : u32, b : u32) -> u32 {
return (a ^ b);
}
)";
auto* expect = R"(
fn phony_sink(p0 : i32, p1 : i32, p2 : i32) {
}
fn phony_sink_1(p0 : i32, p1 : i32) {
}
fn phony_sink_2(p0 : u32, p1 : u32) {
}
fn f() {
phony_sink(add((2 + add(3, 4)), 5), add(6, 7), neg(8));
phony_sink_1(add(9, neg(10)), neg(11));
phony_sink_2(xor(12u, 13u), xor(14u, 15u));
phony_sink(neg(16), neg(17), add(18, 19));
}
fn neg(a : i32) -> i32 {
return -(a);
}
fn add(a : i32, b : i32) -> i32 {
return (a + b);
}
fn xor(a : u32, b : u32) -> u32 {
return (a ^ b);
}
)";
auto got = Run<RemovePhonies>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemovePhoniesTest, ForLoop) {
auto* src = R"(
struct S {
arr : array<i32>;
};
@group(0) @binding(0) var<storage, read_write> s : S;
fn x() -> i32 {
return 0;
}
fn y() -> i32 {
return 0;
}
fn z() -> i32 {
return 0;
}
fn f() {
for (_ = &s.arr; ;_ = &s.arr) {
break;
}
for (_ = x(); ;_ = y() + z()) {
break;
}
}
)";
auto* expect = R"(
struct S {
arr : array<i32>;
}
@group(0) @binding(0) var<storage, read_write> s : S;
fn x() -> i32 {
return 0;
}
fn y() -> i32 {
return 0;
}
fn z() -> i32 {
return 0;
}
fn phony_sink(p0 : i32, p1 : i32) {
}
fn f() {
for(; ; ) {
break;
}
for(x(); ; phony_sink(y(), z())) {
break;
}
}
)";
auto got = Run<RemovePhonies>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemovePhoniesTest, ForLoop_OutOfOrder) {
auto* src = R"(
fn f() {
for (_ = &s.arr; ;_ = &s.arr) {
break;
}
for (_ = x(); ;_ = y() + z()) {
break;
}
}
fn x() -> i32 {
return 0;
}
fn y() -> i32 {
return 0;
}
fn z() -> i32 {
return 0;
}
struct S {
arr : array<i32>;
};
@group(0) @binding(0) var<storage, read_write> s : S;
)";
auto* expect = R"(
fn phony_sink(p0 : i32, p1 : i32) {
}
fn f() {
for(; ; ) {
break;
}
for(x(); ; phony_sink(y(), z())) {
break;
}
}
fn x() -> i32 {
return 0;
}
fn y() -> i32 {
return 0;
}
fn z() -> i32 {
return 0;
}
struct S {
arr : array<i32>;
}
@group(0) @binding(0) var<storage, read_write> s : S;
)";
auto got = Run<RemovePhonies>(src);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,67 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/remove_unreachable_statements.h"
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include "src/tint/ast/traverse_expressions.h"
#include "src/tint/program_builder.h"
#include "src/tint/sem/block_statement.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/statement.h"
#include "src/tint/sem/variable.h"
#include "src/tint/utils/map.h"
#include "src/tint/utils/scoped_assignment.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::RemoveUnreachableStatements);
namespace tint {
namespace transform {
RemoveUnreachableStatements::RemoveUnreachableStatements() = default;
RemoveUnreachableStatements::~RemoveUnreachableStatements() = default;
bool RemoveUnreachableStatements::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* stmt = program->Sem().Get<sem::Statement>(node)) {
if (!stmt->IsReachable()) {
return true;
}
}
}
return false;
}
void RemoveUnreachableStatements::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* stmt = ctx.src->Sem().Get<sem::Statement>(node)) {
if (!stmt->IsReachable()) {
RemoveStatement(ctx, stmt->Declaration());
}
}
}
ctx.Clone();
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,58 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_REMOVE_UNREACHABLE_STATEMENTS_H_
#define SRC_TINT_TRANSFORM_REMOVE_UNREACHABLE_STATEMENTS_H_
#include <string>
#include <unordered_map>
#include "src/tint/transform/transform.h"
namespace tint {
namespace transform {
/// RemoveUnreachableStatements is a Transform that removes all statements
/// marked as unreachable.
class RemoveUnreachableStatements
: public Castable<RemoveUnreachableStatements, Transform> {
public:
/// Constructor
RemoveUnreachableStatements();
/// Destructor
~RemoveUnreachableStatements() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_REMOVE_UNREACHABLE_STATEMENTS_H_

View File

@@ -0,0 +1,571 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/remove_unreachable_statements.h"
#include "src/tint/transform/test_helper.h"
namespace tint {
namespace transform {
namespace {
using RemoveUnreachableStatementsTest = TransformTest;
TEST_F(RemoveUnreachableStatementsTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<RemoveUnreachableStatements>(src));
}
TEST_F(RemoveUnreachableStatementsTest, ShouldRunHasNoUnreachable) {
auto* src = R"(
fn f() {
if (true) {
var x = 1;
}
}
)";
EXPECT_FALSE(ShouldRun<RemoveUnreachableStatements>(src));
}
TEST_F(RemoveUnreachableStatementsTest, ShouldRunHasUnreachable) {
auto* src = R"(
fn f() {
return;
if (true) {
var x = 1;
}
}
)";
EXPECT_TRUE(ShouldRun<RemoveUnreachableStatements>(src));
}
TEST_F(RemoveUnreachableStatementsTest, EmptyModule) {
auto* src = "";
auto* expect = "";
auto got = Run<RemoveUnreachableStatements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, Return) {
auto* src = R"(
fn f() {
return;
var remove_me = 1;
if (true) {
var remove_me_too = 1;
}
}
)";
auto* expect = R"(
fn f() {
return;
}
)";
auto got = Run<RemoveUnreachableStatements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, NestedReturn) {
auto* src = R"(
fn f() {
{
{
return;
}
}
var remove_me = 1;
if (true) {
var remove_me_too = 1;
}
}
)";
auto* expect = R"(
fn f() {
{
{
return;
}
}
}
)";
auto got = Run<RemoveUnreachableStatements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, Discard) {
auto* src = R"(
fn f() {
discard;
var remove_me = 1;
if (true) {
var remove_me_too = 1;
}
}
)";
auto* expect = R"(
fn f() {
discard;
}
)";
auto got = Run<RemoveUnreachableStatements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, NestedDiscard) {
auto* src = R"(
fn f() {
{
{
discard;
}
}
var remove_me = 1;
if (true) {
var remove_me_too = 1;
}
}
)";
auto* expect = R"(
fn f() {
{
{
discard;
}
}
}
)";
auto got = Run<RemoveUnreachableStatements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, CallToFuncWithDiscard) {
auto* src = R"(
fn DISCARD() {
discard;
}
fn f() {
DISCARD();
var remove_me = 1;
if (true) {
var remove_me_too = 1;
}
}
)";
auto* expect = R"(
fn DISCARD() {
discard;
}
fn f() {
DISCARD();
}
)";
auto got = Run<RemoveUnreachableStatements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, CallToFuncWithIfDiscard) {
auto* src = R"(
fn DISCARD() {
if (true) {
discard;
}
}
fn f() {
DISCARD();
var preserve_me = 1;
if (true) {
var preserve_me_too = 1;
}
}
)";
auto* expect = src;
auto got = Run<RemoveUnreachableStatements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, IfDiscardElseDiscard) {
auto* src = R"(
fn f() {
if (true) {
discard;
} else {
discard;
}
var remove_me = 1;
if (true) {
var remove_me_too = 1;
}
}
)";
auto* expect = R"(
fn f() {
if (true) {
discard;
} else {
discard;
}
}
)";
auto got = Run<RemoveUnreachableStatements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, IfDiscardElseReturn) {
auto* src = R"(
fn f() {
if (true) {
discard;
} else {
return;
}
var remove_me = 1;
if (true) {
var remove_me_too = 1;
}
}
)";
auto* expect = R"(
fn f() {
if (true) {
discard;
} else {
return;
}
}
)";
auto got = Run<RemoveUnreachableStatements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, IfDiscard) {
auto* src = R"(
fn f() {
if (true) {
discard;
}
var preserve_me = 1;
if (true) {
var preserve_me_too = 1;
}
}
)";
auto* expect = src;
auto got = Run<RemoveUnreachableStatements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, IfReturn) {
auto* src = R"(
fn f() {
if (true) {
return;
}
var preserve_me = 1;
if (true) {
var preserve_me_too = 1;
}
}
)";
auto* expect = src;
auto got = Run<RemoveUnreachableStatements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, IfElseDiscard) {
auto* src = R"(
fn f() {
if (true) {
} else {
discard;
}
var preserve_me = 1;
if (true) {
var preserve_me_too = 1;
}
}
)";
auto* expect = src;
auto got = Run<RemoveUnreachableStatements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, IfElseReturn) {
auto* src = R"(
fn f() {
if (true) {
} else {
return;
}
var preserve_me = 1;
if (true) {
var preserve_me_too = 1;
}
}
)";
auto* expect = src;
auto got = Run<RemoveUnreachableStatements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, LoopWithDiscard) {
auto* src = R"(
fn f() {
loop {
var a = 1;
discard;
continuing {
var b = 2;
}
}
var remove_me = 1;
if (true) {
var remove_me_too = 1;
}
}
)";
auto* expect = R"(
fn f() {
loop {
var a = 1;
discard;
continuing {
var b = 2;
}
}
}
)";
auto got = Run<RemoveUnreachableStatements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, LoopWithConditionalBreak) {
auto* src = R"(
fn f() {
loop {
var a = 1;
if (true) {
break;
}
continuing {
var b = 2;
}
}
var preserve_me = 1;
if (true) {
var preserve_me_too = 1;
}
}
)";
auto* expect = src;
auto got = Run<RemoveUnreachableStatements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, LoopWithConditionalBreakInContinuing) {
auto* src = R"(
fn f() {
loop {
continuing {
if (true) {
break;
}
}
}
var preserve_me = 1;
if (true) {
var preserve_me_too = 1;
}
}
)";
auto* expect = src;
auto got = Run<RemoveUnreachableStatements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, SwitchDefaultDiscard) {
auto* src = R"(
fn f() {
switch(1) {
default: {
discard;
}
}
var remove_me = 1;
if (true) {
var remove_me_too = 1;
}
}
)";
auto* expect = R"(
fn f() {
switch(1) {
default: {
discard;
}
}
}
)";
auto got = Run<RemoveUnreachableStatements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, SwitchCaseReturnDefaultDiscard) {
auto* src = R"(
fn f() {
switch(1) {
case 0: {
return;
}
default: {
discard;
}
}
var remove_me = 1;
if (true) {
var remove_me_too = 1;
}
}
)";
auto* expect = R"(
fn f() {
switch(1) {
case 0: {
return;
}
default: {
discard;
}
}
}
)";
auto got = Run<RemoveUnreachableStatements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, SwitchCaseBreakDefaultDiscard) {
auto* src = R"(
fn f() {
switch(1) {
case 0: {
break;
}
default: {
discard;
}
}
var preserve_me = 1;
if (true) {
var preserve_me_too = 1;
}
}
)";
auto* expect = src;
auto got = Run<RemoveUnreachableStatements>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveUnreachableStatementsTest, SwitchCaseReturnDefaultBreak) {
auto* src = R"(
fn f() {
switch(1) {
case 0: {
return;
}
default: {
break;
}
}
var preserve_me = 1;
if (true) {
var preserve_me_too = 1;
}
}
)";
auto* expect = src;
auto got = Run<RemoveUnreachableStatements>(src);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,97 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_RENAMER_H_
#define SRC_TINT_TRANSFORM_RENAMER_H_
#include <string>
#include <unordered_map>
#include "src/tint/transform/transform.h"
namespace tint::transform {
/// Renamer is a Transform that renames all the symbols in a program.
class Renamer : public Castable<Renamer, Transform> {
public:
/// Data is outputted by the Renamer transform.
/// Data holds information about shader usage and constant buffer offsets.
struct Data : public Castable<Data, transform::Data> {
/// Remappings is a map of old symbol name to new symbol name
using Remappings = std::unordered_map<std::string, std::string>;
/// Constructor
/// @param remappings the symbol remappings
explicit Data(Remappings&& remappings);
/// Copy constructor
Data(const Data&);
/// Destructor
~Data() override;
/// A map of old symbol name to new symbol name
const Remappings remappings;
};
/// Target is an enumerator of rename targets that can be used
enum class Target {
/// Rename every symbol.
kAll,
/// Only rename symbols that are reserved keywords in GLSL.
kGlslKeywords,
/// Only rename symbols that are reserved keywords in HLSL.
kHlslKeywords,
/// Only rename symbols that are reserved keywords in MSL.
kMslKeywords,
};
/// Optional configuration options for the transform.
/// If omitted, then the renamer will use Target::kAll.
struct Config : public Castable<Config, transform::Data> {
/// Constructor
/// @param tgt the targets to rename
/// @param keep_unicode if false, symbols with non-ascii code-points are
/// renamed
explicit Config(Target tgt, bool keep_unicode = false);
/// Copy constructor
Config(const Config&);
/// Destructor
~Config() override;
/// The targets to rename
Target const target = Target::kAll;
/// If false, symbols with non-ascii code-points are renamed.
bool preserve_unicode = false;
};
/// Constructor using a the configuration provided in the input Data
Renamer();
/// Destructor
~Renamer() override;
/// Runs the transform on `program`, returning the transformation result.
/// @param program the source program to transform
/// @param data optional extra transform-specific input data
/// @returns the transformation result
Output Run(const Program* program, const DataMap& data = {}) const override;
};
} // namespace tint::transform
#endif // SRC_TINT_TRANSFORM_RENAMER_H_

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,321 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/robustness.h"
#include <algorithm>
#include <limits>
#include <utility>
#include "src/tint/program_builder.h"
#include "src/tint/sem/block_statement.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/expression.h"
#include "src/tint/sem/reference_type.h"
#include "src/tint/sem/statement.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::Robustness);
TINT_INSTANTIATE_TYPEINFO(tint::transform::Robustness::Config);
namespace tint {
namespace transform {
/// State holds the current transform state
struct Robustness::State {
/// The clone context
CloneContext& ctx;
/// Set of storage classes to not apply the transform to
std::unordered_set<ast::StorageClass> omitted_classes;
/// Applies the transformation state to `ctx`.
void Transform() {
ctx.ReplaceAll([&](const ast::IndexAccessorExpression* expr) {
return Transform(expr);
});
ctx.ReplaceAll(
[&](const ast::CallExpression* expr) { return Transform(expr); });
}
/// Apply bounds clamping to array, vector and matrix indexing
/// @param expr the array, vector or matrix index expression
/// @return the clamped replacement expression, or nullptr if `expr` should be
/// cloned without changes.
const ast::IndexAccessorExpression* Transform(
const ast::IndexAccessorExpression* expr) {
auto* ret_type = ctx.src->Sem().Get(expr->object)->Type();
auto* ref = ret_type->As<sem::Reference>();
if (ref && omitted_classes.count(ref->StorageClass()) != 0) {
return nullptr;
}
auto* ret_unwrapped = ret_type->UnwrapRef();
ProgramBuilder& b = *ctx.dst;
using u32 = ProgramBuilder::u32;
struct Value {
const ast::Expression* expr = nullptr; // If null, then is a constant
union {
uint32_t u32 = 0; // use if is_signed == false
int32_t i32; // use if is_signed == true
};
bool is_signed = false;
};
Value size; // size of the array, vector or matrix
size.is_signed = false; // size is always unsigned
if (auto* vec = ret_unwrapped->As<sem::Vector>()) {
size.u32 = vec->Width();
} else if (auto* arr = ret_unwrapped->As<sem::Array>()) {
size.u32 = arr->Count();
} else if (auto* mat = ret_unwrapped->As<sem::Matrix>()) {
// The row accessor would have been an embedded index accessor and already
// handled, so we just need to do columns here.
size.u32 = mat->columns();
} else {
return nullptr;
}
if (size.u32 == 0) {
if (!ret_unwrapped->Is<sem::Array>()) {
b.Diagnostics().add_error(diag::System::Transform,
"invalid 0 sized non-array", expr->source);
return nullptr;
}
// Runtime sized array
auto* arr = ctx.Clone(expr->object);
size.expr = b.Call("arrayLength", b.AddressOf(arr));
}
// Calculate the maximum possible index value (size-1u)
// Size must be positive (non-zero), so we can safely subtract 1 here
// without underflow.
Value limit;
limit.is_signed = false; // Like size, limit is always unsigned.
if (size.expr) {
// Dynamic size
limit.expr = b.Sub(size.expr, 1u);
} else {
// Constant size
limit.u32 = size.u32 - 1u;
}
Value idx; // index value
auto* idx_sem = ctx.src->Sem().Get(expr->index);
auto* idx_ty = idx_sem->Type()->UnwrapRef();
if (!idx_ty->IsAnyOf<sem::I32, sem::U32>()) {
TINT_ICE(Transform, b.Diagnostics())
<< "index must be u32 or i32, got " << idx_sem->Type()->type_name();
return nullptr;
}
if (auto idx_constant = idx_sem->ConstantValue()) {
// Constant value index
if (idx_constant.Type()->Is<sem::I32>()) {
idx.i32 = idx_constant.Elements()[0].i32;
idx.is_signed = true;
} else if (idx_constant.Type()->Is<sem::U32>()) {
idx.u32 = idx_constant.Elements()[0].u32;
idx.is_signed = false;
} else {
b.Diagnostics().add_error(diag::System::Transform,
"unsupported constant value for accessor: " +
idx_constant.Type()->type_name(),
expr->source);
return nullptr;
}
} else {
// Dynamic value index
idx.expr = ctx.Clone(expr->index);
idx.is_signed = idx_ty->Is<sem::I32>();
}
// Clamp the index so that it cannot exceed limit.
if (idx.expr || limit.expr) {
// One of, or both of idx and limit are non-constant.
// If the index is signed, cast it to a u32 (with clamping if constant).
if (idx.is_signed) {
if (idx.expr) {
// We don't use a max(idx, 0) here, as that incurs a runtime
// performance cost, and if the unsigned value will be clamped by
// limit, resulting in a value between [0..limit)
idx.expr = b.Construct<u32>(idx.expr);
idx.is_signed = false;
} else {
idx.u32 = static_cast<uint32_t>(std::max(idx.i32, 0));
idx.is_signed = false;
}
}
// Convert idx and limit to expressions, so we can emit `min(idx, limit)`.
if (!idx.expr) {
idx.expr = b.Expr(idx.u32);
}
if (!limit.expr) {
limit.expr = b.Expr(limit.u32);
}
// Perform the clamp with `min(idx, limit)`
idx.expr = b.Call("min", idx.expr, limit.expr);
} else {
// Both idx and max are constant.
if (idx.is_signed) {
// The index is signed. Calculate limit as signed.
int32_t signed_limit = static_cast<int32_t>(
std::min<uint32_t>(limit.u32, std::numeric_limits<int32_t>::max()));
idx.i32 = std::max(idx.i32, 0);
idx.i32 = std::min(idx.i32, signed_limit);
} else {
// The index is unsigned.
idx.u32 = std::min(idx.u32, limit.u32);
}
}
// Convert idx to an expression, so we can emit the new accessor.
if (!idx.expr) {
idx.expr = idx.is_signed
? static_cast<const ast::Expression*>(b.Expr(idx.i32))
: static_cast<const ast::Expression*>(b.Expr(idx.u32));
}
// Clone arguments outside of create() call to have deterministic ordering
auto src = ctx.Clone(expr->source);
auto* obj = ctx.Clone(expr->object);
return b.IndexAccessor(src, obj, idx.expr);
}
/// @param type builtin type
/// @returns true if the given builtin is a texture function that requires
/// argument clamping,
bool TextureBuiltinNeedsClamping(sem::BuiltinType type) {
return type == sem::BuiltinType::kTextureLoad ||
type == sem::BuiltinType::kTextureStore;
}
/// Apply bounds clamping to the coordinates, array index and level arguments
/// of the `textureLoad()` and `textureStore()` builtins.
/// @param expr the builtin call expression
/// @return the clamped replacement call expression, or nullptr if `expr`
/// should be cloned without changes.
const ast::CallExpression* Transform(const ast::CallExpression* expr) {
auto* call = ctx.src->Sem().Get(expr);
auto* call_target = call->Target();
auto* builtin = call_target->As<sem::Builtin>();
if (!builtin || !TextureBuiltinNeedsClamping(builtin->Type())) {
return nullptr; // No transform, just clone.
}
ProgramBuilder& b = *ctx.dst;
// Indices of the mandatory texture and coords parameters, and the optional
// array and level parameters.
auto& signature = builtin->Signature();
auto texture_idx = signature.IndexOf(sem::ParameterUsage::kTexture);
auto coords_idx = signature.IndexOf(sem::ParameterUsage::kCoords);
auto array_idx = signature.IndexOf(sem::ParameterUsage::kArrayIndex);
auto level_idx = signature.IndexOf(sem::ParameterUsage::kLevel);
auto* texture_arg = expr->args[texture_idx];
auto* coords_arg = expr->args[coords_idx];
auto* coords_ty = builtin->Parameters()[coords_idx]->Type();
// If the level is provided, then we need to clamp this. As the level is
// used by textureDimensions() and the texture[Load|Store]() calls, we need
// to clamp both usages.
// TODO(bclayton): We probably want to place this into a let so that the
// calculation can be reused. This is fiddly to get right.
std::function<const ast::Expression*()> level_arg;
if (level_idx >= 0) {
level_arg = [&] {
auto* arg = expr->args[level_idx];
auto* num_levels = b.Call("textureNumLevels", ctx.Clone(texture_arg));
auto* zero = b.Expr(0);
auto* max = ctx.dst->Sub(num_levels, 1);
auto* clamped = b.Call("clamp", ctx.Clone(arg), zero, max);
return clamped;
};
}
// Clamp the coordinates argument
{
auto* texture_dims =
level_arg
? b.Call("textureDimensions", ctx.Clone(texture_arg), level_arg())
: b.Call("textureDimensions", ctx.Clone(texture_arg));
auto* zero = b.Construct(CreateASTTypeFor(ctx, coords_ty));
auto* max = ctx.dst->Sub(
texture_dims, b.Construct(CreateASTTypeFor(ctx, coords_ty), 1));
auto* clamped_coords = b.Call("clamp", ctx.Clone(coords_arg), zero, max);
ctx.Replace(coords_arg, clamped_coords);
}
// Clamp the array_index argument, if provided
if (array_idx >= 0) {
auto* arg = expr->args[array_idx];
auto* num_layers = b.Call("textureNumLayers", ctx.Clone(texture_arg));
auto* zero = b.Expr(0);
auto* max = ctx.dst->Sub(num_layers, 1);
auto* clamped = b.Call("clamp", ctx.Clone(arg), zero, max);
ctx.Replace(arg, clamped);
}
// Clamp the level argument, if provided
if (level_idx >= 0) {
auto* arg = expr->args[level_idx];
ctx.Replace(arg, level_arg ? level_arg() : ctx.dst->Expr(0));
}
return nullptr; // Clone, which will use the argument replacements above.
}
};
Robustness::Config::Config() = default;
Robustness::Config::Config(const Config&) = default;
Robustness::Config::~Config() = default;
Robustness::Config& Robustness::Config::operator=(const Config&) = default;
Robustness::Robustness() = default;
Robustness::~Robustness() = default;
void Robustness::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
Config cfg;
if (auto* cfg_data = inputs.Get<Config>()) {
cfg = *cfg_data;
}
std::unordered_set<ast::StorageClass> omitted_classes;
for (auto sc : cfg.omitted_classes) {
switch (sc) {
case StorageClass::kUniform:
omitted_classes.insert(ast::StorageClass::kUniform);
break;
case StorageClass::kStorage:
omitted_classes.insert(ast::StorageClass::kStorage);
break;
}
}
State state{ctx, std::move(omitted_classes)};
state.Transform();
ctx.Clone();
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,88 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_ROBUSTNESS_H_
#define SRC_TINT_TRANSFORM_ROBUSTNESS_H_
#include <unordered_set>
#include "src/tint/transform/transform.h"
// Forward declarations
namespace tint {
namespace ast {
class IndexAccessorExpression;
class CallExpression;
} // namespace ast
} // namespace tint
namespace tint {
namespace transform {
/// This transform is responsible for clamping all array accesses to be within
/// the bounds of the array. Any access before the start of the array will clamp
/// to zero and any access past the end of the array will clamp to
/// (array length - 1).
class Robustness : public Castable<Robustness, Transform> {
public:
/// Storage class to be skipped in the transform
enum class StorageClass {
kUniform,
kStorage,
};
/// Configuration options for the transform
struct Config : public Castable<Config, Data> {
/// Constructor
Config();
/// Copy constructor
Config(const Config&);
/// Destructor
~Config() override;
/// Assignment operator
/// @returns this Config
Config& operator=(const Config&);
/// Storage classes to omit from apply the transform to.
/// This allows for optimizing on hardware that provide safe accesses.
std::unordered_set<StorageClass> omitted_classes;
};
/// Constructor
Robustness();
/// Destructor
~Robustness() override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
private:
struct State;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_ROBUSTNESS_H_

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,239 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/simplify_pointers.h"
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include "src/tint/program_builder.h"
#include "src/tint/sem/block_statement.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/statement.h"
#include "src/tint/sem/variable.h"
#include "src/tint/transform/unshadow.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::SimplifyPointers);
namespace tint {
namespace transform {
namespace {
/// PointerOp describes either possible indirection or address-of action on an
/// expression.
struct PointerOp {
/// Positive: Number of times the `expr` was dereferenced (*expr)
/// Negative: Number of times the `expr` was 'addressed-of' (&expr)
/// Zero: no pointer op on `expr`
int indirections = 0;
/// The expression being operated on
const ast::Expression* expr = nullptr;
};
} // namespace
/// The PIMPL state for the SimplifyPointers transform
struct SimplifyPointers::State {
/// The clone context
CloneContext& ctx;
/// Constructor
/// @param context the clone context
explicit State(CloneContext& context) : ctx(context) {}
/// Traverses the expression `expr` looking for non-literal array indexing
/// expressions that would affect the computed address of a pointer
/// expression. The function-like argument `cb` is called for each found.
/// @param expr the expression to traverse
/// @param cb a function-like object with the signature
/// `void(const ast::Expression*)`, which is called for each array index
/// expression
template <typename F>
static void CollectSavedArrayIndices(const ast::Expression* expr, F&& cb) {
if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
CollectSavedArrayIndices(a->object, cb);
if (!a->index->Is<ast::LiteralExpression>()) {
cb(a->index);
}
return;
}
if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
CollectSavedArrayIndices(m->structure, cb);
return;
}
if (auto* u = expr->As<ast::UnaryOpExpression>()) {
CollectSavedArrayIndices(u->expr, cb);
return;
}
// Note: Other ast::Expression types can be safely ignored as they cannot be
// used to generate a reference or pointer.
// See https://gpuweb.github.io/gpuweb/wgsl/#forming-references-and-pointers
}
/// Reduce walks the expression chain, collapsing all address-of and
/// indirection ops into a PointerOp.
/// @param in the expression to walk
/// @returns the reduced PointerOp
PointerOp Reduce(const ast::Expression* in) const {
PointerOp op{0, in};
while (true) {
if (auto* unary = op.expr->As<ast::UnaryOpExpression>()) {
switch (unary->op) {
case ast::UnaryOp::kIndirection:
op.indirections++;
op.expr = unary->expr;
continue;
case ast::UnaryOp::kAddressOf:
op.indirections--;
op.expr = unary->expr;
continue;
default:
break;
}
}
if (auto* user = ctx.src->Sem().Get<sem::VariableUser>(op.expr)) {
auto* var = user->Variable();
if (var->Is<sem::LocalVariable>() && //
var->Declaration()->is_const && //
var->Type()->Is<sem::Pointer>()) {
op.expr = var->Declaration()->constructor;
continue;
}
}
return op;
}
}
/// Performs the transformation
void Run() {
// A map of saved expressions to their saved variable name
std::unordered_map<const ast::Expression*, Symbol> saved_vars;
// Register the ast::Expression transform handler.
// This performs two different transformations:
// * Identifiers that resolve to the pointer-typed `let` declarations are
// replaced with the recursively inlined initializer expression for the
// `let` declaration.
// * Sub-expressions inside the pointer-typed `let` initializer expression
// that have been hoisted to a saved variable are replaced with the saved
// variable identifier.
ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
// Look to see if we need to swap this Expression with a saved variable.
auto it = saved_vars.find(expr);
if (it != saved_vars.end()) {
return ctx.dst->Expr(it->second);
}
// Reduce the expression, folding away chains of address-of / indirections
auto op = Reduce(expr);
// Clone the reduced root expression
expr = ctx.CloneWithoutTransform(op.expr);
// And reapply the minimum number of address-of / indirections
for (int i = 0; i < op.indirections; i++) {
expr = ctx.dst->Deref(expr);
}
for (int i = 0; i > op.indirections; i--) {
expr = ctx.dst->AddressOf(expr);
}
return expr;
});
// Find all the pointer-typed `let` declarations.
// Note that these must be function-scoped, as module-scoped `let`s are not
// permitted.
for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* let = node->As<ast::VariableDeclStatement>()) {
if (!let->variable->is_const) {
continue; // Not a `let` declaration. Ignore.
}
auto* var = ctx.src->Sem().Get(let->variable);
if (!var->Type()->Is<sem::Pointer>()) {
continue; // Not a pointer type. Ignore.
}
// We're dealing with a pointer-typed `let` declaration.
// Scan the initializer expression for array index expressions that need
// to be hoist to temporary "saved" variables.
std::vector<const ast::VariableDeclStatement*> saved;
CollectSavedArrayIndices(
var->Declaration()->constructor,
[&](const ast::Expression* idx_expr) {
// We have a sub-expression that needs to be saved.
// Create a new variable
auto saved_name = ctx.dst->Symbols().New(
ctx.src->Symbols().NameFor(var->Declaration()->symbol) +
"_save");
auto* decl = ctx.dst->Decl(
ctx.dst->Const(saved_name, nullptr, ctx.Clone(idx_expr)));
saved.emplace_back(decl);
// Record the substitution of `idx_expr` to the saved variable
// with the symbol `saved_name`. This will be used by the
// ReplaceAll() handler above.
saved_vars.emplace(idx_expr, saved_name);
});
// Find the place to insert the saved declarations.
// Special care needs to be made for lets declared as the initializer
// part of for-loops. In this case the block will hold the for-loop
// statement, not the let.
if (!saved.empty()) {
auto* stmt = ctx.src->Sem().Get(let);
auto* block = stmt->Block();
// Find the statement owned by the block (either the let decl or a
// for-loop)
while (block != stmt->Parent()) {
stmt = stmt->Parent();
}
// Declare the stored variables just before stmt. Order here is
// important as order-of-operations needs to be preserved.
// CollectSavedArrayIndices() visits the LHS of an index accessor
// before the index expression.
for (auto* decl : saved) {
// Note that repeated calls to InsertBefore() with the same `before`
// argument will result in nodes to inserted in the order the
// calls are made (last call is inserted last).
ctx.InsertBefore(block->Declaration()->statements,
stmt->Declaration(), decl);
}
}
// As the original `let` declaration will be fully inlined, there's no
// need for the original declaration to exist. Remove it.
RemoveStatement(ctx, let);
}
}
ctx.Clone();
}
};
SimplifyPointers::SimplifyPointers() = default;
SimplifyPointers::~SimplifyPointers() = default;
void SimplifyPointers::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
State(ctx).Run();
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,60 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_SIMPLIFY_POINTERS_H_
#define SRC_TINT_TRANSFORM_SIMPLIFY_POINTERS_H_
#include "src/tint/transform/transform.h"
namespace tint {
namespace transform {
/// SimplifyPointers is a Transform that moves all usage of function-scope
/// `let` statements of a pointer type into their places of usage, while also
/// simplifying any chains of address-of or indirections operators.
///
/// Parameters of a pointer type are not adjusted.
///
/// Note: SimplifyPointers does not operate on module-scope `let`s, as these
/// cannot be pointers: https://gpuweb.github.io/gpuweb/wgsl/#module-constants
/// `A module-scope let-declared constant must be of constructible type.`
///
/// @note Depends on the following transforms to have been run first:
/// * Unshadow
class SimplifyPointers : public Castable<SimplifyPointers, Transform> {
public:
/// Constructor
SimplifyPointers();
/// Destructor
~SimplifyPointers() override;
protected:
struct State;
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_SIMPLIFY_POINTERS_H_

View File

@@ -0,0 +1,370 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/simplify_pointers.h"
#include "src/tint/transform/test_helper.h"
#include "src/tint/transform/unshadow.h"
namespace tint {
namespace transform {
namespace {
using SimplifyPointersTest = TransformTest;
TEST_F(SimplifyPointersTest, EmptyModule) {
auto* src = "";
auto* expect = "";
auto got = Run<Unshadow, SimplifyPointers>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SimplifyPointersTest, FoldPointer) {
auto* src = R"(
fn f() {
var v : i32;
let p : ptr<function, i32> = &v;
let x : i32 = *p;
}
)";
auto* expect = R"(
fn f() {
var v : i32;
let x : i32 = v;
}
)";
auto got = Run<Unshadow, SimplifyPointers>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SimplifyPointersTest, AddressOfDeref) {
auto* src = R"(
fn f() {
var v : i32;
let p : ptr<function, i32> = &(v);
let x : ptr<function, i32> = &(*(p));
let y : ptr<function, i32> = &(*(&(*(p))));
let z : ptr<function, i32> = &(*(&(*(&(*(&(*(p))))))));
var a = *x;
var b = *y;
var c = *z;
}
)";
auto* expect = R"(
fn f() {
var v : i32;
var a = v;
var b = v;
var c = v;
}
)";
auto got = Run<Unshadow, SimplifyPointers>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SimplifyPointersTest, DerefAddressOf) {
auto* src = R"(
fn f() {
var v : i32;
let x : i32 = *(&(v));
let y : i32 = *(&(*(&(v))));
let z : i32 = *(&(*(&(*(&(*(&(v))))))));
}
)";
auto* expect = R"(
fn f() {
var v : i32;
let x : i32 = v;
let y : i32 = v;
let z : i32 = v;
}
)";
auto got = Run<Unshadow, SimplifyPointers>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SimplifyPointersTest, ComplexChain) {
auto* src = R"(
fn f() {
var a : array<mat4x4<f32>, 4>;
let ap : ptr<function, array<mat4x4<f32>, 4>> = &a;
let mp : ptr<function, mat4x4<f32>> = &(*ap)[3];
let vp : ptr<function, vec4<f32>> = &(*mp)[2];
let v : vec4<f32> = *vp;
}
)";
auto* expect = R"(
fn f() {
var a : array<mat4x4<f32>, 4>;
let v : vec4<f32> = a[3][2];
}
)";
auto got = Run<Unshadow, SimplifyPointers>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SimplifyPointersTest, SavedVars) {
auto* src = R"(
struct S {
i : i32;
};
fn arr() {
var a : array<S, 2>;
var i : i32 = 0;
var j : i32 = 0;
let p : ptr<function, i32> = &a[i + j].i;
i = 2;
*p = 4;
}
fn matrix() {
var m : mat3x3<f32>;
var i : i32 = 0;
var j : i32 = 0;
let p : ptr<function, vec3<f32>> = &m[i + j];
i = 2;
*p = vec3<f32>(4.0, 5.0, 6.0);
}
)";
auto* expect = R"(
struct S {
i : i32;
}
fn arr() {
var a : array<S, 2>;
var i : i32 = 0;
var j : i32 = 0;
let p_save = (i + j);
i = 2;
a[p_save].i = 4;
}
fn matrix() {
var m : mat3x3<f32>;
var i : i32 = 0;
var j : i32 = 0;
let p_save_1 = (i + j);
i = 2;
m[p_save_1] = vec3<f32>(4.0, 5.0, 6.0);
}
)";
auto got = Run<Unshadow, SimplifyPointers>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SimplifyPointersTest, DontSaveLiterals) {
auto* src = R"(
fn f() {
var arr : array<i32, 2>;
let p1 : ptr<function, i32> = &arr[1];
*p1 = 4;
}
)";
auto* expect = R"(
fn f() {
var arr : array<i32, 2>;
arr[1] = 4;
}
)";
auto got = Run<Unshadow, SimplifyPointers>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SimplifyPointersTest, SavedVarsChain) {
auto* src = R"(
fn f() {
var arr : array<array<i32, 2>, 2>;
let i : i32 = 0;
let j : i32 = 1;
let p : ptr<function, array<i32, 2>> = &arr[i];
let q : ptr<function, i32> = &(*p)[j];
*q = 12;
}
)";
auto* expect = R"(
fn f() {
var arr : array<array<i32, 2>, 2>;
let i : i32 = 0;
let j : i32 = 1;
let p_save = i;
let q_save = j;
arr[p_save][q_save] = 12;
}
)";
auto got = Run<Unshadow, SimplifyPointers>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SimplifyPointersTest, ForLoopInit) {
auto* src = R"(
fn foo() -> i32 {
return 1;
}
@stage(fragment)
fn main() {
var arr = array<f32, 4>();
for (let a = &arr[foo()]; ;) {
let x = *a;
break;
}
}
)";
auto* expect = R"(
fn foo() -> i32 {
return 1;
}
@stage(fragment)
fn main() {
var arr = array<f32, 4>();
let a_save = foo();
for(; ; ) {
let x = arr[a_save];
break;
}
}
)";
auto got = Run<Unshadow, SimplifyPointers>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SimplifyPointersTest, MultiSavedVarsInSinglePtrLetExpr) {
auto* src = R"(
fn x() -> i32 {
return 1;
}
fn y() -> i32 {
return 1;
}
fn z() -> i32 {
return 1;
}
struct Inner {
a : array<i32, 2>;
};
struct Outer {
a : array<Inner, 2>;
};
fn f() {
var arr : array<Outer, 2>;
let p : ptr<function, i32> = &arr[x()].a[y()].a[z()];
*p = 1;
*p = 2;
}
)";
auto* expect = R"(
fn x() -> i32 {
return 1;
}
fn y() -> i32 {
return 1;
}
fn z() -> i32 {
return 1;
}
struct Inner {
a : array<i32, 2>;
}
struct Outer {
a : array<Inner, 2>;
}
fn f() {
var arr : array<Outer, 2>;
let p_save = x();
let p_save_1 = y();
let p_save_2 = z();
arr[p_save].a[p_save_1].a[p_save_2] = 1;
arr[p_save].a[p_save_1].a[p_save_2] = 2;
}
)";
auto got = Run<Unshadow, SimplifyPointers>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SimplifyPointersTest, ShadowPointer) {
auto* src = R"(
var<private> a : array<i32, 2>;
@stage(compute) @workgroup_size(1)
fn main() {
let x = &a;
var a : i32 = (*x)[0];
{
var a : i32 = (*x)[1];
}
}
)";
auto* expect = R"(
var<private> a : array<i32, 2>;
@stage(compute) @workgroup_size(1)
fn main() {
var a_1 : i32 = a[0];
{
var a_2 : i32 = a[1];
}
}
)";
auto got = Run<Unshadow, SimplifyPointers>(src);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,117 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/single_entry_point.h"
#include <unordered_set>
#include <utility>
#include "src/tint/program_builder.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/variable.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::SingleEntryPoint);
TINT_INSTANTIATE_TYPEINFO(tint::transform::SingleEntryPoint::Config);
namespace tint {
namespace transform {
SingleEntryPoint::SingleEntryPoint() = default;
SingleEntryPoint::~SingleEntryPoint() = default;
void SingleEntryPoint::Run(CloneContext& ctx,
const DataMap& inputs,
DataMap&) const {
auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform,
"missing transform data for " + std::string(TypeInfo().name));
return;
}
// Find the target entry point.
const ast::Function* entry_point = nullptr;
for (auto* f : ctx.src->AST().Functions()) {
if (!f->IsEntryPoint()) {
continue;
}
if (ctx.src->Symbols().NameFor(f->symbol) == cfg->entry_point_name) {
entry_point = f;
break;
}
}
if (entry_point == nullptr) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform,
"entry point '" + cfg->entry_point_name + "' not found");
return;
}
auto& sem = ctx.src->Sem();
// Build set of referenced module-scope variables for faster lookups later.
std::unordered_set<const ast::Variable*> referenced_vars;
for (auto* var : sem.Get(entry_point)->TransitivelyReferencedGlobals()) {
referenced_vars.emplace(var->Declaration());
}
// Clone any module-scope variables, types, and functions that are statically
// referenced by the target entry point.
for (auto* decl : ctx.src->AST().GlobalDeclarations()) {
if (auto* ty = decl->As<ast::TypeDecl>()) {
// TODO(jrprice): Strip unused types.
ctx.dst->AST().AddTypeDecl(ctx.Clone(ty));
} else if (auto* var = decl->As<ast::Variable>()) {
if (referenced_vars.count(var)) {
if (var->is_overridable) {
// It is an overridable constant
if (!ast::HasAttribute<ast::IdAttribute>(var->attributes)) {
// If the constant doesn't already have an @id() attribute, add one
// so that its allocated ID so that it won't be affected by other
// stripped away constants
auto* global = sem.Get(var)->As<sem::GlobalVariable>();
const auto* id = ctx.dst->Id(global->ConstantId());
ctx.InsertFront(var->attributes, id);
}
}
ctx.dst->AST().AddGlobalVariable(ctx.Clone(var));
}
} else if (auto* func = decl->As<ast::Function>()) {
if (sem.Get(func)->HasAncestorEntryPoint(entry_point->symbol)) {
ctx.dst->AST().AddFunction(ctx.Clone(func));
}
} else {
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
<< "unhandled global declaration: " << decl->TypeInfo().name;
return;
}
}
// Clone the entry point.
ctx.dst->AST().AddFunction(ctx.Clone(entry_point));
}
SingleEntryPoint::Config::Config(std::string entry_point)
: entry_point_name(entry_point) {}
SingleEntryPoint::Config::Config(const Config&) = default;
SingleEntryPoint::Config::~Config() = default;
SingleEntryPoint::Config& SingleEntryPoint::Config::operator=(const Config&) =
default;
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,72 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_SINGLE_ENTRY_POINT_H_
#define SRC_TINT_TRANSFORM_SINGLE_ENTRY_POINT_H_
#include <string>
#include "src/tint/transform/transform.h"
namespace tint {
namespace transform {
/// Strip all but one entry point a module.
///
/// All module-scope variables, types, and functions that are not used by the
/// target entry point will also be removed.
class SingleEntryPoint : public Castable<SingleEntryPoint, Transform> {
public:
/// Configuration options for the transform
struct Config : public Castable<Config, Data> {
/// Constructor
/// @param entry_point the name of the entry point to keep
explicit Config(std::string entry_point = "");
/// Copy constructor
Config(const Config&);
/// Destructor
~Config() override;
/// Assignment operator
/// @returns this Config
Config& operator=(const Config&);
/// The name of the entry point to keep.
std::string entry_point_name;
};
/// Constructor
SingleEntryPoint();
/// Destructor
~SingleEntryPoint() override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_SINGLE_ENTRY_POINT_H_

View File

@@ -0,0 +1,517 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/single_entry_point.h"
#include <utility>
#include "src/tint/transform/test_helper.h"
namespace tint {
namespace transform {
namespace {
using SingleEntryPointTest = TransformTest;
TEST_F(SingleEntryPointTest, Error_MissingTransformData) {
auto* src = "";
auto* expect =
"error: missing transform data for tint::transform::SingleEntryPoint";
auto got = Run<SingleEntryPoint>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, Error_NoEntryPoints) {
auto* src = "";
auto* expect = "error: entry point 'main' not found";
DataMap data;
data.Add<SingleEntryPoint::Config>("main");
auto got = Run<SingleEntryPoint>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, Error_InvalidEntryPoint) {
auto* src = R"(
@stage(vertex)
fn main() -> @builtin(position) vec4<f32> {
return vec4<f32>();
}
)";
auto* expect = "error: entry point '_' not found";
SingleEntryPoint::Config cfg("_");
DataMap data;
data.Add<SingleEntryPoint::Config>(cfg);
auto got = Run<SingleEntryPoint>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, Error_NotAnEntryPoint) {
auto* src = R"(
fn foo() {}
@stage(fragment)
fn main() {}
)";
auto* expect = "error: entry point 'foo' not found";
SingleEntryPoint::Config cfg("foo");
DataMap data;
data.Add<SingleEntryPoint::Config>(cfg);
auto got = Run<SingleEntryPoint>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, SingleEntryPoint) {
auto* src = R"(
@stage(compute) @workgroup_size(1)
fn main() {
}
)";
SingleEntryPoint::Config cfg("main");
DataMap data;
data.Add<SingleEntryPoint::Config>(cfg);
auto got = Run<SingleEntryPoint>(src, data);
EXPECT_EQ(src, str(got));
}
TEST_F(SingleEntryPointTest, MultipleEntryPoints) {
auto* src = R"(
@stage(vertex)
fn vert_main() -> @builtin(position) vec4<f32> {
return vec4<f32>();
}
@stage(fragment)
fn frag_main() {
}
@stage(compute) @workgroup_size(1)
fn comp_main1() {
}
@stage(compute) @workgroup_size(1)
fn comp_main2() {
}
)";
auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn comp_main1() {
}
)";
SingleEntryPoint::Config cfg("comp_main1");
DataMap data;
data.Add<SingleEntryPoint::Config>(cfg);
auto got = Run<SingleEntryPoint>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, GlobalVariables) {
auto* src = R"(
var<private> a : f32;
var<private> b : f32;
var<private> c : f32;
var<private> d : f32;
@stage(vertex)
fn vert_main() -> @builtin(position) vec4<f32> {
a = 0.0;
return vec4<f32>();
}
@stage(fragment)
fn frag_main() {
b = 0.0;
}
@stage(compute) @workgroup_size(1)
fn comp_main1() {
c = 0.0;
}
@stage(compute) @workgroup_size(1)
fn comp_main2() {
d = 0.0;
}
)";
auto* expect = R"(
var<private> c : f32;
@stage(compute) @workgroup_size(1)
fn comp_main1() {
c = 0.0;
}
)";
SingleEntryPoint::Config cfg("comp_main1");
DataMap data;
data.Add<SingleEntryPoint::Config>(cfg);
auto got = Run<SingleEntryPoint>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, GlobalConstants) {
auto* src = R"(
let a : f32 = 1.0;
let b : f32 = 1.0;
let c : f32 = 1.0;
let d : f32 = 1.0;
@stage(vertex)
fn vert_main() -> @builtin(position) vec4<f32> {
let local_a : f32 = a;
return vec4<f32>();
}
@stage(fragment)
fn frag_main() {
let local_b : f32 = b;
}
@stage(compute) @workgroup_size(1)
fn comp_main1() {
let local_c : f32 = c;
}
@stage(compute) @workgroup_size(1)
fn comp_main2() {
let local_d : f32 = d;
}
)";
auto* expect = R"(
let c : f32 = 1.0;
@stage(compute) @workgroup_size(1)
fn comp_main1() {
let local_c : f32 = c;
}
)";
SingleEntryPoint::Config cfg("comp_main1");
DataMap data;
data.Add<SingleEntryPoint::Config>(cfg);
auto got = Run<SingleEntryPoint>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, WorkgroupSizeLetPreserved) {
auto* src = R"(
let size : i32 = 1;
@stage(compute) @workgroup_size(size)
fn main() {
}
)";
auto* expect = src;
SingleEntryPoint::Config cfg("main");
DataMap data;
data.Add<SingleEntryPoint::Config>(cfg);
auto got = Run<SingleEntryPoint>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, OverridableConstants) {
auto* src = R"(
@id(1001) override c1 : u32 = 1u;
override c2 : u32 = 1u;
@id(0) override c3 : u32 = 1u;
@id(9999) override c4 : u32 = 1u;
@stage(compute) @workgroup_size(1)
fn comp_main1() {
let local_d = c1;
}
@stage(compute) @workgroup_size(1)
fn comp_main2() {
let local_d = c2;
}
@stage(compute) @workgroup_size(1)
fn comp_main3() {
let local_d = c3;
}
@stage(compute) @workgroup_size(1)
fn comp_main4() {
let local_d = c4;
}
@stage(compute) @workgroup_size(1)
fn comp_main5() {
let local_d = 1u;
}
)";
{
SingleEntryPoint::Config cfg("comp_main1");
auto* expect = R"(
@id(1001) override c1 : u32 = 1u;
@stage(compute) @workgroup_size(1)
fn comp_main1() {
let local_d = c1;
}
)";
DataMap data;
data.Add<SingleEntryPoint::Config>(cfg);
auto got = Run<SingleEntryPoint>(src, data);
EXPECT_EQ(expect, str(got));
}
{
SingleEntryPoint::Config cfg("comp_main2");
// The decorator is replaced with the one with explicit id
// And should not be affected by other constants stripped away
auto* expect = R"(
@id(1) override c2 : u32 = 1u;
@stage(compute) @workgroup_size(1)
fn comp_main2() {
let local_d = c2;
}
)";
DataMap data;
data.Add<SingleEntryPoint::Config>(cfg);
auto got = Run<SingleEntryPoint>(src, data);
EXPECT_EQ(expect, str(got));
}
{
SingleEntryPoint::Config cfg("comp_main3");
auto* expect = R"(
@id(0) override c3 : u32 = 1u;
@stage(compute) @workgroup_size(1)
fn comp_main3() {
let local_d = c3;
}
)";
DataMap data;
data.Add<SingleEntryPoint::Config>(cfg);
auto got = Run<SingleEntryPoint>(src, data);
EXPECT_EQ(expect, str(got));
}
{
SingleEntryPoint::Config cfg("comp_main4");
auto* expect = R"(
@id(9999) override c4 : u32 = 1u;
@stage(compute) @workgroup_size(1)
fn comp_main4() {
let local_d = c4;
}
)";
DataMap data;
data.Add<SingleEntryPoint::Config>(cfg);
auto got = Run<SingleEntryPoint>(src, data);
EXPECT_EQ(expect, str(got));
}
{
SingleEntryPoint::Config cfg("comp_main5");
auto* expect = R"(
@stage(compute) @workgroup_size(1)
fn comp_main5() {
let local_d = 1u;
}
)";
DataMap data;
data.Add<SingleEntryPoint::Config>(cfg);
auto got = Run<SingleEntryPoint>(src, data);
EXPECT_EQ(expect, str(got));
}
}
TEST_F(SingleEntryPointTest, CalledFunctions) {
auto* src = R"(
fn inner1() {
}
fn inner2() {
}
fn inner_shared() {
}
fn outer1() {
inner1();
inner_shared();
}
fn outer2() {
inner2();
inner_shared();
}
@stage(compute) @workgroup_size(1)
fn comp_main1() {
outer1();
}
@stage(compute) @workgroup_size(1)
fn comp_main2() {
outer2();
}
)";
auto* expect = R"(
fn inner1() {
}
fn inner_shared() {
}
fn outer1() {
inner1();
inner_shared();
}
@stage(compute) @workgroup_size(1)
fn comp_main1() {
outer1();
}
)";
SingleEntryPoint::Config cfg("comp_main1");
DataMap data;
data.Add<SingleEntryPoint::Config>(cfg);
auto got = Run<SingleEntryPoint>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(SingleEntryPointTest, GlobalsReferencedByCalledFunctions) {
auto* src = R"(
var<private> inner1_var : f32;
var<private> inner2_var : f32;
var<private> inner_shared_var : f32;
var<private> outer1_var : f32;
var<private> outer2_var : f32;
fn inner1() {
inner1_var = 0.0;
}
fn inner2() {
inner2_var = 0.0;
}
fn inner_shared() {
inner_shared_var = 0.0;
}
fn outer1() {
inner1();
inner_shared();
outer1_var = 0.0;
}
fn outer2() {
inner2();
inner_shared();
outer2_var = 0.0;
}
@stage(compute) @workgroup_size(1)
fn comp_main1() {
outer1();
}
@stage(compute) @workgroup_size(1)
fn comp_main2() {
outer2();
}
)";
auto* expect = R"(
var<private> inner1_var : f32;
var<private> inner_shared_var : f32;
var<private> outer1_var : f32;
fn inner1() {
inner1_var = 0.0;
}
fn inner_shared() {
inner_shared_var = 0.0;
}
fn outer1() {
inner1();
inner_shared();
outer1_var = 0.0;
}
@stage(compute) @workgroup_size(1)
fn comp_main1() {
outer1();
}
)";
SingleEntryPoint::Config cfg("comp_main1");
DataMap data;
data.Add<SingleEntryPoint::Config>(cfg);
auto got = Run<SingleEntryPoint>(src, data);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,153 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_TEST_HELPER_H_
#define SRC_TINT_TRANSFORM_TEST_HELPER_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "gtest/gtest.h"
#include "src/tint/reader/wgsl/parser.h"
#include "src/tint/transform/manager.h"
#include "src/tint/writer/wgsl/generator.h"
namespace tint {
namespace transform {
/// @param program the program to get an output WGSL string from
/// @returns the output program as a WGSL string, or an error string if the
/// program is not valid.
inline std::string str(const Program& program) {
diag::Formatter::Style style;
style.print_newline_at_end = false;
if (!program.IsValid()) {
return diag::Formatter(style).format(program.Diagnostics());
}
writer::wgsl::Options options;
auto result = writer::wgsl::Generate(&program, options);
if (!result.success) {
return "WGSL writer failed:\n" + result.error;
}
auto res = result.wgsl;
if (res.empty()) {
return res;
}
// The WGSL sometimes has two trailing newlines. Strip them
while (res.back() == '\n') {
res.pop_back();
}
if (res.empty()) {
return res;
}
return "\n" + res + "\n";
}
/// Helper class for testing transforms
template <typename BASE>
class TransformTestBase : public BASE {
public:
/// Transforms and returns the WGSL source `in`, transformed using
/// `transform`.
/// @param transform the transform to apply
/// @param in the input WGSL source
/// @param data the optional DataMap to pass to Transform::Run()
/// @return the transformed output
Output Run(std::string in,
std::unique_ptr<transform::Transform> transform,
const DataMap& data = {}) {
std::vector<std::unique_ptr<transform::Transform>> transforms;
transforms.emplace_back(std::move(transform));
return Run(std::move(in), std::move(transforms), data);
}
/// Transforms and returns the WGSL source `in`, transformed using
/// a transform of type `TRANSFORM`.
/// @param in the input WGSL source
/// @param data the optional DataMap to pass to Transform::Run()
/// @return the transformed output
template <typename... TRANSFORMS>
Output Run(std::string in, const DataMap& data = {}) {
auto file = std::make_unique<Source::File>("test", in);
auto program = reader::wgsl::Parse(file.get());
// Keep this pointer alive after Transform() returns
files_.emplace_back(std::move(file));
return Run<TRANSFORMS...>(std::move(program), data);
}
/// Transforms and returns program `program`, transformed using a transform of
/// type `TRANSFORM`.
/// @param program the input Program
/// @param data the optional DataMap to pass to Transform::Run()
/// @return the transformed output
template <typename... TRANSFORMS>
Output Run(Program&& program, const DataMap& data = {}) {
if (!program.IsValid()) {
return Output(std::move(program));
}
Manager manager;
for (auto* transform_ptr :
std::initializer_list<Transform*>{new TRANSFORMS()...}) {
manager.append(std::unique_ptr<Transform>(transform_ptr));
}
return manager.Run(&program, data);
}
/// @param program the input program
/// @param data the optional DataMap to pass to Transform::Run()
/// @return true if the transform should be run for the given input.
template <typename TRANSFORM>
bool ShouldRun(Program&& program, const DataMap& data = {}) {
EXPECT_TRUE(program.IsValid()) << program.Diagnostics().str();
return TRANSFORM().ShouldRun(&program, data);
}
/// @param in the input WGSL source
/// @param data the optional DataMap to pass to Transform::Run()
/// @return true if the transform should be run for the given input.
template <typename TRANSFORM>
bool ShouldRun(std::string in, const DataMap& data = {}) {
auto file = std::make_unique<Source::File>("test", in);
auto program = reader::wgsl::Parse(file.get());
return ShouldRun<TRANSFORM>(std::move(program), data);
}
/// @param output the output of the transform
/// @returns the output program as a WGSL string, or an error string if the
/// program is not valid.
std::string str(const Output& output) {
return transform::str(output.program);
}
private:
std::vector<std::unique_ptr<Source::File>> files_;
};
using TransformTest = TransformTestBase<testing::Test>;
template <typename T>
using TransformTestWithParam = TransformTestBase<testing::TestWithParam<T>>;
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_TEST_HELPER_H_

View File

@@ -0,0 +1,160 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/transform.h"
#include <algorithm>
#include <string>
#include "src/tint/program_builder.h"
#include "src/tint/sem/atomic_type.h"
#include "src/tint/sem/block_statement.h"
#include "src/tint/sem/depth_multisampled_texture_type.h"
#include "src/tint/sem/for_loop_statement.h"
#include "src/tint/sem/reference_type.h"
#include "src/tint/sem/sampler_type.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::Transform);
TINT_INSTANTIATE_TYPEINFO(tint::transform::Data);
namespace tint {
namespace transform {
Data::Data() = default;
Data::Data(const Data&) = default;
Data::~Data() = default;
Data& Data::operator=(const Data&) = default;
DataMap::DataMap() = default;
DataMap::DataMap(DataMap&&) = default;
DataMap::~DataMap() = default;
DataMap& DataMap::operator=(DataMap&&) = default;
Output::Output() = default;
Output::Output(Program&& p) : program(std::move(p)) {}
Transform::Transform() = default;
Transform::~Transform() = default;
Output Transform::Run(const Program* program,
const DataMap& data /* = {} */) const {
ProgramBuilder builder;
CloneContext ctx(&builder, program);
Output output;
Run(ctx, data, output.data);
output.program = Program(std::move(builder));
return output;
}
void Transform::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
TINT_UNIMPLEMENTED(Transform, ctx.dst->Diagnostics())
<< "Transform::Run() unimplemented for " << TypeInfo().name;
}
bool Transform::ShouldRun(const Program*, const DataMap&) const {
return true;
}
void Transform::RemoveStatement(CloneContext& ctx, const ast::Statement* stmt) {
auto* sem = ctx.src->Sem().Get(stmt);
if (auto* block = tint::As<sem::BlockStatement>(sem->Parent())) {
ctx.Remove(block->Declaration()->statements, stmt);
return;
}
if (tint::Is<sem::ForLoopStatement>(sem->Parent())) {
ctx.Replace(stmt, static_cast<ast::Expression*>(nullptr));
return;
}
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unable to remove statement from parent of type "
<< sem->TypeInfo().name;
}
const ast::Type* Transform::CreateASTTypeFor(CloneContext& ctx,
const sem::Type* ty) {
if (ty->Is<sem::Void>()) {
return ctx.dst->create<ast::Void>();
}
if (ty->Is<sem::I32>()) {
return ctx.dst->create<ast::I32>();
}
if (ty->Is<sem::U32>()) {
return ctx.dst->create<ast::U32>();
}
if (ty->Is<sem::F32>()) {
return ctx.dst->create<ast::F32>();
}
if (ty->Is<sem::Bool>()) {
return ctx.dst->create<ast::Bool>();
}
if (auto* m = ty->As<sem::Matrix>()) {
auto* el = CreateASTTypeFor(ctx, m->type());
return ctx.dst->create<ast::Matrix>(el, m->rows(), m->columns());
}
if (auto* v = ty->As<sem::Vector>()) {
auto* el = CreateASTTypeFor(ctx, v->type());
return ctx.dst->create<ast::Vector>(el, v->Width());
}
if (auto* a = ty->As<sem::Array>()) {
auto* el = CreateASTTypeFor(ctx, a->ElemType());
ast::AttributeList attrs;
if (!a->IsStrideImplicit()) {
attrs.emplace_back(ctx.dst->create<ast::StrideAttribute>(a->Stride()));
}
if (a->IsRuntimeSized()) {
return ctx.dst->ty.array(el, nullptr, std::move(attrs));
} else {
return ctx.dst->ty.array(el, a->Count(), std::move(attrs));
}
}
if (auto* s = ty->As<sem::Struct>()) {
return ctx.dst->create<ast::TypeName>(ctx.Clone(s->Declaration()->name));
}
if (auto* s = ty->As<sem::Reference>()) {
return CreateASTTypeFor(ctx, s->StoreType());
}
if (auto* a = ty->As<sem::Atomic>()) {
return ctx.dst->create<ast::Atomic>(CreateASTTypeFor(ctx, a->Type()));
}
if (auto* t = ty->As<sem::DepthTexture>()) {
return ctx.dst->create<ast::DepthTexture>(t->dim());
}
if (auto* t = ty->As<sem::DepthMultisampledTexture>()) {
return ctx.dst->create<ast::DepthMultisampledTexture>(t->dim());
}
if (ty->Is<sem::ExternalTexture>()) {
return ctx.dst->create<ast::ExternalTexture>();
}
if (auto* t = ty->As<sem::MultisampledTexture>()) {
return ctx.dst->create<ast::MultisampledTexture>(
t->dim(), CreateASTTypeFor(ctx, t->type()));
}
if (auto* t = ty->As<sem::SampledTexture>()) {
return ctx.dst->create<ast::SampledTexture>(
t->dim(), CreateASTTypeFor(ctx, t->type()));
}
if (auto* t = ty->As<sem::StorageTexture>()) {
return ctx.dst->create<ast::StorageTexture>(
t->dim(), t->texel_format(), CreateASTTypeFor(ctx, t->type()),
t->access());
}
if (auto* s = ty->As<sem::Sampler>()) {
return ctx.dst->create<ast::Sampler>(s->kind());
}
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
<< "Unhandled type: " << ty->TypeInfo().name;
return nullptr;
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,199 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_TRANSFORM_H_
#define SRC_TINT_TRANSFORM_TRANSFORM_H_
#include <memory>
#include <unordered_map>
#include <utility>
#include "src/tint/castable.h"
#include "src/tint/program.h"
namespace tint {
namespace transform {
/// Data is the base class for transforms that accept extra input or emit extra
/// output information along with a Program.
class Data : public Castable<Data> {
public:
/// Constructor
Data();
/// Copy constructor
Data(const Data&);
/// Destructor
~Data() override;
/// Assignment operator
/// @returns this Data
Data& operator=(const Data&);
};
/// DataMap is a map of Data unique pointers keyed by the Data's ClassID.
class DataMap {
public:
/// Constructor
DataMap();
/// Move constructor
DataMap(DataMap&&);
/// Constructor
/// @param data_unique_ptrs a variadic list of additional data unique_ptrs
/// produced by the transform
template <typename... DATA>
explicit DataMap(DATA... data_unique_ptrs) {
PutAll(std::forward<DATA>(data_unique_ptrs)...);
}
/// Destructor
~DataMap();
/// Move assignment operator
/// @param rhs the DataMap to move into this DataMap
/// @return this DataMap
DataMap& operator=(DataMap&& rhs);
/// Adds the data into DataMap keyed by the ClassID of type T.
/// @param data the data to add to the DataMap
template <typename T>
void Put(std::unique_ptr<T>&& data) {
static_assert(std::is_base_of<Data, T>::value,
"T does not derive from Data");
map_[&TypeInfo::Of<T>()] = std::move(data);
}
/// Creates the data of type `T` with the provided arguments and adds it into
/// DataMap keyed by the ClassID of type T.
/// @param args the arguments forwarded to the constructor for type T
template <typename T, typename... ARGS>
void Add(ARGS&&... args) {
Put(std::make_unique<T>(std::forward<ARGS>(args)...));
}
/// @returns a pointer to the Data placed into the DataMap with a call to
/// Put()
template <typename T>
T const* Get() const {
auto it = map_.find(&TypeInfo::Of<T>());
if (it == map_.end()) {
return nullptr;
}
return static_cast<T*>(it->second.get());
}
/// Add moves all the data from other into this DataMap
/// @param other the DataMap to move into this DataMap
void Add(DataMap&& other) {
for (auto& it : other.map_) {
map_.emplace(it.first, std::move(it.second));
}
other.map_.clear();
}
private:
template <typename T0>
void PutAll(T0&& first) {
Put(std::forward<T0>(first));
}
template <typename T0, typename... Tn>
void PutAll(T0&& first, Tn&&... remainder) {
Put(std::forward<T0>(first));
PutAll(std::forward<Tn>(remainder)...);
}
std::unordered_map<const TypeInfo*, std::unique_ptr<Data>> map_;
};
/// The return type of Run()
class Output {
public:
/// Constructor
Output();
/// Constructor
/// @param program the program to move into this Output
explicit Output(Program&& program);
/// Constructor
/// @param program_ the program to move into this Output
/// @param data_ a variadic list of additional data unique_ptrs produced by
/// the transform
template <typename... DATA>
Output(Program&& program_, DATA... data_)
: program(std::move(program_)), data(std::forward<DATA>(data_)...) {}
/// The transformed program. May be empty on error.
Program program;
/// Extra output generated by the transforms.
DataMap data;
};
/// Interface for Program transforms
class Transform : public Castable<Transform> {
public:
/// Constructor
Transform();
/// Destructor
~Transform() override;
/// Runs the transform on `program`, returning the transformation result.
/// @param program the source program to transform
/// @param data optional extra transform-specific input data
/// @returns the transformation result
virtual Output Run(const Program* program, const DataMap& data = {}) const;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
virtual bool ShouldRun(const Program* program,
const DataMap& data = {}) const;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
virtual void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const;
/// Removes the statement `stmt` from the transformed program.
/// RemoveStatement handles edge cases, like statements in the initializer and
/// continuing of for-loops.
/// @param ctx the clone context
/// @param stmt the statement to remove when the program is cloned
static void RemoveStatement(CloneContext& ctx, const ast::Statement* stmt);
/// CreateASTTypeFor constructs new ast::Type nodes that reconstructs the
/// semantic type `ty`.
/// @param ctx the clone context
/// @param ty the semantic type to reconstruct
/// @returns a ast::Type that when resolved, will produce the semantic type
/// `ty`.
static const ast::Type* CreateASTTypeFor(CloneContext& ctx,
const sem::Type* ty);
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_TRANSFORM_H_

View File

@@ -0,0 +1,123 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/transform.h"
#include "src/tint/clone_context.h"
#include "src/tint/program_builder.h"
#include "gtest/gtest.h"
namespace tint {
namespace transform {
namespace {
// Inherit from Transform so we have access to protected methods
struct CreateASTTypeForTest : public testing::Test, public Transform {
Output Run(const Program*, const DataMap&) const override { return {}; }
const ast::Type* create(
std::function<sem::Type*(ProgramBuilder&)> create_sem_type) {
ProgramBuilder sem_type_builder;
auto* sem_type = create_sem_type(sem_type_builder);
Program program(std::move(sem_type_builder));
CloneContext ctx(&ast_type_builder, &program, false);
return CreateASTTypeFor(ctx, sem_type);
}
ProgramBuilder ast_type_builder;
};
TEST_F(CreateASTTypeForTest, Basic) {
EXPECT_TRUE(create([](ProgramBuilder& b) {
return b.create<sem::I32>();
})->Is<ast::I32>());
EXPECT_TRUE(create([](ProgramBuilder& b) {
return b.create<sem::U32>();
})->Is<ast::U32>());
EXPECT_TRUE(create([](ProgramBuilder& b) {
return b.create<sem::F32>();
})->Is<ast::F32>());
EXPECT_TRUE(create([](ProgramBuilder& b) {
return b.create<sem::Bool>();
})->Is<ast::Bool>());
EXPECT_TRUE(create([](ProgramBuilder& b) {
return b.create<sem::Void>();
})->Is<ast::Void>());
}
TEST_F(CreateASTTypeForTest, Matrix) {
auto* mat = create([](ProgramBuilder& b) {
auto* column_type = b.create<sem::Vector>(b.create<sem::F32>(), 2u);
return b.create<sem::Matrix>(column_type, 3u);
});
ASSERT_TRUE(mat->Is<ast::Matrix>());
ASSERT_TRUE(mat->As<ast::Matrix>()->type->Is<ast::F32>());
ASSERT_EQ(mat->As<ast::Matrix>()->columns, 3u);
ASSERT_EQ(mat->As<ast::Matrix>()->rows, 2u);
}
TEST_F(CreateASTTypeForTest, Vector) {
auto* vec = create([](ProgramBuilder& b) {
return b.create<sem::Vector>(b.create<sem::F32>(), 2);
});
ASSERT_TRUE(vec->Is<ast::Vector>());
ASSERT_TRUE(vec->As<ast::Vector>()->type->Is<ast::F32>());
ASSERT_EQ(vec->As<ast::Vector>()->width, 2u);
}
TEST_F(CreateASTTypeForTest, ArrayImplicitStride) {
auto* arr = create([](ProgramBuilder& b) {
return b.create<sem::Array>(b.create<sem::F32>(), 2, 4, 4, 32u, 32u);
});
ASSERT_TRUE(arr->Is<ast::Array>());
ASSERT_TRUE(arr->As<ast::Array>()->type->Is<ast::F32>());
ASSERT_EQ(arr->As<ast::Array>()->attributes.size(), 0u);
auto* size = arr->As<ast::Array>()->count->As<ast::IntLiteralExpression>();
ASSERT_NE(size, nullptr);
EXPECT_EQ(size->ValueAsI32(), 2);
}
TEST_F(CreateASTTypeForTest, ArrayNonImplicitStride) {
auto* arr = create([](ProgramBuilder& b) {
return b.create<sem::Array>(b.create<sem::F32>(), 2, 4, 4, 64u, 32u);
});
ASSERT_TRUE(arr->Is<ast::Array>());
ASSERT_TRUE(arr->As<ast::Array>()->type->Is<ast::F32>());
ASSERT_EQ(arr->As<ast::Array>()->attributes.size(), 1u);
ASSERT_TRUE(arr->As<ast::Array>()->attributes[0]->Is<ast::StrideAttribute>());
ASSERT_EQ(
arr->As<ast::Array>()->attributes[0]->As<ast::StrideAttribute>()->stride,
64u);
auto* size = arr->As<ast::Array>()->count->As<ast::IntLiteralExpression>();
ASSERT_NE(size, nullptr);
EXPECT_EQ(size->ValueAsI32(), 2);
}
TEST_F(CreateASTTypeForTest, Struct) {
auto* str = create([](ProgramBuilder& b) {
auto* decl = b.Structure("S", {}, {});
return b.create<sem::Struct>(decl, decl->name, sem::StructMemberList{},
4 /* align */, 4 /* size */,
4 /* size_no_padding */);
});
ASSERT_TRUE(str->Is<ast::TypeName>());
EXPECT_EQ(ast_type_builder.Symbols().NameFor(str->As<ast::TypeName>()->name),
"S");
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,99 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/unshadow.h"
#include <memory>
#include <unordered_map>
#include <utility>
#include "src/tint/program_builder.h"
#include "src/tint/sem/block_statement.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/statement.h"
#include "src/tint/sem/variable.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::Unshadow);
namespace tint {
namespace transform {
/// The PIMPL state for the Unshadow transform
struct Unshadow::State {
/// The clone context
CloneContext& ctx;
/// Constructor
/// @param context the clone context
explicit State(CloneContext& context) : ctx(context) {}
/// Performs the transformation
void Run() {
auto& sem = ctx.src->Sem();
// Maps a variable to its new name.
std::unordered_map<const sem::Variable*, Symbol> renamed_to;
auto rename = [&](const sem::Variable* var) -> const ast::Variable* {
auto* decl = var->Declaration();
auto name = ctx.src->Symbols().NameFor(decl->symbol);
auto symbol = ctx.dst->Symbols().New(name);
renamed_to.emplace(var, symbol);
auto source = ctx.Clone(decl->source);
auto* type = ctx.Clone(decl->type);
auto* constructor = ctx.Clone(decl->constructor);
auto attributes = ctx.Clone(decl->attributes);
return ctx.dst->create<ast::Variable>(
source, symbol, decl->declared_storage_class, decl->declared_access,
type, decl->is_const, decl->is_overridable, constructor, attributes);
};
ctx.ReplaceAll([&](const ast::Variable* var) -> const ast::Variable* {
if (auto* local = sem.Get<sem::LocalVariable>(var)) {
if (local->Shadows()) {
return rename(local);
}
}
if (auto* param = sem.Get<sem::Parameter>(var)) {
if (param->Shadows()) {
return rename(param);
}
}
return nullptr;
});
ctx.ReplaceAll([&](const ast::IdentifierExpression* ident)
-> const tint::ast::IdentifierExpression* {
if (auto* user = sem.Get<sem::VariableUser>(ident)) {
auto it = renamed_to.find(user->Variable());
if (it != renamed_to.end()) {
return ctx.dst->Expr(it->second);
}
}
return nullptr;
});
ctx.Clone();
}
};
Unshadow::Unshadow() = default;
Unshadow::~Unshadow() = default;
void Unshadow::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
State(ctx).Run();
}
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,50 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_UNSHADOW_H_
#define SRC_TINT_TRANSFORM_UNSHADOW_H_
#include "src/tint/transform/transform.h"
namespace tint {
namespace transform {
/// Unshadow is a Transform that renames any variables that shadow another
/// variable.
class Unshadow : public Castable<Unshadow, Transform> {
public:
/// Constructor
Unshadow();
/// Destructor
~Unshadow() override;
protected:
struct State;
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TINT_TRANSFORM_UNSHADOW_H_

View File

@@ -0,0 +1,609 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/unshadow.h"
#include "src/tint/transform/test_helper.h"
namespace tint {
namespace transform {
namespace {
using UnshadowTest = TransformTest;
TEST_F(UnshadowTest, EmptyModule) {
auto* src = "";
auto* expect = "";
auto got = Run<Unshadow>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, Noop) {
auto* src = R"(
var<private> a : i32;
let b : i32 = 1;
fn F(c : i32) {
var d : i32;
let e : i32 = 1;
{
var f : i32;
let g : i32 = 1;
}
}
)";
auto* expect = src;
auto got = Run<Unshadow>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsAlias) {
auto* src = R"(
type a = i32;
fn X() {
var a = false;
}
fn Y() {
let a = true;
}
)";
auto* expect = R"(
type a = i32;
fn X() {
var a_1 = false;
}
fn Y() {
let a_2 = true;
}
)";
auto got = Run<Unshadow>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsAlias_OutOfOrder) {
auto* src = R"(
fn X() {
var a = false;
}
fn Y() {
let a = true;
}
type a = i32;
)";
auto* expect = R"(
fn X() {
var a_1 = false;
}
fn Y() {
let a_2 = true;
}
type a = i32;
)";
auto got = Run<Unshadow>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsStruct) {
auto* src = R"(
struct a {
m : i32;
};
fn X() {
var a = true;
}
fn Y() {
let a = false;
}
)";
auto* expect = R"(
struct a {
m : i32;
}
fn X() {
var a_1 = true;
}
fn Y() {
let a_2 = false;
}
)";
auto got = Run<Unshadow>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsStruct_OutOfOrder) {
auto* src = R"(
fn X() {
var a = true;
}
fn Y() {
let a = false;
}
struct a {
m : i32;
};
)";
auto* expect = R"(
fn X() {
var a_1 = true;
}
fn Y() {
let a_2 = false;
}
struct a {
m : i32;
}
)";
auto got = Run<Unshadow>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsFunction) {
auto* src = R"(
fn a() {
var a = true;
var b = false;
}
fn b() {
let a = true;
let b = false;
}
)";
auto* expect = R"(
fn a() {
var a_1 = true;
var b_1 = false;
}
fn b() {
let a_2 = true;
let b_2 = false;
}
)";
auto got = Run<Unshadow>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsFunction_OutOfOrder) {
auto* src = R"(
fn b() {
let a = true;
let b = false;
}
fn a() {
var a = true;
var b = false;
}
)";
auto* expect = R"(
fn b() {
let a_1 = true;
let b_1 = false;
}
fn a() {
var a_2 = true;
var b_2 = false;
}
)";
auto got = Run<Unshadow>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsGlobalVar) {
auto* src = R"(
var<private> a : i32;
fn X() {
var a = (a == 123);
}
fn Y() {
let a = (a == 321);
}
)";
auto* expect = R"(
var<private> a : i32;
fn X() {
var a_1 = (a == 123);
}
fn Y() {
let a_2 = (a == 321);
}
)";
auto got = Run<Unshadow>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsGlobalVar_OutOfOrder) {
auto* src = R"(
fn X() {
var a = (a == 123);
}
fn Y() {
let a = (a == 321);
}
var<private> a : i32;
)";
auto* expect = R"(
fn X() {
var a_1 = (a == 123);
}
fn Y() {
let a_2 = (a == 321);
}
var<private> a : i32;
)";
auto got = Run<Unshadow>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsGlobalLet) {
auto* src = R"(
let a : i32 = 1;
fn X() {
var a = (a == 123);
}
fn Y() {
let a = (a == 321);
}
)";
auto* expect = R"(
let a : i32 = 1;
fn X() {
var a_1 = (a == 123);
}
fn Y() {
let a_2 = (a == 321);
}
)";
auto got = Run<Unshadow>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsGlobalLet_OutOfOrder) {
auto* src = R"(
fn X() {
var a = (a == 123);
}
fn Y() {
let a = (a == 321);
}
let a : i32 = 1;
)";
auto* expect = R"(
fn X() {
var a_1 = (a == 123);
}
fn Y() {
let a_2 = (a == 321);
}
let a : i32 = 1;
)";
auto got = Run<Unshadow>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsLocalVar) {
auto* src = R"(
fn X() {
var a : i32;
{
var a = (a == 123);
}
{
let a = (a == 321);
}
}
)";
auto* expect = R"(
fn X() {
var a : i32;
{
var a_1 = (a == 123);
}
{
let a_2 = (a == 321);
}
}
)";
auto got = Run<Unshadow>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsLocalLet) {
auto* src = R"(
fn X() {
let a = 1;
{
var a = (a == 123);
}
{
let a = (a == 321);
}
}
)";
auto* expect = R"(
fn X() {
let a = 1;
{
var a_1 = (a == 123);
}
{
let a_2 = (a == 321);
}
}
)";
auto got = Run<Unshadow>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, LocalShadowsParam) {
auto* src = R"(
fn F(a : i32) {
{
var a = (a == 123);
}
{
let a = (a == 321);
}
}
)";
auto* expect = R"(
fn F(a : i32) {
{
var a_1 = (a == 123);
}
{
let a_2 = (a == 321);
}
}
)";
auto got = Run<Unshadow>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, ParamShadowsFunction) {
auto* src = R"(
fn a(a : i32) {
{
var a = (a == 123);
}
{
let a = (a == 321);
}
}
)";
auto* expect = R"(
fn a(a_1 : i32) {
{
var a_2 = (a_1 == 123);
}
{
let a_3 = (a_1 == 321);
}
}
)";
auto got = Run<Unshadow>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, ParamShadowsGlobalVar) {
auto* src = R"(
var<private> a : i32;
fn F(a : bool) {
}
)";
auto* expect = R"(
var<private> a : i32;
fn F(a_1 : bool) {
}
)";
auto got = Run<Unshadow>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, ParamShadowsGlobalLet) {
auto* src = R"(
let a : i32 = 1;
fn F(a : bool) {
}
)";
auto* expect = R"(
let a : i32 = 1;
fn F(a_1 : bool) {
}
)";
auto got = Run<Unshadow>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, ParamShadowsGlobalLet_OutOfOrder) {
auto* src = R"(
fn F(a : bool) {
}
let a : i32 = 1;
)";
auto* expect = R"(
fn F(a_1 : bool) {
}
let a : i32 = 1;
)";
auto got = Run<Unshadow>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, ParamShadowsAlias) {
auto* src = R"(
type a = i32;
fn F(a : a) {
{
var a = (a == 123);
}
{
let a = (a == 321);
}
}
)";
auto* expect = R"(
type a = i32;
fn F(a_1 : a) {
{
var a_2 = (a_1 == 123);
}
{
let a_3 = (a_1 == 321);
}
}
)";
auto got = Run<Unshadow>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(UnshadowTest, ParamShadowsAlias_OutOfOrder) {
auto* src = R"(
fn F(a : a) {
{
var a = (a == 123);
}
{
let a = (a == 321);
}
}
type a = i32;
)";
auto* expect = R"(
fn F(a_1 : a) {
{
var a_2 = (a_1 == 123);
}
{
let a_3 = (a_1 == 321);
}
}
type a = i32;
)";
auto got = Run<Unshadow>(src);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@@ -0,0 +1,327 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/utils/hoist_to_decl_before.h"
#include <unordered_map>
#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/sem/block_statement.h"
#include "src/tint/sem/for_loop_statement.h"
#include "src/tint/sem/if_statement.h"
#include "src/tint/sem/reference_type.h"
#include "src/tint/sem/variable.h"
#include "src/tint/utils/reverse.h"
namespace tint::transform {
/// Private implementation of HoistToDeclBefore transform
class HoistToDeclBefore::State {
CloneContext& ctx;
ProgramBuilder& b;
/// Holds information about a for-loop that needs to be decomposed into a
/// loop, so that declaration statements can be inserted before the
/// condition expression or continuing statement.
struct LoopInfo {
ast::StatementList cond_decls;
ast::StatementList cont_decls;
};
/// Holds information about 'if's with 'else-if' statements that need to be
/// decomposed into 'if {else}' so that declaration statements can be
/// inserted before the condition expression.
struct IfInfo {
/// Info for each else-if that needs decomposing
struct ElseIfInfo {
/// Decls to insert before condition
ast::StatementList cond_decls;
};
/// 'else if's that need to be decomposed to 'else { if }'
std::unordered_map<const sem::ElseStatement*, ElseIfInfo> else_ifs;
};
/// For-loops that need to be decomposed to loops.
std::unordered_map<const sem::ForLoopStatement*, LoopInfo> loops;
/// If statements with 'else if's that need to be decomposed to 'else { if
/// }'
std::unordered_map<const sem::IfStatement*, IfInfo> ifs;
// Inserts `decl` before `sem_expr`, possibly marking a for-loop to be
// converted to a loop, or an else-if to an else { if }.
bool InsertBefore(const sem::Expression* sem_expr,
const ast::VariableDeclStatement* decl) {
auto* sem_stmt = sem_expr->Stmt();
auto* stmt = sem_stmt->Declaration();
if (auto* else_if = sem_stmt->As<sem::ElseStatement>()) {
// Expression used in 'else if' condition.
// Need to convert 'else if' to 'else { if }'.
auto& if_info = ifs[else_if->Parent()->As<sem::IfStatement>()];
if_info.else_ifs[else_if].cond_decls.push_back(decl);
return true;
}
if (auto* fl = sem_stmt->As<sem::ForLoopStatement>()) {
// Expression used in for-loop condition.
// For-loop needs to be decomposed to a loop.
loops[fl].cond_decls.emplace_back(decl);
return true;
}
auto* parent = sem_stmt->Parent(); // The statement's parent
if (auto* block = parent->As<sem::BlockStatement>()) {
// Expression's statement sits in a block. Simple case.
// Insert the decl before the parent statement
ctx.InsertBefore(block->Declaration()->statements, stmt, decl);
return true;
}
if (auto* fl = parent->As<sem::ForLoopStatement>()) {
// Expression is used in a for-loop. These require special care.
if (fl->Declaration()->initializer == stmt) {
// Expression used in for-loop initializer.
// Insert the let above the for-loop.
ctx.InsertBefore(fl->Block()->Declaration()->statements,
fl->Declaration(), decl);
return true;
}
if (fl->Declaration()->continuing == stmt) {
// Expression used in for-loop continuing.
// For-loop needs to be decomposed to a loop.
loops[fl].cont_decls.emplace_back(decl);
return true;
}
TINT_ICE(Transform, b.Diagnostics())
<< "unhandled use of expression in for-loop";
return false;
}
TINT_ICE(Transform, b.Diagnostics())
<< "unhandled expression parent statement type: "
<< parent->TypeInfo().name;
return false;
}
// Converts any for-loops marked for conversion to loops, inserting
// registered declaration statements before the condition or continuing
// statement.
void ForLoopsToLoops() {
if (loops.empty()) {
return;
}
// At least one for-loop needs to be transformed into a loop.
ctx.ReplaceAll(
[&](const ast::ForLoopStatement* stmt) -> const ast::Statement* {
auto& sem = ctx.src->Sem();
if (auto* fl = sem.Get(stmt)) {
if (auto it = loops.find(fl); it != loops.end()) {
auto& info = it->second;
auto* for_loop = fl->Declaration();
// For-loop needs to be decomposed to a loop.
// Build the loop body's statements.
// Start with any let declarations for the conditional
// expression.
auto body_stmts = info.cond_decls;
// If the for-loop has a condition, emit this next as:
// if (!cond) { break; }
if (auto* cond = for_loop->condition) {
// !condition
auto* not_cond = b.create<ast::UnaryOpExpression>(
ast::UnaryOp::kNot, ctx.Clone(cond));
// { break; }
auto* break_body = b.Block(b.create<ast::BreakStatement>());
// if (!condition) { break; }
body_stmts.emplace_back(b.If(not_cond, break_body));
}
// Next emit the for-loop body
body_stmts.emplace_back(ctx.Clone(for_loop->body));
// Finally create the continuing block if there was one.
const ast::BlockStatement* continuing = nullptr;
if (auto* cont = for_loop->continuing) {
// Continuing block starts with any let declarations used by
// the continuing.
auto cont_stmts = info.cont_decls;
cont_stmts.emplace_back(ctx.Clone(cont));
continuing = b.Block(cont_stmts);
}
auto* body = b.Block(body_stmts);
auto* loop = b.Loop(body, continuing);
if (auto* init = for_loop->initializer) {
return b.Block(ctx.Clone(init), loop);
}
return loop;
}
}
return nullptr;
});
}
void ElseIfsToElseWithNestedIfs() {
if (ifs.empty()) {
return;
}
ctx.ReplaceAll([&](const ast::IfStatement* if_stmt) //
-> const ast::IfStatement* {
auto& sem = ctx.src->Sem();
auto* sem_if = sem.Get(if_stmt);
if (!sem_if) {
return nullptr;
}
auto it = ifs.find(sem_if);
if (it == ifs.end()) {
return nullptr;
}
auto& if_info = it->second;
// This if statement has "else if"s that need to be converted to "else
// { if }"s
ast::ElseStatementList next_else_stmts;
next_else_stmts.reserve(if_stmt->else_statements.size());
for (auto* else_stmt : utils::Reverse(if_stmt->else_statements)) {
if (else_stmt->condition == nullptr) {
// The last 'else', keep as is
next_else_stmts.insert(next_else_stmts.begin(), ctx.Clone(else_stmt));
} else {
auto* sem_else_if = sem.Get(else_stmt);
auto it2 = if_info.else_ifs.find(sem_else_if);
if (it2 == if_info.else_ifs.end()) {
// 'else if' we don't need to modify (no decls to insert), so
// keep as is
next_else_stmts.insert(next_else_stmts.begin(),
ctx.Clone(else_stmt));
} else {
// 'else if' we need to replace with 'else <decls> { if }'
auto& else_if_info = it2->second;
// Build the else body's statements, starting with let decls for
// the conditional expression
auto& body_stmts = else_if_info.cond_decls;
// Build nested if
auto* cond = ctx.Clone(else_stmt->condition);
auto* body = ctx.Clone(else_stmt->body);
body_stmts.emplace_back(b.If(cond, body, next_else_stmts));
// Build else
auto* else_with_nested_if = b.Else(b.Block(body_stmts));
// This will be used in parent if (either another nested if, or
// top-level if)
next_else_stmts = {else_with_nested_if};
}
}
}
// Build a new top-level if with new else statements
if (next_else_stmts.empty()) {
TINT_ICE(Transform, b.Diagnostics())
<< "Expected else statements to insert into new if";
}
auto* cond = ctx.Clone(if_stmt->condition);
auto* body = ctx.Clone(if_stmt->body);
auto* new_if = b.If(cond, body, next_else_stmts);
return new_if;
});
}
public:
/// Constructor
/// @param ctx_in the clone context
explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {}
/// Hoists `expr` to a `let` or `var` with optional `decl_name`, inserting it
/// before `before_expr`.
/// @param before_expr expression to insert `expr` before
/// @param expr expression to hoist
/// @param as_const hoist to `let` if true, otherwise to `var`
/// @param decl_name optional name to use for the variable/constant name
/// @return true on success
bool HoistToDeclBefore(const sem::Expression* before_expr,
const ast::Expression* expr,
bool as_const,
const char* decl_name = "") {
auto name = b.Symbols().New(decl_name);
auto* sem_expr = ctx.src->Sem().Get(expr);
bool is_ref =
sem_expr &&
!sem_expr->Is<sem::VariableUser>() // Don't need to take a ref to a var
&& sem_expr->Type()->Is<sem::Reference>();
auto* expr_clone = ctx.Clone(expr);
if (is_ref) {
expr_clone = b.AddressOf(expr_clone);
}
// Construct the let/var that holds the hoisted expr
auto* v = as_const ? b.Const(name, nullptr, expr_clone)
: b.Var(name, nullptr, expr_clone);
auto* decl = b.Decl(v);
if (!InsertBefore(before_expr, decl)) {
return false;
}
// Replace the initializer expression with a reference to the let
const ast::Expression* new_expr = b.Expr(name);
if (is_ref) {
new_expr = b.Deref(new_expr);
}
ctx.Replace(expr, new_expr);
return true;
}
/// Applies any scheduled insertions from previous calls to Add() to
/// CloneContext. Call this once before ctx.Clone().
/// @return true on success
bool Apply() {
ForLoopsToLoops();
ElseIfsToElseWithNestedIfs();
return true;
}
};
HoistToDeclBefore::HoistToDeclBefore(CloneContext& ctx)
: state_(std::make_unique<State>(ctx)) {}
HoistToDeclBefore::~HoistToDeclBefore() {}
bool HoistToDeclBefore::Add(const sem::Expression* before_expr,
const ast::Expression* expr,
bool as_const,
const char* decl_name) {
return state_->HoistToDeclBefore(before_expr, expr, as_const, decl_name);
}
bool HoistToDeclBefore::Apply() {
return state_->Apply();
}
} // namespace tint::transform

View File

@@ -0,0 +1,61 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_UTILS_HOIST_TO_DECL_BEFORE_H_
#define SRC_TINT_TRANSFORM_UTILS_HOIST_TO_DECL_BEFORE_H_
#include <memory>
#include "src/tint/sem/expression.h"
#include "src/tint/transform/transform.h"
namespace tint::transform {
/// Utility class that can be used to hoist expressions before other
/// expressions, possibly converting 'for' loops to 'loop's and 'else if to
// 'else if'.
class HoistToDeclBefore {
public:
/// Constructor
/// @param ctx the clone context
explicit HoistToDeclBefore(CloneContext& ctx);
/// Destructor
~HoistToDeclBefore();
/// Hoists `expr` to a `let` or `var` with optional `decl_name`, inserting it
/// before `before_expr`.
/// @param before_expr expression to insert `expr` before
/// @param expr expression to hoist
/// @param as_const hoist to `let` if true, otherwise to `var`
/// @param decl_name optional name to use for the variable/constant name
/// @return true on success
bool Add(const sem::Expression* before_expr,
const ast::Expression* expr,
bool as_const,
const char* decl_name = "");
/// Applies any scheduled insertions from previous calls to Add() to
/// CloneContext. Call this once before ctx.Clone().
/// @return true on success
bool Apply();
private:
class State;
std::unique_ptr<State> state_;
};
} // namespace tint::transform
#endif // SRC_TINT_TRANSFORM_UTILS_HOIST_TO_DECL_BEFORE_H_

View File

@@ -0,0 +1,291 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <utility>
#include "gtest/gtest-spi.h"
#include "src/tint/program_builder.h"
#include "src/tint/transform/test_helper.h"
#include "src/tint/transform/utils/hoist_to_decl_before.h"
namespace tint::transform {
namespace {
using HoistToDeclBeforeTest = ::testing::Test;
TEST_F(HoistToDeclBeforeTest, VarInit) {
// fn f() {
// var a = 1;
// }
ProgramBuilder b;
auto* expr = b.Expr(1);
auto* var = b.Decl(b.Var("a", nullptr, expr));
b.Func("f", {}, b.ty.void_(), {var});
Program original(std::move(b));
ProgramBuilder cloned_b;
CloneContext ctx(&cloned_b, &original);
HoistToDeclBefore hoistToDeclBefore(ctx);
auto* sem_expr = ctx.src->Sem().Get(expr);
hoistToDeclBefore.Add(sem_expr, expr, true);
hoistToDeclBefore.Apply();
ctx.Clone();
Program cloned(std::move(cloned_b));
auto* expect = R"(
fn f() {
let tint_symbol = 1;
var a = tint_symbol;
}
)";
EXPECT_EQ(expect, str(cloned));
}
TEST_F(HoistToDeclBeforeTest, ForLoopInit) {
// fn f() {
// for(var a = 1; true; ) {
// }
// }
ProgramBuilder b;
auto* expr = b.Expr(1);
auto* s =
b.For(b.Decl(b.Var("a", nullptr, expr)), b.Expr(true), {}, b.Block());
b.Func("f", {}, b.ty.void_(), {s});
Program original(std::move(b));
ProgramBuilder cloned_b;
CloneContext ctx(&cloned_b, &original);
HoistToDeclBefore hoistToDeclBefore(ctx);
auto* sem_expr = ctx.src->Sem().Get(expr);
hoistToDeclBefore.Add(sem_expr, expr, true);
hoistToDeclBefore.Apply();
ctx.Clone();
Program cloned(std::move(cloned_b));
auto* expect = R"(
fn f() {
let tint_symbol = 1;
for(var a = tint_symbol; true; ) {
}
}
)";
EXPECT_EQ(expect, str(cloned));
}
TEST_F(HoistToDeclBeforeTest, ForLoopCond) {
// fn f() {
// var a : bool;
// for(; a; ) {
// }
// }
ProgramBuilder b;
auto* var = b.Decl(b.Var("a", b.ty.bool_()));
auto* expr = b.Expr("a");
auto* s = b.For({}, expr, {}, b.Block());
b.Func("f", {}, b.ty.void_(), {var, s});
Program original(std::move(b));
ProgramBuilder cloned_b;
CloneContext ctx(&cloned_b, &original);
HoistToDeclBefore hoistToDeclBefore(ctx);
auto* sem_expr = ctx.src->Sem().Get(expr);
hoistToDeclBefore.Add(sem_expr, expr, true);
hoistToDeclBefore.Apply();
ctx.Clone();
Program cloned(std::move(cloned_b));
auto* expect = R"(
fn f() {
var a : bool;
loop {
let tint_symbol = a;
if (!(tint_symbol)) {
break;
}
{
}
}
}
)";
EXPECT_EQ(expect, str(cloned));
}
TEST_F(HoistToDeclBeforeTest, ForLoopCont) {
// fn f() {
// for(; true; var a = 1) {
// }
// }
ProgramBuilder b;
auto* expr = b.Expr(1);
auto* s =
b.For({}, b.Expr(true), b.Decl(b.Var("a", nullptr, expr)), b.Block());
b.Func("f", {}, b.ty.void_(), {s});
Program original(std::move(b));
ProgramBuilder cloned_b;
CloneContext ctx(&cloned_b, &original);
HoistToDeclBefore hoistToDeclBefore(ctx);
auto* sem_expr = ctx.src->Sem().Get(expr);
hoistToDeclBefore.Add(sem_expr, expr, true);
hoistToDeclBefore.Apply();
ctx.Clone();
Program cloned(std::move(cloned_b));
auto* expect = R"(
fn f() {
loop {
if (!(true)) {
break;
}
{
}
continuing {
let tint_symbol = 1;
var a = tint_symbol;
}
}
}
)";
EXPECT_EQ(expect, str(cloned));
}
TEST_F(HoistToDeclBeforeTest, ElseIf) {
// fn f() {
// var a : bool;
// if (true) {
// } else if (a) {
// } else {
// }
// }
ProgramBuilder b;
auto* var = b.Decl(b.Var("a", b.ty.bool_()));
auto* expr = b.Expr("a");
auto* s = b.If(b.Expr(true), b.Block(), //
b.Else(expr, b.Block()), //
b.Else(b.Block()));
b.Func("f", {}, b.ty.void_(), {var, s});
Program original(std::move(b));
ProgramBuilder cloned_b;
CloneContext ctx(&cloned_b, &original);
HoistToDeclBefore hoistToDeclBefore(ctx);
auto* sem_expr = ctx.src->Sem().Get(expr);
hoistToDeclBefore.Add(sem_expr, expr, true);
hoistToDeclBefore.Apply();
ctx.Clone();
Program cloned(std::move(cloned_b));
auto* expect = R"(
fn f() {
var a : bool;
if (true) {
} else {
let tint_symbol = a;
if (tint_symbol) {
} else {
}
}
}
)";
EXPECT_EQ(expect, str(cloned));
}
TEST_F(HoistToDeclBeforeTest, Array1D) {
// fn f() {
// var a : array<i32, 10>;
// var b = a[0];
// }
ProgramBuilder b;
auto* var1 = b.Decl(b.Var("a", b.ty.array<ProgramBuilder::i32, 10>()));
auto* expr = b.IndexAccessor("a", 0);
auto* var2 = b.Decl(b.Var("b", nullptr, expr));
b.Func("f", {}, b.ty.void_(), {var1, var2});
Program original(std::move(b));
ProgramBuilder cloned_b;
CloneContext ctx(&cloned_b, &original);
HoistToDeclBefore hoistToDeclBefore(ctx);
auto* sem_expr = ctx.src->Sem().Get(expr);
hoistToDeclBefore.Add(sem_expr, expr, true);
hoistToDeclBefore.Apply();
ctx.Clone();
Program cloned(std::move(cloned_b));
auto* expect = R"(
fn f() {
var a : array<i32, 10>;
let tint_symbol = &(a[0]);
var b = *(tint_symbol);
}
)";
EXPECT_EQ(expect, str(cloned));
}
TEST_F(HoistToDeclBeforeTest, Array2D) {
// fn f() {
// var a : array<array<i32, 10>, 10>;
// var b = a[0][0];
// }
ProgramBuilder b;
auto* var1 =
b.Decl(b.Var("a", b.ty.array(b.ty.array<ProgramBuilder::i32, 10>(), 10)));
auto* expr = b.IndexAccessor(b.IndexAccessor("a", 0), 0);
auto* var2 = b.Decl(b.Var("b", nullptr, expr));
b.Func("f", {}, b.ty.void_(), {var1, var2});
Program original(std::move(b));
ProgramBuilder cloned_b;
CloneContext ctx(&cloned_b, &original);
HoistToDeclBefore hoistToDeclBefore(ctx);
auto* sem_expr = ctx.src->Sem().Get(expr);
hoistToDeclBefore.Add(sem_expr, expr, true);
hoistToDeclBefore.Apply();
ctx.Clone();
Program cloned(std::move(cloned_b));
auto* expect = R"(
fn f() {
var a : array<array<i32, 10>, 10>;
let tint_symbol = &(a[0][0]);
var b = *(tint_symbol);
}
)";
EXPECT_EQ(expect, str(cloned));
}
} // namespace
} // namespace tint::transform

View File

@@ -0,0 +1,68 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/var_for_dynamic_index.h"
#include "src/tint/program_builder.h"
#include "src/tint/transform/utils/hoist_to_decl_before.h"
namespace tint::transform {
VarForDynamicIndex::VarForDynamicIndex() = default;
VarForDynamicIndex::~VarForDynamicIndex() = default;
void VarForDynamicIndex::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
HoistToDeclBefore hoist_to_decl_before(ctx);
// Extracts array and matrix values that are dynamically indexed to a
// temporary `var` local that is then indexed.
auto dynamic_index_to_var =
[&](const ast::IndexAccessorExpression* access_expr) {
auto* index_expr = access_expr->index;
auto* object_expr = access_expr->object;
auto& sem = ctx.src->Sem();
if (sem.Get(index_expr)->ConstantValue()) {
// Index expression resolves to a compile time value.
// As this isn't a dynamic index, we can ignore this.
return true;
}
auto* indexed = sem.Get(object_expr);
if (!indexed->Type()->IsAnyOf<sem::Array, sem::Matrix>()) {
// We only care about array and matrices.
return true;
}
// TODO(bclayton): group multiple accesses in the same object.
// e.g. arr[i] + arr[i+1] // Don't create two vars for this
return hoist_to_decl_before.Add(indexed, object_expr, false,
"var_for_index");
};
for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* access_expr = node->As<ast::IndexAccessorExpression>()) {
if (!dynamic_index_to_var(access_expr)) {
return;
}
}
}
hoist_to_decl_before.Apply();
ctx.Clone();
}
} // namespace tint::transform

Some files were not shown because too many files have changed in this diff Show More