optimizations: Implement transform::ShouldRun()

This change adds an override for Transform::ShouldRun() for many of the transforms that can trivially detect whether running would be a no-op or not. Most programs do not require all the transforms to be run, and by skipping those that are not needed, significant performance wins can be had.

This change also removes Transform::Requires() and Program::HasTransformApplied(). This makes little sense now that transforms can be skipped, and the usefulness of this information has been severely reduced since the introduction of transforms that need to be run more than once.
Instread, just document on the transform class what the expectations are.

Issue: tint:1383
Change-Id: I1a6f27cc4ba61ca1475a4ba912c465db619f76c7
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/77121
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-01-25 21:36:04 +00:00 committed by Tint LUCI CQ
parent 12d54d746e
commit 800b8e3175
78 changed files with 909 additions and 298 deletions

View File

@ -57,7 +57,6 @@ Symbol CloneContext::Clone(Symbol s) {
void CloneContext::Clone() {
dst->AST().Copy(this, &src->AST());
dst->SetTransformApplied(src->TransformsApplied());
}
ast::FunctionList CloneContext::Clone(const ast::FunctionList& v) {

View File

@ -42,7 +42,6 @@ Program::Program(Program&& program)
sem_(std::move(program.sem_)),
symbols_(std::move(program.symbols_)),
diagnostics_(std::move(program.diagnostics_)),
transforms_applied_(std::move(program.transforms_applied_)),
is_valid_(program.is_valid_) {
program.AssertNotMoved();
program.moved_ = true;
@ -67,7 +66,6 @@ Program::Program(ProgramBuilder&& builder) {
sem_ = std::move(builder.Sem());
symbols_ = std::move(builder.Symbols());
diagnostics_.add(std::move(builder.Diagnostics()));
transforms_applied_ = builder.TransformsApplied();
builder.MarkAsMoved();
if (!is_valid_ && !diagnostics_.contains_errors()) {
@ -92,7 +90,6 @@ Program& Program::operator=(Program&& program) {
sem_ = std::move(program.sem_);
symbols_ = std::move(program.symbols_);
diagnostics_ = std::move(program.diagnostics_);
transforms_applied_ = std::move(program.transforms_applied_);
is_valid_ = program.is_valid_;
return *this;
}

View File

@ -126,25 +126,6 @@ class Program {
/// information
bool IsValid() const;
/// @return the TypeInfo pointers of all transforms that have been applied to
/// this program.
std::unordered_set<const TypeInfo*> TransformsApplied() const {
return transforms_applied_;
}
/// @param transform the TypeInfo of the transform
/// @returns true if the transform with the given TypeInfo was applied to the
/// Program
bool HasTransformApplied(const TypeInfo* transform) const {
return transforms_applied_.count(transform);
}
/// @returns true if the transform of type `T` was applied.
template <typename T>
bool HasTransformApplied() const {
return HasTransformApplied(&TypeInfo::Of<T>());
}
/// Helper for returning the resolved semantic type of the expression `expr`.
/// @param expr the AST expression
/// @return the resolved semantic type for the expression, or nullptr if the
@ -184,7 +165,6 @@ class Program {
sem::Info sem_;
SymbolTable symbols_{id_};
diag::List diagnostics_;
std::unordered_set<const TypeInfo*> transforms_applied_;
bool is_valid_ = false; // Not valid until it is built
bool moved_ = false;
};

View File

@ -38,8 +38,7 @@ ProgramBuilder::ProgramBuilder(ProgramBuilder&& rhs)
ast_(rhs.ast_),
sem_(std::move(rhs.sem_)),
symbols_(std::move(rhs.symbols_)),
diagnostics_(std::move(rhs.diagnostics_)),
transforms_applied_(std::move(rhs.transforms_applied_)) {
diagnostics_(std::move(rhs.diagnostics_)) {
rhs.MarkAsMoved();
}
@ -56,7 +55,7 @@ ProgramBuilder& ProgramBuilder::operator=(ProgramBuilder&& rhs) {
sem_ = std::move(rhs.sem_);
symbols_ = std::move(rhs.symbols_);
diagnostics_ = std::move(rhs.diagnostics_);
transforms_applied_ = std::move(rhs.transforms_applied_);
return *this;
}
@ -69,7 +68,6 @@ ProgramBuilder ProgramBuilder::Wrap(const Program* program) {
builder.sem_ = sem::Info::Wrap(program->Sem());
builder.symbols_ = program->Symbols();
builder.diagnostics_ = program->Diagnostics();
builder.transforms_applied_ = program->TransformsApplied();
return builder;
}

View File

@ -2469,49 +2469,6 @@ class ProgramBuilder {
source_ = Source(loc);
}
/// Marks that the given transform has been applied to this program.
/// @param transform the transform that has been applied
void SetTransformApplied(const CastableBase* transform) {
transforms_applied_.emplace(&transform->TypeInfo());
}
/// Marks that the given transform `T` has been applied to this program.
template <typename T>
void SetTransformApplied() {
transforms_applied_.emplace(&TypeInfo::Of<T>());
}
/// Marks that the transforms with the given TypeInfos have been applied to
/// this program.
/// @param transforms the set of transform TypeInfos that has been applied
void SetTransformApplied(
const std::unordered_set<const TypeInfo*>& transforms) {
for (auto* transform : transforms) {
transforms_applied_.emplace(transform);
}
}
/// Unmarks that the given transform `T` has been applied to this program.
template <typename T>
void UnsetTransformApplied() {
auto it = transforms_applied_.find(&TypeInfo::Of<T>());
if (it != transforms_applied_.end()) {
transforms_applied_.erase(it);
}
}
/// @returns true if the transform of type `T` was applied.
template <typename T>
bool HasTransformApplied() {
return transforms_applied_.count(&TypeInfo::Of<T>());
}
/// @return the TypeInfo pointers of all transforms that have been applied to
/// this program.
std::unordered_set<const TypeInfo*> TransformsApplied() const {
return transforms_applied_;
}
/// Helper for returning the resolved semantic type of the expression `expr`.
/// @note As the Resolver is run when the Program is built, this will only be
/// useful for the Resolver itself and tests that use their own Resolver.
@ -2592,7 +2549,6 @@ class ProgramBuilder {
sem::Info sem_;
SymbolTable symbols_{id_};
diag::List diagnostics_;
std::unordered_set<const TypeInfo*> transforms_applied_;
/// The source to use when creating AST nodes without providing a Source as
/// the first argument.

View File

@ -27,15 +27,19 @@ AddEmptyEntryPoint::AddEmptyEntryPoint() = default;
AddEmptyEntryPoint::~AddEmptyEntryPoint() = default;
bool AddEmptyEntryPoint::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* func : program->AST().Functions()) {
if (func->IsEntryPoint()) {
return false;
}
}
return true;
}
void AddEmptyEntryPoint::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
for (auto* func : ctx.src->AST().Functions()) {
if (func->IsEntryPoint()) {
ctx.Clone();
return;
}
}
ctx.dst->Func(ctx.dst->Symbols().New("unused_entry_point"), {},
ctx.dst->ty.void_(), {},
{ctx.dst->Stage(ast::PipelineStage::kCompute),

View File

@ -28,6 +28,12 @@ class AddEmptyEntryPoint : public Castable<AddEmptyEntryPoint, Transform> {
/// Destructor
~AddEmptyEntryPoint() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -24,6 +24,21 @@ namespace {
using AddEmptyEntryPointTest = TransformTest;
TEST_F(AddEmptyEntryPointTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_TRUE(ShouldRun<AddEmptyEntryPoint>(src));
}
TEST_F(AddEmptyEntryPointTest, ShouldRunExistingEntryPoint) {
auto* src = R"(
[[stage(compute), workgroup_size(1)]]
fn existing() {}
)";
EXPECT_FALSE(ShouldRun<AddEmptyEntryPoint>(src));
}
TEST_F(AddEmptyEntryPointTest, EmptyModule) {
auto* src = R"()";

View File

@ -20,6 +20,7 @@
#include "src/program_builder.h"
#include "src/sem/call.h"
#include "src/sem/function.h"
#include "src/sem/variable.h"
#include "src/transform/simplify_pointers.h"
@ -93,13 +94,23 @@ static void IterateArrayLengthOnStorageVar(CloneContext& ctx, F&& functor) {
}
}
bool ArrayLengthFromUniform::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* fn : program->AST().Functions()) {
if (auto* sem_fn = program->Sem().Get(fn)) {
for (auto* intrinsic : sem_fn->DirectlyCalledIntrinsics()) {
if (intrinsic->Type() == sem::IntrinsicType::kArrayLength) {
return true;
}
}
}
}
return false;
}
void ArrayLengthFromUniform::Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const {
if (!Requires<SimplifyPointers>(ctx)) {
return;
}
auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) {
ctx.dst->Diagnostics().add_error(

View File

@ -49,6 +49,9 @@ namespace transform {
/// This transform assumes that the `SimplifyPointers`
/// transforms have been run before it so that arguments to the arrayLength
/// builtin always have the form `&resource.array`.
///
/// @note Depends on the following transforms to have been run first:
/// * SimplifyPointers
class ArrayLengthFromUniform
: public Castable<ArrayLengthFromUniform, Transform> {
public:
@ -81,6 +84,8 @@ class ArrayLengthFromUniform
};
/// Information produced about what the transform did.
/// If there were no calls to the arrayLength() intrinsic, then no Result will
/// be emitted.
struct Result : public Castable<Result, transform::Data> {
/// Constructor
/// @param used_size_indices Indices into the UBO that are statically used.
@ -96,6 +101,12 @@ class ArrayLengthFromUniform
const std::unordered_set<uint32_t> used_size_indices;
};
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -26,8 +26,61 @@ namespace {
using ArrayLengthFromUniformTest = TransformTest;
TEST_F(ArrayLengthFromUniformTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src));
}
TEST_F(ArrayLengthFromUniformTest, ShouldRunNoArrayLength) {
auto* src = R"(
struct SB {
x : i32;
arr : array<i32>;
};
[[group(0), binding(0)]] var<storage, read> sb : SB;
[[stage(compute), workgroup_size(1)]]
fn main() {
}
)";
EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src));
}
TEST_F(ArrayLengthFromUniformTest, ShouldRunWithArrayLength) {
auto* src = R"(
struct SB {
x : i32;
arr : array<i32>;
};
[[group(0), binding(0)]] var<storage, read> sb : SB;
[[stage(compute), workgroup_size(1)]]
fn main() {
var len : u32 = arrayLength(&sb.arr);
}
)";
EXPECT_TRUE(ShouldRun<ArrayLengthFromUniform>(src));
}
TEST_F(ArrayLengthFromUniformTest, Error_MissingTransformData) {
auto* src = "";
auto* src = R"(
struct SB {
x : i32;
arr : array<i32>;
};
[[group(0), binding(0)]] var<storage, read> sb : SB;
[[stage(compute), workgroup_size(1)]]
fn main() {
var len : u32 = arrayLength(&sb.arr);
}
)";
auto* expect =
"error: missing transform data for "
@ -38,18 +91,6 @@ TEST_F(ArrayLengthFromUniformTest, Error_MissingTransformData) {
EXPECT_EQ(expect, str(got));
}
TEST_F(ArrayLengthFromUniformTest, Error_MissingSimplifyPointers) {
auto* src = "";
auto* expect =
"error: tint::transform::ArrayLengthFromUniform depends on "
"tint::transform::SimplifyPointers but the dependency was not run";
auto got = Run<ArrayLengthFromUniform>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ArrayLengthFromUniformTest, Basic) {
auto* src = R"(
@group(0) @binding(0) var<storage, read> sb : array<i32>;
@ -426,8 +467,7 @@ fn main() {
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(src, str(got));
EXPECT_EQ(std::unordered_set<uint32_t>(),
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
EXPECT_EQ(got.data.Get<ArrayLengthFromUniform::Result>(), nullptr);
}
TEST_F(ArrayLengthFromUniformTest, MissingBindingPointToIndexMapping) {

View File

@ -42,6 +42,14 @@ BindingRemapper::Remappings::~Remappings() = default;
BindingRemapper::BindingRemapper() = default;
BindingRemapper::~BindingRemapper() = default;
bool BindingRemapper::ShouldRun(const Program*, const DataMap& inputs) const {
if (auto* remappings = inputs.Get<Remappings>()) {
return !remappings->binding_points.empty() ||
!remappings->access_controls.empty();
}
return false;
}
void BindingRemapper::Run(CloneContext& ctx,
const DataMap& inputs,
DataMap&) const {

View File

@ -68,6 +68,12 @@ class BindingRemapper : public Castable<BindingRemapper, Transform> {
BindingRemapper();
~BindingRemapper() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -24,6 +24,47 @@ namespace {
using BindingRemapperTest = TransformTest;
TEST_F(BindingRemapperTest, ShouldRunNoRemappings) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<BindingRemapper>(src));
}
TEST_F(BindingRemapperTest, ShouldRunEmptyRemappings) {
auto* src = R"()";
DataMap data;
data.Add<BindingRemapper::Remappings>(BindingRemapper::BindingPoints{},
BindingRemapper::AccessControls{});
EXPECT_FALSE(ShouldRun<BindingRemapper>(src, data));
}
TEST_F(BindingRemapperTest, ShouldRunBindingPointRemappings) {
auto* src = R"()";
DataMap data;
data.Add<BindingRemapper::Remappings>(
BindingRemapper::BindingPoints{
{{2, 1}, {1, 2}},
},
BindingRemapper::AccessControls{});
EXPECT_TRUE(ShouldRun<BindingRemapper>(src, data));
}
TEST_F(BindingRemapperTest, ShouldRunAccessControlRemappings) {
auto* src = R"()";
DataMap data;
data.Add<BindingRemapper::Remappings>(BindingRemapper::BindingPoints{},
BindingRemapper::AccessControls{
{{2, 1}, ast::Access::kWrite},
});
EXPECT_TRUE(ShouldRun<BindingRemapper>(src, data));
}
TEST_F(BindingRemapperTest, NoRemappings) {
auto* src = R"(
struct S {
@ -359,17 +400,18 @@ TEST_F(BindingRemapperTest, NoData) {
auto* src = R"(
struct S {
a : f32;
};
}
@group(2) @binding(1) var<storage, read> a : S;
@group(3) @binding(2) var<storage, read> b : S;
@stage(compute) @workgroup_size(1)
fn f() {}
fn f() {
}
)";
auto* expect =
"error: missing transform data for tint::transform::BindingRemapper";
auto* expect = src;
auto got = Run<BindingRemapper>(src);

View File

@ -22,6 +22,7 @@
#include "src/program_builder.h"
#include "src/sem/block_statement.h"
#include "src/sem/call.h"
#include "src/sem/function.h"
#include "src/sem/statement.h"
#include "src/sem/struct.h"
#include "src/sem/variable.h"
@ -71,13 +72,24 @@ CalculateArrayLength::BufferSizeIntrinsic::Clone(CloneContext* ctx) const {
CalculateArrayLength::CalculateArrayLength() = default;
CalculateArrayLength::~CalculateArrayLength() = default;
bool CalculateArrayLength::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* fn : program->AST().Functions()) {
if (auto* sem_fn = program->Sem().Get(fn)) {
for (auto* intrinsic : sem_fn->DirectlyCalledIntrinsics()) {
if (intrinsic->Type() == sem::IntrinsicType::kArrayLength) {
return true;
}
}
}
}
return false;
}
void CalculateArrayLength::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
auto& sem = ctx.src->Sem();
if (!Requires<SimplifyPointers>(ctx)) {
return;
}
// get_buffer_size_intrinsic() emits the function decorated with
// BufferSizeIntrinsic that is transformed by the HLSL writer into a call to

View File

@ -29,6 +29,9 @@ namespace transform {
/// CalculateArrayLength is a transform used to replace calls to arrayLength()
/// with a value calculated from the size of the storage buffer.
///
/// @note Depends on the following transforms to have been run first:
/// * SimplifyPointers
class CalculateArrayLength : public Castable<CalculateArrayLength, Transform> {
public:
/// BufferSizeIntrinsic is an InternalDecoration that's applied to intrinsic
@ -56,6 +59,12 @@ class CalculateArrayLength : public Castable<CalculateArrayLength, Transform> {
/// Destructor
~CalculateArrayLength() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -24,16 +24,45 @@ namespace {
using CalculateArrayLengthTest = TransformTest;
TEST_F(CalculateArrayLengthTest, Error_MissingCalculateArrayLength) {
auto* src = "";
TEST_F(CalculateArrayLengthTest, ShouldRunEmptyModule) {
auto* src = R"()";
auto* expect =
"error: tint::transform::CalculateArrayLength depends on "
"tint::transform::SimplifyPointers but the dependency was not run";
EXPECT_FALSE(ShouldRun<CalculateArrayLength>(src));
}
auto got = Run<CalculateArrayLength>(src);
TEST_F(CalculateArrayLengthTest, ShouldRunNoArrayLength) {
auto* src = R"(
struct SB {
x : i32;
arr : array<i32>;
};
EXPECT_EQ(expect, str(got));
[[group(0), binding(0)]] var<storage, read> sb : SB;
[[stage(compute), workgroup_size(1)]]
fn main() {
}
)";
EXPECT_FALSE(ShouldRun<CalculateArrayLength>(src));
}
TEST_F(CalculateArrayLengthTest, ShouldRunWithArrayLength) {
auto* src = R"(
struct SB {
x : i32;
arr : array<i32>;
};
[[group(0), binding(0)]] var<storage, read> sb : SB;
[[stage(compute), workgroup_size(1)]]
fn main() {
var len : u32 = arrayLength(&sb.arr);
}
)";
EXPECT_TRUE(ShouldRun<CalculateArrayLength>(src));
}
TEST_F(CalculateArrayLengthTest, Basic) {

View File

@ -551,10 +551,6 @@ struct CanonicalizeEntryPointIO::State {
void CanonicalizeEntryPointIO::Run(CloneContext& ctx,
const DataMap& inputs,
DataMap&) const {
if (!Requires<Unshadow>(ctx)) {
return;
}
auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) {
ctx.dst->Diagnostics().add_error(

View File

@ -80,6 +80,9 @@ namespace transform {
/// return wrapper_result;
/// }
/// ```
///
/// @note Depends on the following transforms to have been run first:
/// * Unshadow
class CanonicalizeEntryPointIO
: public Castable<CanonicalizeEntryPointIO, Transform> {
public:

View File

@ -23,18 +23,6 @@ namespace {
using CanonicalizeEntryPointIOTest = TransformTest;
TEST_F(CanonicalizeEntryPointIOTest, Error_MissingUnshadow) {
auto* src = "";
auto* expect =
"error: tint::transform::CanonicalizeEntryPointIO depends on "
"tint::transform::Unshadow but the dependency was not run";
auto got = Run<CanonicalizeEntryPointIO>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, Error_MissingTransformData) {
auto* src = "";

View File

@ -790,6 +790,19 @@ const DecomposeMemoryAccess::Intrinsic* DecomposeMemoryAccess::Intrinsic::Clone(
DecomposeMemoryAccess::DecomposeMemoryAccess() = default;
DecomposeMemoryAccess::~DecomposeMemoryAccess() = default;
bool DecomposeMemoryAccess::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* decl : program->AST().GlobalDeclarations()) {
if (auto* var = program->Sem().Get<sem::Variable>(decl)) {
if (var->StorageClass() == ast::StorageClass::kStorage ||
var->StorageClass() == ast::StorageClass::kUniform) {
return true;
}
}
}
return false;
}
void DecomposeMemoryAccess::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {

View File

@ -105,6 +105,12 @@ class DecomposeMemoryAccess
/// Destructor
~DecomposeMemoryAccess() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -22,6 +22,34 @@ namespace {
using DecomposeMemoryAccessTest = TransformTest;
TEST_F(DecomposeMemoryAccessTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<DecomposeMemoryAccess>(src));
}
TEST_F(DecomposeMemoryAccessTest, ShouldRunStorageBuffer) {
auto* src = R"(
struct Buffer {
i : i32;
};
[[group(0), binding(0)]] var<storage, read_write> sb : Buffer;
)";
EXPECT_TRUE(ShouldRun<DecomposeMemoryAccess>(src));
}
TEST_F(DecomposeMemoryAccessTest, ShouldRunUniformBuffer) {
auto* src = R"(
struct Buffer {
i : i32;
};
[[group(0), binding(0)]] var<uniform> ub : Buffer;
)";
EXPECT_TRUE(ShouldRun<DecomposeMemoryAccess>(src));
}
TEST_F(DecomposeMemoryAccessTest, SB_BasicLoad) {
auto* src = R"(
struct SB {

View File

@ -107,7 +107,8 @@ DecomposeStridedMatrix::DecomposeStridedMatrix() = default;
DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
bool DecomposeStridedMatrix::ShouldRun(const Program* program) const {
bool DecomposeStridedMatrix::ShouldRun(const Program* program,
const DataMap&) const {
bool should_run = false;
GatherCustomStrideMatrixMembers(
program, [&](const sem::StructMember*, sem::Matrix*, uint32_t) {
@ -120,10 +121,6 @@ bool DecomposeStridedMatrix::ShouldRun(const Program* program) const {
void DecomposeStridedMatrix::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
if (!Requires<SimplifyPointers>(ctx)) {
return;
}
// Scan the program for all storage and uniform structure matrix members with
// a custom stride attribute. Replace these matrices with an equivalent array,
// and populate the `decomposed` map with the members that have been replaced.

View File

@ -25,6 +25,9 @@ namespace transform {
/// of N column vectors.
/// This transform is used by the SPIR-V reader to handle the SPIR-V
/// MatrixStride decoration.
///
/// @note Depends on the following transforms to have been run first:
/// * SimplifyPointers
class DecomposeStridedMatrix
: public Castable<DecomposeStridedMatrix, Transform> {
public:
@ -35,8 +38,10 @@ class DecomposeStridedMatrix
~DecomposeStridedMatrix() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program) const override;
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a

View File

@ -31,55 +31,25 @@ namespace {
using DecomposeStridedMatrixTest = TransformTest;
using f32 = ProgramBuilder::f32;
TEST_F(DecomposeStridedMatrixTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<DecomposeStridedMatrix>(src));
}
TEST_F(DecomposeStridedMatrixTest, ShouldRunNonStridedMatrox) {
auto* src = R"(
var<private> m : mat3x2<f32>;
)";
EXPECT_FALSE(ShouldRun<DecomposeStridedMatrix>(src));
}
TEST_F(DecomposeStridedMatrixTest, Empty) {
auto* src = R"()";
auto* expect = src;
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, MissingDependencySimplify) {
// struct S {
// [[offset(16), stride(32)]]
// [[internal(ignore_stride_decoration)]]
// m : mat2x2<f32>;
// };
// [[group(0), binding(0)]] var<uniform> s : S;
//
// [[stage(compute), workgroup_size(1)]]
// fn f() {
// let x : mat2x2<f32> = s.m;
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetDecoration>(16),
b.create<ast::StrideDecoration>(32),
b.Disable(ast::DisabledValidation::kIgnoreStrideDecoration),
}),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
b.GroupAndBinding(0, 0));
b.Func(
"f", {}, b.ty.void_(),
{
b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect =
R"(error: tint::transform::DecomposeStridedMatrix depends on tint::transform::SimplifyPointers but the dependency was not run)";
auto got = Run<DecomposeStridedMatrix>(Program(std::move(b)));
auto got = Run<DecomposeStridedMatrix>(src);
EXPECT_EQ(expect, str(got));
}

View File

@ -57,6 +57,15 @@ FirstIndexOffset::Data::~Data() = default;
FirstIndexOffset::FirstIndexOffset() = default;
FirstIndexOffset::~FirstIndexOffset() = default;
bool FirstIndexOffset::ShouldRun(const Program* program, const DataMap&) const {
for (auto* fn : program->AST().Functions()) {
if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
return true;
}
}
return false;
}
void FirstIndexOffset::Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) const {

View File

@ -115,6 +115,12 @@ class FirstIndexOffset : public Castable<FirstIndexOffset, Transform> {
/// Destructor
~FirstIndexOffset() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -26,6 +26,34 @@ namespace {
using FirstIndexOffsetTest = TransformTest;
TEST_F(FirstIndexOffsetTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<FirstIndexOffset>(src));
}
TEST_F(FirstIndexOffsetTest, ShouldRunFragmentStage) {
auto* src = R"(
[[stage(fragment)]]
fn entry() {
return;
}
)";
EXPECT_FALSE(ShouldRun<FirstIndexOffset>(src));
}
TEST_F(FirstIndexOffsetTest, ShouldRunVertexStage) {
auto* src = R"(
[[stage(vertex)]]
fn entry() -> [[builtin(position)]] vec4<f32> {
return vec4<f32>();
}
)";
EXPECT_TRUE(ShouldRun<FirstIndexOffset>(src));
}
TEST_F(FirstIndexOffsetTest, EmptyModule) {
auto* src = "";
auto* expect = "";
@ -38,6 +66,26 @@ TEST_F(FirstIndexOffsetTest, EmptyModule) {
auto* data = got.data.Get<FirstIndexOffset::Data>();
EXPECT_EQ(data, nullptr);
}
TEST_F(FirstIndexOffsetTest, BasicVertexShader) {
auto* src = R"(
@stage(vertex)
fn entry() -> @builtin(position) vec4<f32> {
return vec4<f32>();
}
)";
auto* expect = src;
DataMap config;
config.Add<FirstIndexOffset::BindingPoint>(0, 0);
auto got = Run<FirstIndexOffset>(src, std::move(config));
EXPECT_EQ(expect, str(got));
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
EXPECT_EQ(data->has_vertex_index, false);
EXPECT_EQ(data->has_instance_index, false);

View File

@ -25,6 +25,15 @@ ForLoopToLoop::ForLoopToLoop() = default;
ForLoopToLoop::~ForLoopToLoop() = default;
bool ForLoopToLoop::ShouldRun(const Program* program, const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (node->Is<ast::ForLoopStatement>()) {
return true;
}
}
return false;
}
void ForLoopToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
ctx.ReplaceAll(
[&](const ast::ForLoopStatement* for_loop) -> const ast::Statement* {

View File

@ -30,6 +30,12 @@ class ForLoopToLoop : public Castable<ForLoopToLoop, Transform> {
/// Destructor
~ForLoopToLoop() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -22,6 +22,24 @@ namespace {
using ForLoopToLoopTest = TransformTest;
TEST_F(ForLoopToLoopTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<ForLoopToLoop>(src));
}
TEST_F(ForLoopToLoopTest, ShouldRunHasForLoop) {
auto* src = R"(
fn f() {
for (;;) {
break;
}
}
)";
EXPECT_TRUE(ShouldRun<ForLoopToLoop>(src));
}
TEST_F(ForLoopToLoopTest, EmptyModule) {
auto* src = "";
auto* expect = src;

View File

@ -113,7 +113,6 @@ Output Glsl::Run(const Program* in, const DataMap& inputs) const {
ProgramBuilder builder;
CloneContext ctx(&builder, &out.program);
ctx.Clone();
builder.SetTransformApplied(this);
return Output{Program(std::move(builder))};
}

View File

@ -216,15 +216,9 @@ LocalizeStructArrayAssignment::~LocalizeStructArrayAssignment() = default;
void LocalizeStructArrayAssignment::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
if (!Requires<SimplifyPointers>(ctx)) {
return;
}
State state(ctx);
state.Run();
// This transform may introduce pointers
ctx.dst->UnsetTransformApplied<transform::SimplifyPointers>();
}
} // namespace transform
} // namespace tint

View File

@ -25,6 +25,9 @@ namespace transform {
/// temporary local variable, assigns to the local variable, and copies the
/// array back. This is to work around FXC's compilation failure for these cases
/// (see crbug.com/tint/1206).
///
/// @note Depends on the following transforms to have been run first:
/// * SimplifyPointers
class LocalizeStructArrayAssignment
: public Castable<LocalizeStructArrayAssignment, Transform> {
public:

View File

@ -24,16 +24,6 @@ namespace {
using LocalizeStructArrayAssignmentTest = TransformTest;
TEST_F(LocalizeStructArrayAssignmentTest, MissingSimplifyPointers) {
auto* src = R"()";
auto* expect =
"error: tint::transform::LocalizeStructArrayAssignment depends on "
"tint::transform::SimplifyPointers but the dependency was not run";
auto got = Run<LocalizeStructArrayAssignment>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, EmptyModule) {
auto* src = R"()";
auto* expect = src;

View File

@ -56,6 +56,15 @@ LoopToForLoop::LoopToForLoop() = default;
LoopToForLoop::~LoopToForLoop() = default;
bool LoopToForLoop::ShouldRun(const Program* program, const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (node->Is<ast::LoopStatement>()) {
return true;
}
}
return false;
}
void LoopToForLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
ctx.ReplaceAll([&](const ast::LoopStatement* loop) -> const ast::Statement* {
// For loop condition is taken from the first statement in the loop.

View File

@ -30,6 +30,12 @@ class LoopToForLoop : public Castable<LoopToForLoop, Transform> {
/// Destructor
~LoopToForLoop() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -22,6 +22,24 @@ namespace {
using LoopToForLoopTest = TransformTest;
TEST_F(LoopToForLoopTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<LoopToForLoop>(src));
}
TEST_F(LoopToForLoopTest, ShouldRunHasForLoop) {
auto* src = R"(
fn f() {
loop {
break;
}
}
)";
EXPECT_TRUE(ShouldRun<LoopToForLoop>(src));
}
TEST_F(LoopToForLoopTest, EmptyModule) {
auto* src = "";
auto* expect = "";

View File

@ -53,7 +53,7 @@ Output Manager::Run(const Program* program, const DataMap& data) const {
Output out;
for (const auto& transform : transforms_) {
if (!transform->ShouldRun(in)) {
if (!transform->ShouldRun(in, data)) {
TINT_IF_PRINT_PROGRAM(std::cout << "Skipping "
<< transform->TypeInfo().name);
continue;

View File

@ -377,6 +377,16 @@ ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default;
bool ModuleScopeVarToEntryPointParam::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* decl : program->AST().GlobalDeclarations()) {
if (decl->Is<ast::Variable>()) {
return true;
}
}
return false;
}
void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {

View File

@ -70,6 +70,12 @@ class ModuleScopeVarToEntryPointParam
/// Destructor
~ModuleScopeVarToEntryPointParam() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -24,6 +24,20 @@ namespace {
using ModuleScopeVarToEntryPointParamTest = TransformTest;
TEST_F(ModuleScopeVarToEntryPointParamTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<ModuleScopeVarToEntryPointParam>(src));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, ShouldRunHasGlobal) {
auto* src = R"(
var<private> v : i32;
)";
EXPECT_TRUE(ShouldRun<ModuleScopeVarToEntryPointParam>(src));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Basic) {
auto* src = R"(
var<private> p : f32;

View File

@ -81,7 +81,7 @@ struct MultiplanarExternalTexture::State {
// represent the secondary plane and one uniform buffer for the
// ExternalTextureParams struct).
ctx.ReplaceAll([&](const ast::Variable* var) -> const ast::Variable* {
if (!::tint::Is<ast::ExternalTexture>(var->type)) {
if (!sem.Get<sem::ExternalTexture>(var->type)) {
return nullptr;
}
@ -201,7 +201,7 @@ struct MultiplanarExternalTexture::State {
// functions.
ctx.ReplaceAll([&](const ast::Function* fn) -> const ast::Function* {
for (const ast::Variable* param : fn->params) {
if (::tint::Is<ast::ExternalTexture>(param->type)) {
if (sem.Get<sem::ExternalTexture>(param->type)) {
// If we find a texture_external, we must ensure the
// ExternalTextureParams struct exists.
if (!params_struct_sym.IsValid()) {
@ -407,6 +407,18 @@ MultiplanarExternalTexture::NewBindingPoints::~NewBindingPoints() = default;
MultiplanarExternalTexture::MultiplanarExternalTexture() = default;
MultiplanarExternalTexture::~MultiplanarExternalTexture() = default;
bool MultiplanarExternalTexture::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* ty = node->As<ast::Type>()) {
if (program->Sem().Get<sem::ExternalTexture>(ty)) {
return true;
}
}
}
return false;
}
// Within this transform, an instance of a texture_external binding is unpacked
// into two texture_2d<f32> bindings representing two possible planes of a
// single texture and a uniform buffer binding representing a struct of

View File

@ -76,6 +76,12 @@ class MultiplanarExternalTexture
/// Destructor
~MultiplanarExternalTexture() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
struct State;

View File

@ -21,6 +21,35 @@ namespace {
using MultiplanarExternalTextureTest = TransformTest;
TEST_F(MultiplanarExternalTextureTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<MultiplanarExternalTexture>(src));
}
TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureAlias) {
auto* src = R"(
type ET = texture_external;
)";
EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src));
}
TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureGlobal) {
auto* src = R"(
[[group(0), binding(0)]] var ext_tex : texture_external;
)";
EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src));
}
TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureParam) {
auto* src = R"(
fn f(ext_tex : texture_external) {}
)";
EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src));
}
// Running the transform without passing in data for the new bindings should
// result in an error.
TEST_F(MultiplanarExternalTextureTest, ErrorNoPassedData) {
@ -651,6 +680,106 @@ fn main() {
EXPECT_EQ(expect, str(got));
}
// Tests that the transform works with a function using an external texture,
// even if there's no external texture declared at module scope.
TEST_F(MultiplanarExternalTextureTest,
ExternalTexturePassedAsParamWithoutGlobalDecl) {
auto* src = R"(
fn f(ext_tex : texture_external) -> vec2<i32> {
return textureDimensions(ext_tex);
}
)";
auto* expect = R"(
struct ExternalTextureParams {
numPlanes : u32;
vr : f32;
ug : f32;
vg : f32;
ub : f32;
}
fn f(ext_tex : texture_2d<f32>, ext_tex_plane_1 : texture_2d<f32>, ext_tex_params : ExternalTextureParams) -> vec2<i32> {
return textureDimensions(ext_tex);
}
)";
DataMap data;
data.Add<MultiplanarExternalTexture::NewBindingPoints>(
MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
auto got = Run<MultiplanarExternalTexture>(src, data);
EXPECT_EQ(expect, str(got));
}
// Tests that the the transform handles aliases to external textures
TEST_F(MultiplanarExternalTextureTest, ExternalTextureAlias) {
auto* src = R"(
type ET = texture_external;
fn f(t : ET, s : sampler) {
textureSampleLevel(t, s, vec2<f32>(1.0, 2.0));
}
[[group(0), binding(0)]] var ext_tex : ET;
[[group(0), binding(1)]] var smp : sampler;
[[stage(fragment)]]
fn main() {
f(ext_tex, smp);
}
)";
auto* expect = R"(
type ET = texture_external;
struct ExternalTextureParams {
numPlanes : u32;
vr : f32;
ug : f32;
vg : f32;
ub : f32;
}
fn textureSampleExternal(plane0 : texture_2d<f32>, plane1 : texture_2d<f32>, smp : sampler, coord : vec2<f32>, params : ExternalTextureParams) -> vec4<f32> {
if ((params.numPlanes == 1u)) {
return textureSampleLevel(plane0, smp, coord, 0.0);
}
let y = (textureSampleLevel(plane0, smp, coord, 0.0).r - 0.0625);
let uv = (textureSampleLevel(plane1, smp, coord, 0.0).rg - 0.5);
let u = uv.x;
let v = uv.y;
let r = ((1.164000034 * y) + (params.vr * v));
let g = (((1.164000034 * y) - (params.ug * u)) - (params.vg * v));
let b = ((1.164000034 * y) + (params.ub * u));
return vec4<f32>(r, g, b, 1.0);
}
fn f(t : texture_2d<f32>, ext_tex_plane_1 : texture_2d<f32>, ext_tex_params : ExternalTextureParams, s : sampler) {
textureSampleExternal(t, ext_tex_plane_1, s, vec2<f32>(1.0, 2.0), ext_tex_params);
}
@group(0) @binding(2) var ext_tex_plane_1_1 : texture_2d<f32>;
@group(0) @binding(3) var<uniform> ext_tex_params_1 : ExternalTextureParams;
@group(0) @binding(0) var ext_tex : texture_2d<f32>;
@group(0) @binding(1) var smp : sampler;
@stage(fragment)
fn main() {
f(ext_tex, ext_tex_plane_1_1, ext_tex_params_1, smp);
}
)";
DataMap data;
data.Add<MultiplanarExternalTexture::NewBindingPoints>(
MultiplanarExternalTexture::BindingsMap{
{{0, 0}, {{0, 2}, {0, 3}}},
});
auto got = Run<MultiplanarExternalTexture>(src, data);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace transform
} // namespace tint

View File

@ -52,13 +52,21 @@ struct Accessor {
NumWorkgroupsFromUniform::NumWorkgroupsFromUniform() = default;
NumWorkgroupsFromUniform::~NumWorkgroupsFromUniform() = default;
bool NumWorkgroupsFromUniform::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* deco = node->As<ast::BuiltinDecoration>()) {
if (deco->builtin == ast::Builtin::kNumWorkgroups) {
return true;
}
}
}
return false;
}
void NumWorkgroupsFromUniform::Run(CloneContext& ctx,
const DataMap& inputs,
DataMap&) const {
if (!Requires<CanonicalizeEntryPointIO>(ctx)) {
return;
}
auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) {
ctx.dst->Diagnostics().add_error(

View File

@ -39,6 +39,9 @@ namespace transform {
/// ```
/// The binding group and number used for this uniform buffer is provided via
/// the `Config` transform input.
///
/// @note Depends on the following transforms to have been run first:
/// * CanonicalizeEntryPointIO
class NumWorkgroupsFromUniform
: public Castable<NumWorkgroupsFromUniform, Transform> {
public:
@ -63,6 +66,12 @@ class NumWorkgroupsFromUniform
sem::BindingPoint ubo_binding;
};
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -26,8 +26,28 @@ namespace {
using NumWorkgroupsFromUniformTest = TransformTest;
TEST_F(NumWorkgroupsFromUniformTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<NumWorkgroupsFromUniform>(src));
}
TEST_F(NumWorkgroupsFromUniformTest, ShouldRunHasNumWorkgroups) {
auto* src = R"(
[[stage(compute), workgroup_size(1)]]
fn main([[builtin(num_workgroups)]] num_wgs : vec3<u32>) {
}
)";
EXPECT_TRUE(ShouldRun<NumWorkgroupsFromUniform>(src));
}
TEST_F(NumWorkgroupsFromUniformTest, Error_MissingTransformData) {
auto* src = "";
auto* src = R"(
[[stage(compute), workgroup_size(1)]]
fn main([[builtin(num_workgroups)]] num_wgs : vec3<u32>) {
}
)";
auto* expect =
"error: missing transform data for "
@ -42,19 +62,6 @@ TEST_F(NumWorkgroupsFromUniformTest, Error_MissingTransformData) {
EXPECT_EQ(expect, str(got));
}
TEST_F(NumWorkgroupsFromUniformTest, Error_MissingCanonicalizeEntryPointIO) {
auto* src = "";
auto* expect =
"error: tint::transform::NumWorkgroupsFromUniform depends on "
"tint::transform::CanonicalizeEntryPointIO but the dependency was not "
"run";
auto got = Run<NumWorkgroupsFromUniform>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(NumWorkgroupsFromUniformTest, Basic) {
auto* src = R"(
@stage(compute) @workgroup_size(1)

View File

@ -97,6 +97,19 @@ PadArrayElements::PadArrayElements() = default;
PadArrayElements::~PadArrayElements() = default;
bool PadArrayElements::ShouldRun(const Program* program, const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* var = node->As<ast::Type>()) {
if (auto* arr = program->Sem().Get<sem::Array>(var)) {
if (!arr->IsStrideImplicit()) {
return true;
}
}
}
}
return false;
}
void PadArrayElements::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
auto& sem = ctx.src->Sem();

View File

@ -38,6 +38,12 @@ class PadArrayElements : public Castable<PadArrayElements, Transform> {
/// Destructor
~PadArrayElements() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -24,6 +24,28 @@ namespace {
using PadArrayElementsTest = TransformTest;
TEST_F(PadArrayElementsTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<PadArrayElements>(src));
}
TEST_F(PadArrayElementsTest, ShouldRunHasImplicitArrayStride) {
auto* src = R"(
var<private> arr : array<i32, 4>;
)";
EXPECT_FALSE(ShouldRun<PadArrayElements>(src));
}
TEST_F(PadArrayElementsTest, ShouldRunHasExplicitArrayStride) {
auto* src = R"(
var<private> arr : [[stride(8)]] array<i32, 4>;
)";
EXPECT_TRUE(ShouldRun<PadArrayElements>(src));
}
TEST_F(PadArrayElementsTest, EmptyModule) {
auto* src = "";
auto* expect = "";

View File

@ -68,6 +68,15 @@ RemovePhonies::RemovePhonies() = default;
RemovePhonies::~RemovePhonies() = default;
bool RemovePhonies::ShouldRun(const Program* program, const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (node->Is<ast::PhonyExpression>()) {
return true;
}
}
return false;
}
void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
auto& sem = ctx.src->Sem();

View File

@ -34,6 +34,12 @@ class RemovePhonies : public Castable<RemovePhonies, Transform> {
/// Destructor
~RemovePhonies() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -26,6 +26,22 @@ namespace {
using RemovePhoniesTest = TransformTest;
TEST_F(RemovePhoniesTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<RemovePhonies>(src));
}
TEST_F(RemovePhoniesTest, ShouldRunHasPhony) {
auto* src = R"(
fn f() {
_ = 1;
}
)";
EXPECT_TRUE(ShouldRun<RemovePhonies>(src));
}
TEST_F(RemovePhoniesTest, EmptyModule) {
auto* src = "";
auto* expect = "";

View File

@ -37,6 +37,18 @@ RemoveUnreachableStatements::RemoveUnreachableStatements() = default;
RemoveUnreachableStatements::~RemoveUnreachableStatements() = default;
bool RemoveUnreachableStatements::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* stmt = program->Sem().Get<sem::Statement>(node)) {
if (!stmt->IsReachable()) {
return true;
}
}
}
return false;
}
void RemoveUnreachableStatements::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {

View File

@ -34,6 +34,12 @@ class RemoveUnreachableStatements
/// Destructor
~RemoveUnreachableStatements() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -22,6 +22,37 @@ namespace {
using RemoveUnreachableStatementsTest = TransformTest;
TEST_F(RemoveUnreachableStatementsTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<RemoveUnreachableStatements>(src));
}
TEST_F(RemoveUnreachableStatementsTest, ShouldRunHasNoUnreachable) {
auto* src = R"(
fn f() {
if (true) {
var x = 1;
}
}
)";
EXPECT_FALSE(ShouldRun<RemoveUnreachableStatements>(src));
}
TEST_F(RemoveUnreachableStatementsTest, ShouldRunHasUnreachable) {
auto* src = R"(
fn f() {
return;
if (true) {
var x = 1;
}
}
)";
EXPECT_TRUE(ShouldRun<RemoveUnreachableStatements>(src));
}
TEST_F(RemoveUnreachableStatementsTest, EmptyModule) {
auto* src = "";
auto* expect = "";

View File

@ -232,9 +232,6 @@ SimplifyPointers::SimplifyPointers() = default;
SimplifyPointers::~SimplifyPointers() = default;
void SimplifyPointers::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
if (!Requires<Unshadow>(ctx)) {
return;
}
State(ctx).Run();
}

View File

@ -29,6 +29,9 @@ namespace transform {
/// Note: SimplifyPointers does not operate on module-scope `let`s, as these
/// cannot be pointers: https://gpuweb.github.io/gpuweb/wgsl/#module-constants
/// `A module-scope let-declared constant must be of constructible type.`
///
/// @note Depends on the following transforms to have been run first:
/// * Unshadow
class SimplifyPointers : public Castable<SimplifyPointers, Transform> {
public:
/// Constructor

View File

@ -23,18 +23,6 @@ namespace {
using SimplifyPointersTest = TransformTest;
TEST_F(SimplifyPointersTest, Error_MissingSimplifyPointers) {
auto* src = "";
auto* expect =
"error: tint::transform::SimplifyPointers depends on "
"tint::transform::Unshadow but the dependency was not run";
auto got = Run<SimplifyPointers>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SimplifyPointersTest, EmptyModule) {
auto* src = "";
auto* expect = "";

View File

@ -107,11 +107,6 @@ void SingleEntryPoint::Run(CloneContext& ctx,
// Clone the entry point.
ctx.dst->AST().AddFunction(ctx.Clone(entry_point));
// Retain the list of applied transforms.
// We need to do this manually since we are not going to use the top-level
// ctx.Clone() function.
ctx.dst->SetTransformApplied(ctx.src->TransformsApplied());
}
SingleEntryPoint::Config::Config(std::string entry_point)

View File

@ -81,6 +81,17 @@ class TransformTestBase : public BASE {
return manager.Run(&program, data);
}
/// @param in the input WGSL source
/// @param data the optional DataMap to pass to Transform::Run()
/// @return true if the transform should be run for the given input.
template <typename TRANSFORM>
bool ShouldRun(std::string in, const DataMap& data = {}) {
auto file = std::make_unique<Source::File>("test", in);
auto program = reader::wgsl::Parse(file.get());
EXPECT_TRUE(program.IsValid()) << program.Diagnostics().str();
return TRANSFORM().ShouldRun(&program, data);
}
/// @param output the output of the transform
/// @returns the output program as a WGSL string, or an error string if the
/// program is not valid.

View File

@ -52,7 +52,6 @@ Output Transform::Run(const Program* program,
CloneContext ctx(&builder, program);
Output output;
Run(ctx, data, output.data);
builder.SetTransformApplied(this);
output.program = Program(std::move(builder));
return output;
}
@ -62,22 +61,7 @@ void Transform::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
<< "Transform::Run() unimplemented for " << TypeInfo().name;
}
bool Transform::ShouldRun(const Program*) const {
return true;
}
bool Transform::Requires(
CloneContext& ctx,
std::initializer_list<const ::tint::TypeInfo*> deps) const {
for (auto* dep : deps) {
if (!ctx.src->HasTransformApplied(dep)) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform, std::string(TypeInfo().name) +
" depends on " + std::string(dep->name) +
" but the dependency was not run");
return false;
}
}
bool Transform::ShouldRun(const Program*, const DataMap&) const {
return true;
}

View File

@ -160,8 +160,10 @@ class Transform : public Castable<Transform> {
virtual Output Run(const Program* program, const DataMap& data = {}) const;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
virtual bool ShouldRun(const Program* program) const;
virtual bool ShouldRun(const Program* program,
const DataMap& data = {}) const;
protected:
/// Runs the transform using the CloneContext built for transforming a
@ -174,23 +176,6 @@ class Transform : public Castable<Transform> {
const DataMap& inputs,
DataMap& outputs) const;
/// Requires appends an error diagnostic to `ctx.dst` if the template type
/// transforms were not already run on `ctx.src`.
/// @param ctx the CloneContext
/// @returns true if all dependency transforms have been run
template <typename... TRANSFORMS>
bool Requires(CloneContext& ctx) const {
return Requires(ctx, {&::tint::TypeInfo::Of<TRANSFORMS>()...});
}
/// Requires appends an error diagnostic to `ctx.dst` if the list of
/// Transforms were not already run on `ctx.src`.
/// @param ctx the CloneContext
/// @param deps the list of Transform TypeInfos
/// @returns true if all dependency transforms have been run
bool Requires(CloneContext& ctx,
std::initializer_list<const ::tint::TypeInfo*> deps) const;
/// Removes the statement `stmt` from the transformed program.
/// RemoveStatement handles edge cases, like statements in the initializer and
/// continuing of for-loops.

View File

@ -32,6 +32,22 @@ VectorizeScalarMatrixConstructors::VectorizeScalarMatrixConstructors() =
VectorizeScalarMatrixConstructors::~VectorizeScalarMatrixConstructors() =
default;
bool VectorizeScalarMatrixConstructors::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* call = program->Sem().Get<sem::Call>(node)) {
if (call->Target()->Is<sem::TypeConstructor>() &&
call->Type()->Is<sem::Matrix>()) {
auto& args = call->Arguments();
if (args.size() > 0 && args[0]->Type()->is_scalar()) {
return true;
}
}
}
}
return false;
}
void VectorizeScalarMatrixConstructors::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {

View File

@ -30,6 +30,12 @@ class VectorizeScalarMatrixConstructors
/// Destructor
~VectorizeScalarMatrixConstructors() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -27,6 +27,12 @@ namespace {
using VectorizeScalarMatrixConstructorsTest =
TransformTestWithParam<std::pair<uint32_t, uint32_t>>;
TEST_F(VectorizeScalarMatrixConstructorsTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
}
TEST_P(VectorizeScalarMatrixConstructorsTest, Basic) {
uint32_t cols = GetParam().first;
uint32_t rows = GetParam().second;
@ -63,6 +69,8 @@ fn main() {
auto src = utils::ReplaceAll(tmpl, "${values}", scalar_values);
auto expect = utils::ReplaceAll(tmpl, "${values}", vector_values);
EXPECT_TRUE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
auto got = Run<VectorizeScalarMatrixConstructors>(src);
EXPECT_EQ(expect, str(got));
@ -92,6 +100,8 @@ fn main() {
auto src = utils::ReplaceAll(tmpl, "${columns}", columns);
auto expect = src;
EXPECT_FALSE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
auto got = Run<VectorizeScalarMatrixConstructors>(src);
EXPECT_EQ(expect, str(got));

View File

@ -38,6 +38,16 @@ WrapArraysInStructs::WrapArraysInStructs() = default;
WrapArraysInStructs::~WrapArraysInStructs() = default;
bool WrapArraysInStructs::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (program->Sem().Get<sem::Array>(node->As<ast::Type>())) {
return true;
}
}
return false;
}
void WrapArraysInStructs::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {

View File

@ -44,6 +44,12 @@ class WrapArraysInStructs : public Castable<WrapArraysInStructs, Transform> {
/// Destructor
~WrapArraysInStructs() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -25,9 +25,23 @@ namespace {
using WrapArraysInStructsTest = TransformTest;
TEST_F(WrapArraysInStructsTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<WrapArraysInStructs>(src));
}
TEST_F(WrapArraysInStructsTest, ShouldRunHasArray) {
auto* src = R"(
var<private> arr : array<i32, 4>;
)";
EXPECT_TRUE(ShouldRun<WrapArraysInStructs>(src));
}
TEST_F(WrapArraysInStructsTest, EmptyModule) {
auto* src = "";
auto* expect = "";
auto* src = R"()";
auto* expect = src;
auto got = Run<WrapArraysInStructs>(src);

View File

@ -433,6 +433,18 @@ ZeroInitWorkgroupMemory::ZeroInitWorkgroupMemory() = default;
ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default;
bool ZeroInitWorkgroupMemory::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* decl : program->AST().GlobalDeclarations()) {
if (auto* var = decl->As<ast::Variable>()) {
if (var->declared_storage_class == ast::StorageClass::kWorkgroup) {
return true;
}
}
}
return false;
}
void ZeroInitWorkgroupMemory::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {

View File

@ -32,6 +32,12 @@ class ZeroInitWorkgroupMemory
/// Destructor
~ZeroInitWorkgroupMemory() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -24,6 +24,28 @@ namespace {
using ZeroInitWorkgroupMemoryTest = TransformTest;
TEST_F(ZeroInitWorkgroupMemoryTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<ZeroInitWorkgroupMemory>(src));
}
TEST_F(ZeroInitWorkgroupMemoryTest, ShouldRunHasNoWorkgroupVars) {
auto* src = R"(
var<private> v : i32;
)";
EXPECT_FALSE(ShouldRun<ZeroInitWorkgroupMemory>(src));
}
TEST_F(ZeroInitWorkgroupMemoryTest, ShouldRunHasWorkgroupVars) {
auto* src = R"(
var<workgroup> a : i32;
)";
EXPECT_TRUE(ShouldRun<ZeroInitWorkgroupMemory>(src));
}
TEST_F(ZeroInitWorkgroupMemoryTest, EmptyModule) {
auto* src = "";
auto* expect = src;

View File

@ -124,14 +124,6 @@ GeneratorImpl::GeneratorImpl(const Program* program) : TextGenerator(program) {}
GeneratorImpl::~GeneratorImpl() = default;
bool GeneratorImpl::Generate() {
if (!builder_.HasTransformApplied<transform::Glsl>()) {
diagnostics_.add_error(
diag::System::Writer,
"GLSL writer requires the transform::Glsl sanitizer to have been "
"applied to the input program");
return false;
}
const TypeInfo* last_kind = nullptr;
size_t last_padding_line = 0;

View File

@ -21,16 +21,6 @@ namespace {
using GlslGeneratorImplTest = TestHelper;
TEST_F(GlslGeneratorImplTest, ErrorIfSanitizerNotRun) {
auto program = std::make_unique<Program>(std::move(*this));
GeneratorImpl gen(program.get());
EXPECT_FALSE(gen.Generate());
EXPECT_EQ(
gen.error(),
"error: GLSL writer requires the transform::Glsl sanitizer to have been "
"applied to the input program");
}
TEST_F(GlslGeneratorImplTest, Generate) {
Func("my_func", ast::VariableList{}, ty.void_(), ast::StatementList{},
ast::DecorationList{});

View File

@ -44,9 +44,6 @@ class TestHelperBase : public BODY, public ProgramBuilder {
if (gen_) {
return *gen_;
}
// Fake that the GLSL sanitizer has been applied, so that we can unit test
// the writer without it erroring.
SetTransformApplied<transform::Glsl>();
[&]() {
ASSERT_TRUE(IsValid()) << "Builder program is not valid\n"
<< diag::Formatter().format(Diagnostics());

View File

@ -187,9 +187,10 @@ SanitizedResult Sanitize(
if (!result.program.IsValid()) {
return result;
}
result.used_array_length_from_uniform_indices =
std::move(out.data.Get<transform::ArrayLengthFromUniform::Result>()
->used_size_indices);
if (auto* res = out.data.Get<transform::ArrayLengthFromUniform::Result>()) {
result.used_array_length_from_uniform_indices =
std::move(res->used_size_indices);
}
result.needs_storage_buffer_sizes =
!result.used_array_length_from_uniform_indices.empty();
return result;