optimizations: Implement transform::ShouldRun()

This change adds an override for Transform::ShouldRun() for many of the transforms that can trivially detect whether running would be a no-op or not. Most programs do not require all the transforms to be run, and by skipping those that are not needed, significant performance wins can be had.

This change also removes Transform::Requires() and Program::HasTransformApplied(). This makes little sense now that transforms can be skipped, and the usefulness of this information has been severely reduced since the introduction of transforms that need to be run more than once.
Instread, just document on the transform class what the expectations are.

Issue: tint:1383
Change-Id: I1a6f27cc4ba61ca1475a4ba912c465db619f76c7
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/77121
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-01-25 21:36:04 +00:00 committed by Tint LUCI CQ
parent 12d54d746e
commit 800b8e3175
78 changed files with 909 additions and 298 deletions

View File

@ -57,7 +57,6 @@ Symbol CloneContext::Clone(Symbol s) {
void CloneContext::Clone() { void CloneContext::Clone() {
dst->AST().Copy(this, &src->AST()); dst->AST().Copy(this, &src->AST());
dst->SetTransformApplied(src->TransformsApplied());
} }
ast::FunctionList CloneContext::Clone(const ast::FunctionList& v) { ast::FunctionList CloneContext::Clone(const ast::FunctionList& v) {

View File

@ -42,7 +42,6 @@ Program::Program(Program&& program)
sem_(std::move(program.sem_)), sem_(std::move(program.sem_)),
symbols_(std::move(program.symbols_)), symbols_(std::move(program.symbols_)),
diagnostics_(std::move(program.diagnostics_)), diagnostics_(std::move(program.diagnostics_)),
transforms_applied_(std::move(program.transforms_applied_)),
is_valid_(program.is_valid_) { is_valid_(program.is_valid_) {
program.AssertNotMoved(); program.AssertNotMoved();
program.moved_ = true; program.moved_ = true;
@ -67,7 +66,6 @@ Program::Program(ProgramBuilder&& builder) {
sem_ = std::move(builder.Sem()); sem_ = std::move(builder.Sem());
symbols_ = std::move(builder.Symbols()); symbols_ = std::move(builder.Symbols());
diagnostics_.add(std::move(builder.Diagnostics())); diagnostics_.add(std::move(builder.Diagnostics()));
transforms_applied_ = builder.TransformsApplied();
builder.MarkAsMoved(); builder.MarkAsMoved();
if (!is_valid_ && !diagnostics_.contains_errors()) { if (!is_valid_ && !diagnostics_.contains_errors()) {
@ -92,7 +90,6 @@ Program& Program::operator=(Program&& program) {
sem_ = std::move(program.sem_); sem_ = std::move(program.sem_);
symbols_ = std::move(program.symbols_); symbols_ = std::move(program.symbols_);
diagnostics_ = std::move(program.diagnostics_); diagnostics_ = std::move(program.diagnostics_);
transforms_applied_ = std::move(program.transforms_applied_);
is_valid_ = program.is_valid_; is_valid_ = program.is_valid_;
return *this; return *this;
} }

View File

@ -126,25 +126,6 @@ class Program {
/// information /// information
bool IsValid() const; bool IsValid() const;
/// @return the TypeInfo pointers of all transforms that have been applied to
/// this program.
std::unordered_set<const TypeInfo*> TransformsApplied() const {
return transforms_applied_;
}
/// @param transform the TypeInfo of the transform
/// @returns true if the transform with the given TypeInfo was applied to the
/// Program
bool HasTransformApplied(const TypeInfo* transform) const {
return transforms_applied_.count(transform);
}
/// @returns true if the transform of type `T` was applied.
template <typename T>
bool HasTransformApplied() const {
return HasTransformApplied(&TypeInfo::Of<T>());
}
/// Helper for returning the resolved semantic type of the expression `expr`. /// Helper for returning the resolved semantic type of the expression `expr`.
/// @param expr the AST expression /// @param expr the AST expression
/// @return the resolved semantic type for the expression, or nullptr if the /// @return the resolved semantic type for the expression, or nullptr if the
@ -184,7 +165,6 @@ class Program {
sem::Info sem_; sem::Info sem_;
SymbolTable symbols_{id_}; SymbolTable symbols_{id_};
diag::List diagnostics_; diag::List diagnostics_;
std::unordered_set<const TypeInfo*> transforms_applied_;
bool is_valid_ = false; // Not valid until it is built bool is_valid_ = false; // Not valid until it is built
bool moved_ = false; bool moved_ = false;
}; };

View File

@ -38,8 +38,7 @@ ProgramBuilder::ProgramBuilder(ProgramBuilder&& rhs)
ast_(rhs.ast_), ast_(rhs.ast_),
sem_(std::move(rhs.sem_)), sem_(std::move(rhs.sem_)),
symbols_(std::move(rhs.symbols_)), symbols_(std::move(rhs.symbols_)),
diagnostics_(std::move(rhs.diagnostics_)), diagnostics_(std::move(rhs.diagnostics_)) {
transforms_applied_(std::move(rhs.transforms_applied_)) {
rhs.MarkAsMoved(); rhs.MarkAsMoved();
} }
@ -56,7 +55,7 @@ ProgramBuilder& ProgramBuilder::operator=(ProgramBuilder&& rhs) {
sem_ = std::move(rhs.sem_); sem_ = std::move(rhs.sem_);
symbols_ = std::move(rhs.symbols_); symbols_ = std::move(rhs.symbols_);
diagnostics_ = std::move(rhs.diagnostics_); diagnostics_ = std::move(rhs.diagnostics_);
transforms_applied_ = std::move(rhs.transforms_applied_);
return *this; return *this;
} }
@ -69,7 +68,6 @@ ProgramBuilder ProgramBuilder::Wrap(const Program* program) {
builder.sem_ = sem::Info::Wrap(program->Sem()); builder.sem_ = sem::Info::Wrap(program->Sem());
builder.symbols_ = program->Symbols(); builder.symbols_ = program->Symbols();
builder.diagnostics_ = program->Diagnostics(); builder.diagnostics_ = program->Diagnostics();
builder.transforms_applied_ = program->TransformsApplied();
return builder; return builder;
} }

View File

@ -2469,49 +2469,6 @@ class ProgramBuilder {
source_ = Source(loc); source_ = Source(loc);
} }
/// Marks that the given transform has been applied to this program.
/// @param transform the transform that has been applied
void SetTransformApplied(const CastableBase* transform) {
transforms_applied_.emplace(&transform->TypeInfo());
}
/// Marks that the given transform `T` has been applied to this program.
template <typename T>
void SetTransformApplied() {
transforms_applied_.emplace(&TypeInfo::Of<T>());
}
/// Marks that the transforms with the given TypeInfos have been applied to
/// this program.
/// @param transforms the set of transform TypeInfos that has been applied
void SetTransformApplied(
const std::unordered_set<const TypeInfo*>& transforms) {
for (auto* transform : transforms) {
transforms_applied_.emplace(transform);
}
}
/// Unmarks that the given transform `T` has been applied to this program.
template <typename T>
void UnsetTransformApplied() {
auto it = transforms_applied_.find(&TypeInfo::Of<T>());
if (it != transforms_applied_.end()) {
transforms_applied_.erase(it);
}
}
/// @returns true if the transform of type `T` was applied.
template <typename T>
bool HasTransformApplied() {
return transforms_applied_.count(&TypeInfo::Of<T>());
}
/// @return the TypeInfo pointers of all transforms that have been applied to
/// this program.
std::unordered_set<const TypeInfo*> TransformsApplied() const {
return transforms_applied_;
}
/// Helper for returning the resolved semantic type of the expression `expr`. /// Helper for returning the resolved semantic type of the expression `expr`.
/// @note As the Resolver is run when the Program is built, this will only be /// @note As the Resolver is run when the Program is built, this will only be
/// useful for the Resolver itself and tests that use their own Resolver. /// useful for the Resolver itself and tests that use their own Resolver.
@ -2592,7 +2549,6 @@ class ProgramBuilder {
sem::Info sem_; sem::Info sem_;
SymbolTable symbols_{id_}; SymbolTable symbols_{id_};
diag::List diagnostics_; diag::List diagnostics_;
std::unordered_set<const TypeInfo*> transforms_applied_;
/// The source to use when creating AST nodes without providing a Source as /// The source to use when creating AST nodes without providing a Source as
/// the first argument. /// the first argument.

View File

@ -27,15 +27,19 @@ AddEmptyEntryPoint::AddEmptyEntryPoint() = default;
AddEmptyEntryPoint::~AddEmptyEntryPoint() = default; AddEmptyEntryPoint::~AddEmptyEntryPoint() = default;
bool AddEmptyEntryPoint::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* func : program->AST().Functions()) {
if (func->IsEntryPoint()) {
return false;
}
}
return true;
}
void AddEmptyEntryPoint::Run(CloneContext& ctx, void AddEmptyEntryPoint::Run(CloneContext& ctx,
const DataMap&, const DataMap&,
DataMap&) const { DataMap&) const {
for (auto* func : ctx.src->AST().Functions()) {
if (func->IsEntryPoint()) {
ctx.Clone();
return;
}
}
ctx.dst->Func(ctx.dst->Symbols().New("unused_entry_point"), {}, ctx.dst->Func(ctx.dst->Symbols().New("unused_entry_point"), {},
ctx.dst->ty.void_(), {}, ctx.dst->ty.void_(), {},
{ctx.dst->Stage(ast::PipelineStage::kCompute), {ctx.dst->Stage(ast::PipelineStage::kCompute),

View File

@ -28,6 +28,12 @@ class AddEmptyEntryPoint : public Castable<AddEmptyEntryPoint, Transform> {
/// Destructor /// Destructor
~AddEmptyEntryPoint() override; ~AddEmptyEntryPoint() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected: protected:
/// Runs the transform using the CloneContext built for transforming a /// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext. /// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -24,6 +24,21 @@ namespace {
using AddEmptyEntryPointTest = TransformTest; using AddEmptyEntryPointTest = TransformTest;
TEST_F(AddEmptyEntryPointTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_TRUE(ShouldRun<AddEmptyEntryPoint>(src));
}
TEST_F(AddEmptyEntryPointTest, ShouldRunExistingEntryPoint) {
auto* src = R"(
[[stage(compute), workgroup_size(1)]]
fn existing() {}
)";
EXPECT_FALSE(ShouldRun<AddEmptyEntryPoint>(src));
}
TEST_F(AddEmptyEntryPointTest, EmptyModule) { TEST_F(AddEmptyEntryPointTest, EmptyModule) {
auto* src = R"()"; auto* src = R"()";

View File

@ -20,6 +20,7 @@
#include "src/program_builder.h" #include "src/program_builder.h"
#include "src/sem/call.h" #include "src/sem/call.h"
#include "src/sem/function.h"
#include "src/sem/variable.h" #include "src/sem/variable.h"
#include "src/transform/simplify_pointers.h" #include "src/transform/simplify_pointers.h"
@ -93,13 +94,23 @@ static void IterateArrayLengthOnStorageVar(CloneContext& ctx, F&& functor) {
} }
} }
bool ArrayLengthFromUniform::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* fn : program->AST().Functions()) {
if (auto* sem_fn = program->Sem().Get(fn)) {
for (auto* intrinsic : sem_fn->DirectlyCalledIntrinsics()) {
if (intrinsic->Type() == sem::IntrinsicType::kArrayLength) {
return true;
}
}
}
}
return false;
}
void ArrayLengthFromUniform::Run(CloneContext& ctx, void ArrayLengthFromUniform::Run(CloneContext& ctx,
const DataMap& inputs, const DataMap& inputs,
DataMap& outputs) const { DataMap& outputs) const {
if (!Requires<SimplifyPointers>(ctx)) {
return;
}
auto* cfg = inputs.Get<Config>(); auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) { if (cfg == nullptr) {
ctx.dst->Diagnostics().add_error( ctx.dst->Diagnostics().add_error(

View File

@ -49,6 +49,9 @@ namespace transform {
/// This transform assumes that the `SimplifyPointers` /// This transform assumes that the `SimplifyPointers`
/// transforms have been run before it so that arguments to the arrayLength /// transforms have been run before it so that arguments to the arrayLength
/// builtin always have the form `&resource.array`. /// builtin always have the form `&resource.array`.
///
/// @note Depends on the following transforms to have been run first:
/// * SimplifyPointers
class ArrayLengthFromUniform class ArrayLengthFromUniform
: public Castable<ArrayLengthFromUniform, Transform> { : public Castable<ArrayLengthFromUniform, Transform> {
public: public:
@ -81,6 +84,8 @@ class ArrayLengthFromUniform
}; };
/// Information produced about what the transform did. /// Information produced about what the transform did.
/// If there were no calls to the arrayLength() intrinsic, then no Result will
/// be emitted.
struct Result : public Castable<Result, transform::Data> { struct Result : public Castable<Result, transform::Data> {
/// Constructor /// Constructor
/// @param used_size_indices Indices into the UBO that are statically used. /// @param used_size_indices Indices into the UBO that are statically used.
@ -96,6 +101,12 @@ class ArrayLengthFromUniform
const std::unordered_set<uint32_t> used_size_indices; const std::unordered_set<uint32_t> used_size_indices;
}; };
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected: protected:
/// Runs the transform using the CloneContext built for transforming a /// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext. /// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -26,8 +26,61 @@ namespace {
using ArrayLengthFromUniformTest = TransformTest; using ArrayLengthFromUniformTest = TransformTest;
TEST_F(ArrayLengthFromUniformTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src));
}
TEST_F(ArrayLengthFromUniformTest, ShouldRunNoArrayLength) {
auto* src = R"(
struct SB {
x : i32;
arr : array<i32>;
};
[[group(0), binding(0)]] var<storage, read> sb : SB;
[[stage(compute), workgroup_size(1)]]
fn main() {
}
)";
EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src));
}
TEST_F(ArrayLengthFromUniformTest, ShouldRunWithArrayLength) {
auto* src = R"(
struct SB {
x : i32;
arr : array<i32>;
};
[[group(0), binding(0)]] var<storage, read> sb : SB;
[[stage(compute), workgroup_size(1)]]
fn main() {
var len : u32 = arrayLength(&sb.arr);
}
)";
EXPECT_TRUE(ShouldRun<ArrayLengthFromUniform>(src));
}
TEST_F(ArrayLengthFromUniformTest, Error_MissingTransformData) { TEST_F(ArrayLengthFromUniformTest, Error_MissingTransformData) {
auto* src = ""; auto* src = R"(
struct SB {
x : i32;
arr : array<i32>;
};
[[group(0), binding(0)]] var<storage, read> sb : SB;
[[stage(compute), workgroup_size(1)]]
fn main() {
var len : u32 = arrayLength(&sb.arr);
}
)";
auto* expect = auto* expect =
"error: missing transform data for " "error: missing transform data for "
@ -38,18 +91,6 @@ TEST_F(ArrayLengthFromUniformTest, Error_MissingTransformData) {
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
} }
TEST_F(ArrayLengthFromUniformTest, Error_MissingSimplifyPointers) {
auto* src = "";
auto* expect =
"error: tint::transform::ArrayLengthFromUniform depends on "
"tint::transform::SimplifyPointers but the dependency was not run";
auto got = Run<ArrayLengthFromUniform>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ArrayLengthFromUniformTest, Basic) { TEST_F(ArrayLengthFromUniformTest, Basic) {
auto* src = R"( auto* src = R"(
@group(0) @binding(0) var<storage, read> sb : array<i32>; @group(0) @binding(0) var<storage, read> sb : array<i32>;
@ -426,8 +467,7 @@ fn main() {
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data); auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(src, str(got)); EXPECT_EQ(src, str(got));
EXPECT_EQ(std::unordered_set<uint32_t>(), EXPECT_EQ(got.data.Get<ArrayLengthFromUniform::Result>(), nullptr);
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
} }
TEST_F(ArrayLengthFromUniformTest, MissingBindingPointToIndexMapping) { TEST_F(ArrayLengthFromUniformTest, MissingBindingPointToIndexMapping) {

View File

@ -42,6 +42,14 @@ BindingRemapper::Remappings::~Remappings() = default;
BindingRemapper::BindingRemapper() = default; BindingRemapper::BindingRemapper() = default;
BindingRemapper::~BindingRemapper() = default; BindingRemapper::~BindingRemapper() = default;
bool BindingRemapper::ShouldRun(const Program*, const DataMap& inputs) const {
if (auto* remappings = inputs.Get<Remappings>()) {
return !remappings->binding_points.empty() ||
!remappings->access_controls.empty();
}
return false;
}
void BindingRemapper::Run(CloneContext& ctx, void BindingRemapper::Run(CloneContext& ctx,
const DataMap& inputs, const DataMap& inputs,
DataMap&) const { DataMap&) const {

View File

@ -68,6 +68,12 @@ class BindingRemapper : public Castable<BindingRemapper, Transform> {
BindingRemapper(); BindingRemapper();
~BindingRemapper() override; ~BindingRemapper() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected: protected:
/// Runs the transform using the CloneContext built for transforming a /// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext. /// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -24,6 +24,47 @@ namespace {
using BindingRemapperTest = TransformTest; using BindingRemapperTest = TransformTest;
TEST_F(BindingRemapperTest, ShouldRunNoRemappings) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<BindingRemapper>(src));
}
TEST_F(BindingRemapperTest, ShouldRunEmptyRemappings) {
auto* src = R"()";
DataMap data;
data.Add<BindingRemapper::Remappings>(BindingRemapper::BindingPoints{},
BindingRemapper::AccessControls{});
EXPECT_FALSE(ShouldRun<BindingRemapper>(src, data));
}
TEST_F(BindingRemapperTest, ShouldRunBindingPointRemappings) {
auto* src = R"()";
DataMap data;
data.Add<BindingRemapper::Remappings>(
BindingRemapper::BindingPoints{
{{2, 1}, {1, 2}},
},
BindingRemapper::AccessControls{});
EXPECT_TRUE(ShouldRun<BindingRemapper>(src, data));
}
TEST_F(BindingRemapperTest, ShouldRunAccessControlRemappings) {
auto* src = R"()";
DataMap data;
data.Add<BindingRemapper::Remappings>(BindingRemapper::BindingPoints{},
BindingRemapper::AccessControls{
{{2, 1}, ast::Access::kWrite},
});
EXPECT_TRUE(ShouldRun<BindingRemapper>(src, data));
}
TEST_F(BindingRemapperTest, NoRemappings) { TEST_F(BindingRemapperTest, NoRemappings) {
auto* src = R"( auto* src = R"(
struct S { struct S {
@ -359,17 +400,18 @@ TEST_F(BindingRemapperTest, NoData) {
auto* src = R"( auto* src = R"(
struct S { struct S {
a : f32; a : f32;
}; }
@group(2) @binding(1) var<storage, read> a : S; @group(2) @binding(1) var<storage, read> a : S;
@group(3) @binding(2) var<storage, read> b : S; @group(3) @binding(2) var<storage, read> b : S;
@stage(compute) @workgroup_size(1) @stage(compute) @workgroup_size(1)
fn f() {} fn f() {
}
)"; )";
auto* expect = auto* expect = src;
"error: missing transform data for tint::transform::BindingRemapper";
auto got = Run<BindingRemapper>(src); auto got = Run<BindingRemapper>(src);

View File

@ -22,6 +22,7 @@
#include "src/program_builder.h" #include "src/program_builder.h"
#include "src/sem/block_statement.h" #include "src/sem/block_statement.h"
#include "src/sem/call.h" #include "src/sem/call.h"
#include "src/sem/function.h"
#include "src/sem/statement.h" #include "src/sem/statement.h"
#include "src/sem/struct.h" #include "src/sem/struct.h"
#include "src/sem/variable.h" #include "src/sem/variable.h"
@ -71,13 +72,24 @@ CalculateArrayLength::BufferSizeIntrinsic::Clone(CloneContext* ctx) const {
CalculateArrayLength::CalculateArrayLength() = default; CalculateArrayLength::CalculateArrayLength() = default;
CalculateArrayLength::~CalculateArrayLength() = default; CalculateArrayLength::~CalculateArrayLength() = default;
bool CalculateArrayLength::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* fn : program->AST().Functions()) {
if (auto* sem_fn = program->Sem().Get(fn)) {
for (auto* intrinsic : sem_fn->DirectlyCalledIntrinsics()) {
if (intrinsic->Type() == sem::IntrinsicType::kArrayLength) {
return true;
}
}
}
}
return false;
}
void CalculateArrayLength::Run(CloneContext& ctx, void CalculateArrayLength::Run(CloneContext& ctx,
const DataMap&, const DataMap&,
DataMap&) const { DataMap&) const {
auto& sem = ctx.src->Sem(); auto& sem = ctx.src->Sem();
if (!Requires<SimplifyPointers>(ctx)) {
return;
}
// get_buffer_size_intrinsic() emits the function decorated with // get_buffer_size_intrinsic() emits the function decorated with
// BufferSizeIntrinsic that is transformed by the HLSL writer into a call to // BufferSizeIntrinsic that is transformed by the HLSL writer into a call to

View File

@ -29,6 +29,9 @@ namespace transform {
/// CalculateArrayLength is a transform used to replace calls to arrayLength() /// CalculateArrayLength is a transform used to replace calls to arrayLength()
/// with a value calculated from the size of the storage buffer. /// with a value calculated from the size of the storage buffer.
///
/// @note Depends on the following transforms to have been run first:
/// * SimplifyPointers
class CalculateArrayLength : public Castable<CalculateArrayLength, Transform> { class CalculateArrayLength : public Castable<CalculateArrayLength, Transform> {
public: public:
/// BufferSizeIntrinsic is an InternalDecoration that's applied to intrinsic /// BufferSizeIntrinsic is an InternalDecoration that's applied to intrinsic
@ -56,6 +59,12 @@ class CalculateArrayLength : public Castable<CalculateArrayLength, Transform> {
/// Destructor /// Destructor
~CalculateArrayLength() override; ~CalculateArrayLength() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected: protected:
/// Runs the transform using the CloneContext built for transforming a /// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext. /// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -24,16 +24,45 @@ namespace {
using CalculateArrayLengthTest = TransformTest; using CalculateArrayLengthTest = TransformTest;
TEST_F(CalculateArrayLengthTest, Error_MissingCalculateArrayLength) { TEST_F(CalculateArrayLengthTest, ShouldRunEmptyModule) {
auto* src = ""; auto* src = R"()";
auto* expect = EXPECT_FALSE(ShouldRun<CalculateArrayLength>(src));
"error: tint::transform::CalculateArrayLength depends on " }
"tint::transform::SimplifyPointers but the dependency was not run";
auto got = Run<CalculateArrayLength>(src); TEST_F(CalculateArrayLengthTest, ShouldRunNoArrayLength) {
auto* src = R"(
struct SB {
x : i32;
arr : array<i32>;
};
EXPECT_EQ(expect, str(got)); [[group(0), binding(0)]] var<storage, read> sb : SB;
[[stage(compute), workgroup_size(1)]]
fn main() {
}
)";
EXPECT_FALSE(ShouldRun<CalculateArrayLength>(src));
}
TEST_F(CalculateArrayLengthTest, ShouldRunWithArrayLength) {
auto* src = R"(
struct SB {
x : i32;
arr : array<i32>;
};
[[group(0), binding(0)]] var<storage, read> sb : SB;
[[stage(compute), workgroup_size(1)]]
fn main() {
var len : u32 = arrayLength(&sb.arr);
}
)";
EXPECT_TRUE(ShouldRun<CalculateArrayLength>(src));
} }
TEST_F(CalculateArrayLengthTest, Basic) { TEST_F(CalculateArrayLengthTest, Basic) {

View File

@ -551,10 +551,6 @@ struct CanonicalizeEntryPointIO::State {
void CanonicalizeEntryPointIO::Run(CloneContext& ctx, void CanonicalizeEntryPointIO::Run(CloneContext& ctx,
const DataMap& inputs, const DataMap& inputs,
DataMap&) const { DataMap&) const {
if (!Requires<Unshadow>(ctx)) {
return;
}
auto* cfg = inputs.Get<Config>(); auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) { if (cfg == nullptr) {
ctx.dst->Diagnostics().add_error( ctx.dst->Diagnostics().add_error(

View File

@ -80,6 +80,9 @@ namespace transform {
/// return wrapper_result; /// return wrapper_result;
/// } /// }
/// ``` /// ```
///
/// @note Depends on the following transforms to have been run first:
/// * Unshadow
class CanonicalizeEntryPointIO class CanonicalizeEntryPointIO
: public Castable<CanonicalizeEntryPointIO, Transform> { : public Castable<CanonicalizeEntryPointIO, Transform> {
public: public:

View File

@ -23,18 +23,6 @@ namespace {
using CanonicalizeEntryPointIOTest = TransformTest; using CanonicalizeEntryPointIOTest = TransformTest;
TEST_F(CanonicalizeEntryPointIOTest, Error_MissingUnshadow) {
auto* src = "";
auto* expect =
"error: tint::transform::CanonicalizeEntryPointIO depends on "
"tint::transform::Unshadow but the dependency was not run";
auto got = Run<CanonicalizeEntryPointIO>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(CanonicalizeEntryPointIOTest, Error_MissingTransformData) { TEST_F(CanonicalizeEntryPointIOTest, Error_MissingTransformData) {
auto* src = ""; auto* src = "";

View File

@ -790,6 +790,19 @@ const DecomposeMemoryAccess::Intrinsic* DecomposeMemoryAccess::Intrinsic::Clone(
DecomposeMemoryAccess::DecomposeMemoryAccess() = default; DecomposeMemoryAccess::DecomposeMemoryAccess() = default;
DecomposeMemoryAccess::~DecomposeMemoryAccess() = default; DecomposeMemoryAccess::~DecomposeMemoryAccess() = default;
bool DecomposeMemoryAccess::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* decl : program->AST().GlobalDeclarations()) {
if (auto* var = program->Sem().Get<sem::Variable>(decl)) {
if (var->StorageClass() == ast::StorageClass::kStorage ||
var->StorageClass() == ast::StorageClass::kUniform) {
return true;
}
}
}
return false;
}
void DecomposeMemoryAccess::Run(CloneContext& ctx, void DecomposeMemoryAccess::Run(CloneContext& ctx,
const DataMap&, const DataMap&,
DataMap&) const { DataMap&) const {

View File

@ -105,6 +105,12 @@ class DecomposeMemoryAccess
/// Destructor /// Destructor
~DecomposeMemoryAccess() override; ~DecomposeMemoryAccess() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected: protected:
/// Runs the transform using the CloneContext built for transforming a /// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext. /// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -22,6 +22,34 @@ namespace {
using DecomposeMemoryAccessTest = TransformTest; using DecomposeMemoryAccessTest = TransformTest;
TEST_F(DecomposeMemoryAccessTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<DecomposeMemoryAccess>(src));
}
TEST_F(DecomposeMemoryAccessTest, ShouldRunStorageBuffer) {
auto* src = R"(
struct Buffer {
i : i32;
};
[[group(0), binding(0)]] var<storage, read_write> sb : Buffer;
)";
EXPECT_TRUE(ShouldRun<DecomposeMemoryAccess>(src));
}
TEST_F(DecomposeMemoryAccessTest, ShouldRunUniformBuffer) {
auto* src = R"(
struct Buffer {
i : i32;
};
[[group(0), binding(0)]] var<uniform> ub : Buffer;
)";
EXPECT_TRUE(ShouldRun<DecomposeMemoryAccess>(src));
}
TEST_F(DecomposeMemoryAccessTest, SB_BasicLoad) { TEST_F(DecomposeMemoryAccessTest, SB_BasicLoad) {
auto* src = R"( auto* src = R"(
struct SB { struct SB {

View File

@ -107,7 +107,8 @@ DecomposeStridedMatrix::DecomposeStridedMatrix() = default;
DecomposeStridedMatrix::~DecomposeStridedMatrix() = default; DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
bool DecomposeStridedMatrix::ShouldRun(const Program* program) const { bool DecomposeStridedMatrix::ShouldRun(const Program* program,
const DataMap&) const {
bool should_run = false; bool should_run = false;
GatherCustomStrideMatrixMembers( GatherCustomStrideMatrixMembers(
program, [&](const sem::StructMember*, sem::Matrix*, uint32_t) { program, [&](const sem::StructMember*, sem::Matrix*, uint32_t) {
@ -120,10 +121,6 @@ bool DecomposeStridedMatrix::ShouldRun(const Program* program) const {
void DecomposeStridedMatrix::Run(CloneContext& ctx, void DecomposeStridedMatrix::Run(CloneContext& ctx,
const DataMap&, const DataMap&,
DataMap&) const { DataMap&) const {
if (!Requires<SimplifyPointers>(ctx)) {
return;
}
// Scan the program for all storage and uniform structure matrix members with // Scan the program for all storage and uniform structure matrix members with
// a custom stride attribute. Replace these matrices with an equivalent array, // a custom stride attribute. Replace these matrices with an equivalent array,
// and populate the `decomposed` map with the members that have been replaced. // and populate the `decomposed` map with the members that have been replaced.

View File

@ -25,6 +25,9 @@ namespace transform {
/// of N column vectors. /// of N column vectors.
/// This transform is used by the SPIR-V reader to handle the SPIR-V /// This transform is used by the SPIR-V reader to handle the SPIR-V
/// MatrixStride decoration. /// MatrixStride decoration.
///
/// @note Depends on the following transforms to have been run first:
/// * SimplifyPointers
class DecomposeStridedMatrix class DecomposeStridedMatrix
: public Castable<DecomposeStridedMatrix, Transform> { : public Castable<DecomposeStridedMatrix, Transform> {
public: public:
@ -35,8 +38,10 @@ class DecomposeStridedMatrix
~DecomposeStridedMatrix() override; ~DecomposeStridedMatrix() override;
/// @param program the program to inspect /// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program /// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program) const override; bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected: protected:
/// Runs the transform using the CloneContext built for transforming a /// Runs the transform using the CloneContext built for transforming a

View File

@ -31,55 +31,25 @@ namespace {
using DecomposeStridedMatrixTest = TransformTest; using DecomposeStridedMatrixTest = TransformTest;
using f32 = ProgramBuilder::f32; using f32 = ProgramBuilder::f32;
TEST_F(DecomposeStridedMatrixTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<DecomposeStridedMatrix>(src));
}
TEST_F(DecomposeStridedMatrixTest, ShouldRunNonStridedMatrox) {
auto* src = R"(
var<private> m : mat3x2<f32>;
)";
EXPECT_FALSE(ShouldRun<DecomposeStridedMatrix>(src));
}
TEST_F(DecomposeStridedMatrixTest, Empty) { TEST_F(DecomposeStridedMatrixTest, Empty) {
auto* src = R"()"; auto* src = R"()";
auto* expect = src; auto* expect = src;
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(src); auto got = Run<DecomposeStridedMatrix>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStridedMatrixTest, MissingDependencySimplify) {
// struct S {
// [[offset(16), stride(32)]]
// [[internal(ignore_stride_decoration)]]
// m : mat2x2<f32>;
// };
// [[group(0), binding(0)]] var<uniform> s : S;
//
// [[stage(compute), workgroup_size(1)]]
// fn f() {
// let x : mat2x2<f32> = s.m;
// }
ProgramBuilder b;
auto* S = b.Structure(
"S",
{
b.Member(
"m", b.ty.mat2x2<f32>(),
{
b.create<ast::StructMemberOffsetDecoration>(16),
b.create<ast::StrideDecoration>(32),
b.Disable(ast::DisabledValidation::kIgnoreStrideDecoration),
}),
});
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
b.GroupAndBinding(0, 0));
b.Func(
"f", {}, b.ty.void_(),
{
b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
},
{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1),
});
auto* expect =
R"(error: tint::transform::DecomposeStridedMatrix depends on tint::transform::SimplifyPointers but the dependency was not run)";
auto got = Run<DecomposeStridedMatrix>(Program(std::move(b)));
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
} }

View File

@ -57,6 +57,15 @@ FirstIndexOffset::Data::~Data() = default;
FirstIndexOffset::FirstIndexOffset() = default; FirstIndexOffset::FirstIndexOffset() = default;
FirstIndexOffset::~FirstIndexOffset() = default; FirstIndexOffset::~FirstIndexOffset() = default;
bool FirstIndexOffset::ShouldRun(const Program* program, const DataMap&) const {
for (auto* fn : program->AST().Functions()) {
if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
return true;
}
}
return false;
}
void FirstIndexOffset::Run(CloneContext& ctx, void FirstIndexOffset::Run(CloneContext& ctx,
const DataMap& inputs, const DataMap& inputs,
DataMap& outputs) const { DataMap& outputs) const {

View File

@ -115,6 +115,12 @@ class FirstIndexOffset : public Castable<FirstIndexOffset, Transform> {
/// Destructor /// Destructor
~FirstIndexOffset() override; ~FirstIndexOffset() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected: protected:
/// Runs the transform using the CloneContext built for transforming a /// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext. /// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -26,6 +26,34 @@ namespace {
using FirstIndexOffsetTest = TransformTest; using FirstIndexOffsetTest = TransformTest;
TEST_F(FirstIndexOffsetTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<FirstIndexOffset>(src));
}
TEST_F(FirstIndexOffsetTest, ShouldRunFragmentStage) {
auto* src = R"(
[[stage(fragment)]]
fn entry() {
return;
}
)";
EXPECT_FALSE(ShouldRun<FirstIndexOffset>(src));
}
TEST_F(FirstIndexOffsetTest, ShouldRunVertexStage) {
auto* src = R"(
[[stage(vertex)]]
fn entry() -> [[builtin(position)]] vec4<f32> {
return vec4<f32>();
}
)";
EXPECT_TRUE(ShouldRun<FirstIndexOffset>(src));
}
TEST_F(FirstIndexOffsetTest, EmptyModule) { TEST_F(FirstIndexOffsetTest, EmptyModule) {
auto* src = ""; auto* src = "";
auto* expect = ""; auto* expect = "";
@ -38,6 +66,26 @@ TEST_F(FirstIndexOffsetTest, EmptyModule) {
auto* data = got.data.Get<FirstIndexOffset::Data>(); auto* data = got.data.Get<FirstIndexOffset::Data>();
EXPECT_EQ(data, nullptr);
}
TEST_F(FirstIndexOffsetTest, BasicVertexShader) {
auto* src = R"(
@stage(vertex)
fn entry() -> @builtin(position) vec4<f32> {
return vec4<f32>();
}
)";
auto* expect = src;
DataMap config;
config.Add<FirstIndexOffset::BindingPoint>(0, 0);
auto got = Run<FirstIndexOffset>(src, std::move(config));
EXPECT_EQ(expect, str(got));
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr); ASSERT_NE(data, nullptr);
EXPECT_EQ(data->has_vertex_index, false); EXPECT_EQ(data->has_vertex_index, false);
EXPECT_EQ(data->has_instance_index, false); EXPECT_EQ(data->has_instance_index, false);

View File

@ -25,6 +25,15 @@ ForLoopToLoop::ForLoopToLoop() = default;
ForLoopToLoop::~ForLoopToLoop() = default; ForLoopToLoop::~ForLoopToLoop() = default;
bool ForLoopToLoop::ShouldRun(const Program* program, const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (node->Is<ast::ForLoopStatement>()) {
return true;
}
}
return false;
}
void ForLoopToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const { void ForLoopToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
ctx.ReplaceAll( ctx.ReplaceAll(
[&](const ast::ForLoopStatement* for_loop) -> const ast::Statement* { [&](const ast::ForLoopStatement* for_loop) -> const ast::Statement* {

View File

@ -30,6 +30,12 @@ class ForLoopToLoop : public Castable<ForLoopToLoop, Transform> {
/// Destructor /// Destructor
~ForLoopToLoop() override; ~ForLoopToLoop() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected: protected:
/// Runs the transform using the CloneContext built for transforming a /// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext. /// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -22,6 +22,24 @@ namespace {
using ForLoopToLoopTest = TransformTest; using ForLoopToLoopTest = TransformTest;
TEST_F(ForLoopToLoopTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<ForLoopToLoop>(src));
}
TEST_F(ForLoopToLoopTest, ShouldRunHasForLoop) {
auto* src = R"(
fn f() {
for (;;) {
break;
}
}
)";
EXPECT_TRUE(ShouldRun<ForLoopToLoop>(src));
}
TEST_F(ForLoopToLoopTest, EmptyModule) { TEST_F(ForLoopToLoopTest, EmptyModule) {
auto* src = ""; auto* src = "";
auto* expect = src; auto* expect = src;

View File

@ -113,7 +113,6 @@ Output Glsl::Run(const Program* in, const DataMap& inputs) const {
ProgramBuilder builder; ProgramBuilder builder;
CloneContext ctx(&builder, &out.program); CloneContext ctx(&builder, &out.program);
ctx.Clone(); ctx.Clone();
builder.SetTransformApplied(this);
return Output{Program(std::move(builder))}; return Output{Program(std::move(builder))};
} }

View File

@ -216,15 +216,9 @@ LocalizeStructArrayAssignment::~LocalizeStructArrayAssignment() = default;
void LocalizeStructArrayAssignment::Run(CloneContext& ctx, void LocalizeStructArrayAssignment::Run(CloneContext& ctx,
const DataMap&, const DataMap&,
DataMap&) const { DataMap&) const {
if (!Requires<SimplifyPointers>(ctx)) {
return;
}
State state(ctx); State state(ctx);
state.Run(); state.Run();
// This transform may introduce pointers
ctx.dst->UnsetTransformApplied<transform::SimplifyPointers>();
} }
} // namespace transform } // namespace transform
} // namespace tint } // namespace tint

View File

@ -25,6 +25,9 @@ namespace transform {
/// temporary local variable, assigns to the local variable, and copies the /// temporary local variable, assigns to the local variable, and copies the
/// array back. This is to work around FXC's compilation failure for these cases /// array back. This is to work around FXC's compilation failure for these cases
/// (see crbug.com/tint/1206). /// (see crbug.com/tint/1206).
///
/// @note Depends on the following transforms to have been run first:
/// * SimplifyPointers
class LocalizeStructArrayAssignment class LocalizeStructArrayAssignment
: public Castable<LocalizeStructArrayAssignment, Transform> { : public Castable<LocalizeStructArrayAssignment, Transform> {
public: public:

View File

@ -24,16 +24,6 @@ namespace {
using LocalizeStructArrayAssignmentTest = TransformTest; using LocalizeStructArrayAssignmentTest = TransformTest;
TEST_F(LocalizeStructArrayAssignmentTest, MissingSimplifyPointers) {
auto* src = R"()";
auto* expect =
"error: tint::transform::LocalizeStructArrayAssignment depends on "
"tint::transform::SimplifyPointers but the dependency was not run";
auto got = Run<LocalizeStructArrayAssignment>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(LocalizeStructArrayAssignmentTest, EmptyModule) { TEST_F(LocalizeStructArrayAssignmentTest, EmptyModule) {
auto* src = R"()"; auto* src = R"()";
auto* expect = src; auto* expect = src;

View File

@ -56,6 +56,15 @@ LoopToForLoop::LoopToForLoop() = default;
LoopToForLoop::~LoopToForLoop() = default; LoopToForLoop::~LoopToForLoop() = default;
bool LoopToForLoop::ShouldRun(const Program* program, const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (node->Is<ast::LoopStatement>()) {
return true;
}
}
return false;
}
void LoopToForLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const { void LoopToForLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
ctx.ReplaceAll([&](const ast::LoopStatement* loop) -> const ast::Statement* { ctx.ReplaceAll([&](const ast::LoopStatement* loop) -> const ast::Statement* {
// For loop condition is taken from the first statement in the loop. // For loop condition is taken from the first statement in the loop.

View File

@ -30,6 +30,12 @@ class LoopToForLoop : public Castable<LoopToForLoop, Transform> {
/// Destructor /// Destructor
~LoopToForLoop() override; ~LoopToForLoop() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected: protected:
/// Runs the transform using the CloneContext built for transforming a /// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext. /// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -22,6 +22,24 @@ namespace {
using LoopToForLoopTest = TransformTest; using LoopToForLoopTest = TransformTest;
TEST_F(LoopToForLoopTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<LoopToForLoop>(src));
}
TEST_F(LoopToForLoopTest, ShouldRunHasForLoop) {
auto* src = R"(
fn f() {
loop {
break;
}
}
)";
EXPECT_TRUE(ShouldRun<LoopToForLoop>(src));
}
TEST_F(LoopToForLoopTest, EmptyModule) { TEST_F(LoopToForLoopTest, EmptyModule) {
auto* src = ""; auto* src = "";
auto* expect = ""; auto* expect = "";

View File

@ -53,7 +53,7 @@ Output Manager::Run(const Program* program, const DataMap& data) const {
Output out; Output out;
for (const auto& transform : transforms_) { for (const auto& transform : transforms_) {
if (!transform->ShouldRun(in)) { if (!transform->ShouldRun(in, data)) {
TINT_IF_PRINT_PROGRAM(std::cout << "Skipping " TINT_IF_PRINT_PROGRAM(std::cout << "Skipping "
<< transform->TypeInfo().name); << transform->TypeInfo().name);
continue; continue;

View File

@ -377,6 +377,16 @@ ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default; ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default;
bool ModuleScopeVarToEntryPointParam::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* decl : program->AST().GlobalDeclarations()) {
if (decl->Is<ast::Variable>()) {
return true;
}
}
return false;
}
void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx, void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
const DataMap&, const DataMap&,
DataMap&) const { DataMap&) const {

View File

@ -70,6 +70,12 @@ class ModuleScopeVarToEntryPointParam
/// Destructor /// Destructor
~ModuleScopeVarToEntryPointParam() override; ~ModuleScopeVarToEntryPointParam() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected: protected:
/// Runs the transform using the CloneContext built for transforming a /// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext. /// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -24,6 +24,20 @@ namespace {
using ModuleScopeVarToEntryPointParamTest = TransformTest; using ModuleScopeVarToEntryPointParamTest = TransformTest;
TEST_F(ModuleScopeVarToEntryPointParamTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<ModuleScopeVarToEntryPointParam>(src));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, ShouldRunHasGlobal) {
auto* src = R"(
var<private> v : i32;
)";
EXPECT_TRUE(ShouldRun<ModuleScopeVarToEntryPointParam>(src));
}
TEST_F(ModuleScopeVarToEntryPointParamTest, Basic) { TEST_F(ModuleScopeVarToEntryPointParamTest, Basic) {
auto* src = R"( auto* src = R"(
var<private> p : f32; var<private> p : f32;

View File

@ -81,7 +81,7 @@ struct MultiplanarExternalTexture::State {
// represent the secondary plane and one uniform buffer for the // represent the secondary plane and one uniform buffer for the
// ExternalTextureParams struct). // ExternalTextureParams struct).
ctx.ReplaceAll([&](const ast::Variable* var) -> const ast::Variable* { ctx.ReplaceAll([&](const ast::Variable* var) -> const ast::Variable* {
if (!::tint::Is<ast::ExternalTexture>(var->type)) { if (!sem.Get<sem::ExternalTexture>(var->type)) {
return nullptr; return nullptr;
} }
@ -201,7 +201,7 @@ struct MultiplanarExternalTexture::State {
// functions. // functions.
ctx.ReplaceAll([&](const ast::Function* fn) -> const ast::Function* { ctx.ReplaceAll([&](const ast::Function* fn) -> const ast::Function* {
for (const ast::Variable* param : fn->params) { for (const ast::Variable* param : fn->params) {
if (::tint::Is<ast::ExternalTexture>(param->type)) { if (sem.Get<sem::ExternalTexture>(param->type)) {
// If we find a texture_external, we must ensure the // If we find a texture_external, we must ensure the
// ExternalTextureParams struct exists. // ExternalTextureParams struct exists.
if (!params_struct_sym.IsValid()) { if (!params_struct_sym.IsValid()) {
@ -407,6 +407,18 @@ MultiplanarExternalTexture::NewBindingPoints::~NewBindingPoints() = default;
MultiplanarExternalTexture::MultiplanarExternalTexture() = default; MultiplanarExternalTexture::MultiplanarExternalTexture() = default;
MultiplanarExternalTexture::~MultiplanarExternalTexture() = default; MultiplanarExternalTexture::~MultiplanarExternalTexture() = default;
bool MultiplanarExternalTexture::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* ty = node->As<ast::Type>()) {
if (program->Sem().Get<sem::ExternalTexture>(ty)) {
return true;
}
}
}
return false;
}
// Within this transform, an instance of a texture_external binding is unpacked // Within this transform, an instance of a texture_external binding is unpacked
// into two texture_2d<f32> bindings representing two possible planes of a // into two texture_2d<f32> bindings representing two possible planes of a
// single texture and a uniform buffer binding representing a struct of // single texture and a uniform buffer binding representing a struct of

View File

@ -76,6 +76,12 @@ class MultiplanarExternalTexture
/// Destructor /// Destructor
~MultiplanarExternalTexture() override; ~MultiplanarExternalTexture() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected: protected:
struct State; struct State;

View File

@ -21,6 +21,35 @@ namespace {
using MultiplanarExternalTextureTest = TransformTest; using MultiplanarExternalTextureTest = TransformTest;
TEST_F(MultiplanarExternalTextureTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<MultiplanarExternalTexture>(src));
}
TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureAlias) {
auto* src = R"(
type ET = texture_external;
)";
EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src));
}
TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureGlobal) {
auto* src = R"(
[[group(0), binding(0)]] var ext_tex : texture_external;
)";
EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src));
}
TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureParam) {
auto* src = R"(
fn f(ext_tex : texture_external) {}
)";
EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src));
}
// Running the transform without passing in data for the new bindings should // Running the transform without passing in data for the new bindings should
// result in an error. // result in an error.
TEST_F(MultiplanarExternalTextureTest, ErrorNoPassedData) { TEST_F(MultiplanarExternalTextureTest, ErrorNoPassedData) {
@ -651,6 +680,106 @@ fn main() {
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
} }
// Tests that the transform works with a function using an external texture,
// even if there's no external texture declared at module scope.
TEST_F(MultiplanarExternalTextureTest,
ExternalTexturePassedAsParamWithoutGlobalDecl) {
auto* src = R"(
fn f(ext_tex : texture_external) -> vec2<i32> {
return textureDimensions(ext_tex);
}
)";
auto* expect = R"(
struct ExternalTextureParams {
numPlanes : u32;
vr : f32;
ug : f32;
vg : f32;
ub : f32;
}
fn f(ext_tex : texture_2d<f32>, ext_tex_plane_1 : texture_2d<f32>, ext_tex_params : ExternalTextureParams) -> vec2<i32> {
return textureDimensions(ext_tex);
}
)";
DataMap data;
data.Add<MultiplanarExternalTexture::NewBindingPoints>(
MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
auto got = Run<MultiplanarExternalTexture>(src, data);
EXPECT_EQ(expect, str(got));
}
// Tests that the the transform handles aliases to external textures
TEST_F(MultiplanarExternalTextureTest, ExternalTextureAlias) {
auto* src = R"(
type ET = texture_external;
fn f(t : ET, s : sampler) {
textureSampleLevel(t, s, vec2<f32>(1.0, 2.0));
}
[[group(0), binding(0)]] var ext_tex : ET;
[[group(0), binding(1)]] var smp : sampler;
[[stage(fragment)]]
fn main() {
f(ext_tex, smp);
}
)";
auto* expect = R"(
type ET = texture_external;
struct ExternalTextureParams {
numPlanes : u32;
vr : f32;
ug : f32;
vg : f32;
ub : f32;
}
fn textureSampleExternal(plane0 : texture_2d<f32>, plane1 : texture_2d<f32>, smp : sampler, coord : vec2<f32>, params : ExternalTextureParams) -> vec4<f32> {
if ((params.numPlanes == 1u)) {
return textureSampleLevel(plane0, smp, coord, 0.0);
}
let y = (textureSampleLevel(plane0, smp, coord, 0.0).r - 0.0625);
let uv = (textureSampleLevel(plane1, smp, coord, 0.0).rg - 0.5);
let u = uv.x;
let v = uv.y;
let r = ((1.164000034 * y) + (params.vr * v));
let g = (((1.164000034 * y) - (params.ug * u)) - (params.vg * v));
let b = ((1.164000034 * y) + (params.ub * u));
return vec4<f32>(r, g, b, 1.0);
}
fn f(t : texture_2d<f32>, ext_tex_plane_1 : texture_2d<f32>, ext_tex_params : ExternalTextureParams, s : sampler) {
textureSampleExternal(t, ext_tex_plane_1, s, vec2<f32>(1.0, 2.0), ext_tex_params);
}
@group(0) @binding(2) var ext_tex_plane_1_1 : texture_2d<f32>;
@group(0) @binding(3) var<uniform> ext_tex_params_1 : ExternalTextureParams;
@group(0) @binding(0) var ext_tex : texture_2d<f32>;
@group(0) @binding(1) var smp : sampler;
@stage(fragment)
fn main() {
f(ext_tex, ext_tex_plane_1_1, ext_tex_params_1, smp);
}
)";
DataMap data;
data.Add<MultiplanarExternalTexture::NewBindingPoints>(
MultiplanarExternalTexture::BindingsMap{
{{0, 0}, {{0, 2}, {0, 3}}},
});
auto got = Run<MultiplanarExternalTexture>(src, data);
EXPECT_EQ(expect, str(got));
}
} // namespace } // namespace
} // namespace transform } // namespace transform
} // namespace tint } // namespace tint

View File

@ -52,13 +52,21 @@ struct Accessor {
NumWorkgroupsFromUniform::NumWorkgroupsFromUniform() = default; NumWorkgroupsFromUniform::NumWorkgroupsFromUniform() = default;
NumWorkgroupsFromUniform::~NumWorkgroupsFromUniform() = default; NumWorkgroupsFromUniform::~NumWorkgroupsFromUniform() = default;
bool NumWorkgroupsFromUniform::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* deco = node->As<ast::BuiltinDecoration>()) {
if (deco->builtin == ast::Builtin::kNumWorkgroups) {
return true;
}
}
}
return false;
}
void NumWorkgroupsFromUniform::Run(CloneContext& ctx, void NumWorkgroupsFromUniform::Run(CloneContext& ctx,
const DataMap& inputs, const DataMap& inputs,
DataMap&) const { DataMap&) const {
if (!Requires<CanonicalizeEntryPointIO>(ctx)) {
return;
}
auto* cfg = inputs.Get<Config>(); auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) { if (cfg == nullptr) {
ctx.dst->Diagnostics().add_error( ctx.dst->Diagnostics().add_error(

View File

@ -39,6 +39,9 @@ namespace transform {
/// ``` /// ```
/// The binding group and number used for this uniform buffer is provided via /// The binding group and number used for this uniform buffer is provided via
/// the `Config` transform input. /// the `Config` transform input.
///
/// @note Depends on the following transforms to have been run first:
/// * CanonicalizeEntryPointIO
class NumWorkgroupsFromUniform class NumWorkgroupsFromUniform
: public Castable<NumWorkgroupsFromUniform, Transform> { : public Castable<NumWorkgroupsFromUniform, Transform> {
public: public:
@ -63,6 +66,12 @@ class NumWorkgroupsFromUniform
sem::BindingPoint ubo_binding; sem::BindingPoint ubo_binding;
}; };
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected: protected:
/// Runs the transform using the CloneContext built for transforming a /// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext. /// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -26,8 +26,28 @@ namespace {
using NumWorkgroupsFromUniformTest = TransformTest; using NumWorkgroupsFromUniformTest = TransformTest;
TEST_F(NumWorkgroupsFromUniformTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<NumWorkgroupsFromUniform>(src));
}
TEST_F(NumWorkgroupsFromUniformTest, ShouldRunHasNumWorkgroups) {
auto* src = R"(
[[stage(compute), workgroup_size(1)]]
fn main([[builtin(num_workgroups)]] num_wgs : vec3<u32>) {
}
)";
EXPECT_TRUE(ShouldRun<NumWorkgroupsFromUniform>(src));
}
TEST_F(NumWorkgroupsFromUniformTest, Error_MissingTransformData) { TEST_F(NumWorkgroupsFromUniformTest, Error_MissingTransformData) {
auto* src = ""; auto* src = R"(
[[stage(compute), workgroup_size(1)]]
fn main([[builtin(num_workgroups)]] num_wgs : vec3<u32>) {
}
)";
auto* expect = auto* expect =
"error: missing transform data for " "error: missing transform data for "
@ -42,19 +62,6 @@ TEST_F(NumWorkgroupsFromUniformTest, Error_MissingTransformData) {
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
} }
TEST_F(NumWorkgroupsFromUniformTest, Error_MissingCanonicalizeEntryPointIO) {
auto* src = "";
auto* expect =
"error: tint::transform::NumWorkgroupsFromUniform depends on "
"tint::transform::CanonicalizeEntryPointIO but the dependency was not "
"run";
auto got = Run<NumWorkgroupsFromUniform>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(NumWorkgroupsFromUniformTest, Basic) { TEST_F(NumWorkgroupsFromUniformTest, Basic) {
auto* src = R"( auto* src = R"(
@stage(compute) @workgroup_size(1) @stage(compute) @workgroup_size(1)

View File

@ -97,6 +97,19 @@ PadArrayElements::PadArrayElements() = default;
PadArrayElements::~PadArrayElements() = default; PadArrayElements::~PadArrayElements() = default;
bool PadArrayElements::ShouldRun(const Program* program, const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* var = node->As<ast::Type>()) {
if (auto* arr = program->Sem().Get<sem::Array>(var)) {
if (!arr->IsStrideImplicit()) {
return true;
}
}
}
}
return false;
}
void PadArrayElements::Run(CloneContext& ctx, const DataMap&, DataMap&) const { void PadArrayElements::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
auto& sem = ctx.src->Sem(); auto& sem = ctx.src->Sem();

View File

@ -38,6 +38,12 @@ class PadArrayElements : public Castable<PadArrayElements, Transform> {
/// Destructor /// Destructor
~PadArrayElements() override; ~PadArrayElements() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected: protected:
/// Runs the transform using the CloneContext built for transforming a /// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext. /// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -24,6 +24,28 @@ namespace {
using PadArrayElementsTest = TransformTest; using PadArrayElementsTest = TransformTest;
TEST_F(PadArrayElementsTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<PadArrayElements>(src));
}
TEST_F(PadArrayElementsTest, ShouldRunHasImplicitArrayStride) {
auto* src = R"(
var<private> arr : array<i32, 4>;
)";
EXPECT_FALSE(ShouldRun<PadArrayElements>(src));
}
TEST_F(PadArrayElementsTest, ShouldRunHasExplicitArrayStride) {
auto* src = R"(
var<private> arr : [[stride(8)]] array<i32, 4>;
)";
EXPECT_TRUE(ShouldRun<PadArrayElements>(src));
}
TEST_F(PadArrayElementsTest, EmptyModule) { TEST_F(PadArrayElementsTest, EmptyModule) {
auto* src = ""; auto* src = "";
auto* expect = ""; auto* expect = "";

View File

@ -68,6 +68,15 @@ RemovePhonies::RemovePhonies() = default;
RemovePhonies::~RemovePhonies() = default; RemovePhonies::~RemovePhonies() = default;
bool RemovePhonies::ShouldRun(const Program* program, const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (node->Is<ast::PhonyExpression>()) {
return true;
}
}
return false;
}
void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const { void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
auto& sem = ctx.src->Sem(); auto& sem = ctx.src->Sem();

View File

@ -34,6 +34,12 @@ class RemovePhonies : public Castable<RemovePhonies, Transform> {
/// Destructor /// Destructor
~RemovePhonies() override; ~RemovePhonies() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected: protected:
/// Runs the transform using the CloneContext built for transforming a /// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext. /// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -26,6 +26,22 @@ namespace {
using RemovePhoniesTest = TransformTest; using RemovePhoniesTest = TransformTest;
TEST_F(RemovePhoniesTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<RemovePhonies>(src));
}
TEST_F(RemovePhoniesTest, ShouldRunHasPhony) {
auto* src = R"(
fn f() {
_ = 1;
}
)";
EXPECT_TRUE(ShouldRun<RemovePhonies>(src));
}
TEST_F(RemovePhoniesTest, EmptyModule) { TEST_F(RemovePhoniesTest, EmptyModule) {
auto* src = ""; auto* src = "";
auto* expect = ""; auto* expect = "";

View File

@ -37,6 +37,18 @@ RemoveUnreachableStatements::RemoveUnreachableStatements() = default;
RemoveUnreachableStatements::~RemoveUnreachableStatements() = default; RemoveUnreachableStatements::~RemoveUnreachableStatements() = default;
bool RemoveUnreachableStatements::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* stmt = program->Sem().Get<sem::Statement>(node)) {
if (!stmt->IsReachable()) {
return true;
}
}
}
return false;
}
void RemoveUnreachableStatements::Run(CloneContext& ctx, void RemoveUnreachableStatements::Run(CloneContext& ctx,
const DataMap&, const DataMap&,
DataMap&) const { DataMap&) const {

View File

@ -34,6 +34,12 @@ class RemoveUnreachableStatements
/// Destructor /// Destructor
~RemoveUnreachableStatements() override; ~RemoveUnreachableStatements() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected: protected:
/// Runs the transform using the CloneContext built for transforming a /// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext. /// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -22,6 +22,37 @@ namespace {
using RemoveUnreachableStatementsTest = TransformTest; using RemoveUnreachableStatementsTest = TransformTest;
TEST_F(RemoveUnreachableStatementsTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<RemoveUnreachableStatements>(src));
}
TEST_F(RemoveUnreachableStatementsTest, ShouldRunHasNoUnreachable) {
auto* src = R"(
fn f() {
if (true) {
var x = 1;
}
}
)";
EXPECT_FALSE(ShouldRun<RemoveUnreachableStatements>(src));
}
TEST_F(RemoveUnreachableStatementsTest, ShouldRunHasUnreachable) {
auto* src = R"(
fn f() {
return;
if (true) {
var x = 1;
}
}
)";
EXPECT_TRUE(ShouldRun<RemoveUnreachableStatements>(src));
}
TEST_F(RemoveUnreachableStatementsTest, EmptyModule) { TEST_F(RemoveUnreachableStatementsTest, EmptyModule) {
auto* src = ""; auto* src = "";
auto* expect = ""; auto* expect = "";

View File

@ -232,9 +232,6 @@ SimplifyPointers::SimplifyPointers() = default;
SimplifyPointers::~SimplifyPointers() = default; SimplifyPointers::~SimplifyPointers() = default;
void SimplifyPointers::Run(CloneContext& ctx, const DataMap&, DataMap&) const { void SimplifyPointers::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
if (!Requires<Unshadow>(ctx)) {
return;
}
State(ctx).Run(); State(ctx).Run();
} }

View File

@ -29,6 +29,9 @@ namespace transform {
/// Note: SimplifyPointers does not operate on module-scope `let`s, as these /// Note: SimplifyPointers does not operate on module-scope `let`s, as these
/// cannot be pointers: https://gpuweb.github.io/gpuweb/wgsl/#module-constants /// cannot be pointers: https://gpuweb.github.io/gpuweb/wgsl/#module-constants
/// `A module-scope let-declared constant must be of constructible type.` /// `A module-scope let-declared constant must be of constructible type.`
///
/// @note Depends on the following transforms to have been run first:
/// * Unshadow
class SimplifyPointers : public Castable<SimplifyPointers, Transform> { class SimplifyPointers : public Castable<SimplifyPointers, Transform> {
public: public:
/// Constructor /// Constructor

View File

@ -23,18 +23,6 @@ namespace {
using SimplifyPointersTest = TransformTest; using SimplifyPointersTest = TransformTest;
TEST_F(SimplifyPointersTest, Error_MissingSimplifyPointers) {
auto* src = "";
auto* expect =
"error: tint::transform::SimplifyPointers depends on "
"tint::transform::Unshadow but the dependency was not run";
auto got = Run<SimplifyPointers>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(SimplifyPointersTest, EmptyModule) { TEST_F(SimplifyPointersTest, EmptyModule) {
auto* src = ""; auto* src = "";
auto* expect = ""; auto* expect = "";

View File

@ -107,11 +107,6 @@ void SingleEntryPoint::Run(CloneContext& ctx,
// Clone the entry point. // Clone the entry point.
ctx.dst->AST().AddFunction(ctx.Clone(entry_point)); ctx.dst->AST().AddFunction(ctx.Clone(entry_point));
// Retain the list of applied transforms.
// We need to do this manually since we are not going to use the top-level
// ctx.Clone() function.
ctx.dst->SetTransformApplied(ctx.src->TransformsApplied());
} }
SingleEntryPoint::Config::Config(std::string entry_point) SingleEntryPoint::Config::Config(std::string entry_point)

View File

@ -81,6 +81,17 @@ class TransformTestBase : public BASE {
return manager.Run(&program, data); return manager.Run(&program, data);
} }
/// @param in the input WGSL source
/// @param data the optional DataMap to pass to Transform::Run()
/// @return true if the transform should be run for the given input.
template <typename TRANSFORM>
bool ShouldRun(std::string in, const DataMap& data = {}) {
auto file = std::make_unique<Source::File>("test", in);
auto program = reader::wgsl::Parse(file.get());
EXPECT_TRUE(program.IsValid()) << program.Diagnostics().str();
return TRANSFORM().ShouldRun(&program, data);
}
/// @param output the output of the transform /// @param output the output of the transform
/// @returns the output program as a WGSL string, or an error string if the /// @returns the output program as a WGSL string, or an error string if the
/// program is not valid. /// program is not valid.

View File

@ -52,7 +52,6 @@ Output Transform::Run(const Program* program,
CloneContext ctx(&builder, program); CloneContext ctx(&builder, program);
Output output; Output output;
Run(ctx, data, output.data); Run(ctx, data, output.data);
builder.SetTransformApplied(this);
output.program = Program(std::move(builder)); output.program = Program(std::move(builder));
return output; return output;
} }
@ -62,22 +61,7 @@ void Transform::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
<< "Transform::Run() unimplemented for " << TypeInfo().name; << "Transform::Run() unimplemented for " << TypeInfo().name;
} }
bool Transform::ShouldRun(const Program*) const { bool Transform::ShouldRun(const Program*, const DataMap&) const {
return true;
}
bool Transform::Requires(
CloneContext& ctx,
std::initializer_list<const ::tint::TypeInfo*> deps) const {
for (auto* dep : deps) {
if (!ctx.src->HasTransformApplied(dep)) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform, std::string(TypeInfo().name) +
" depends on " + std::string(dep->name) +
" but the dependency was not run");
return false;
}
}
return true; return true;
} }

View File

@ -160,8 +160,10 @@ class Transform : public Castable<Transform> {
virtual Output Run(const Program* program, const DataMap& data = {}) const; virtual Output Run(const Program* program, const DataMap& data = {}) const;
/// @param program the program to inspect /// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program /// @returns true if this transform should be run for the given program
virtual bool ShouldRun(const Program* program) const; virtual bool ShouldRun(const Program* program,
const DataMap& data = {}) const;
protected: protected:
/// Runs the transform using the CloneContext built for transforming a /// Runs the transform using the CloneContext built for transforming a
@ -174,23 +176,6 @@ class Transform : public Castable<Transform> {
const DataMap& inputs, const DataMap& inputs,
DataMap& outputs) const; DataMap& outputs) const;
/// Requires appends an error diagnostic to `ctx.dst` if the template type
/// transforms were not already run on `ctx.src`.
/// @param ctx the CloneContext
/// @returns true if all dependency transforms have been run
template <typename... TRANSFORMS>
bool Requires(CloneContext& ctx) const {
return Requires(ctx, {&::tint::TypeInfo::Of<TRANSFORMS>()...});
}
/// Requires appends an error diagnostic to `ctx.dst` if the list of
/// Transforms were not already run on `ctx.src`.
/// @param ctx the CloneContext
/// @param deps the list of Transform TypeInfos
/// @returns true if all dependency transforms have been run
bool Requires(CloneContext& ctx,
std::initializer_list<const ::tint::TypeInfo*> deps) const;
/// Removes the statement `stmt` from the transformed program. /// Removes the statement `stmt` from the transformed program.
/// RemoveStatement handles edge cases, like statements in the initializer and /// RemoveStatement handles edge cases, like statements in the initializer and
/// continuing of for-loops. /// continuing of for-loops.

View File

@ -32,6 +32,22 @@ VectorizeScalarMatrixConstructors::VectorizeScalarMatrixConstructors() =
VectorizeScalarMatrixConstructors::~VectorizeScalarMatrixConstructors() = VectorizeScalarMatrixConstructors::~VectorizeScalarMatrixConstructors() =
default; default;
bool VectorizeScalarMatrixConstructors::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* call = program->Sem().Get<sem::Call>(node)) {
if (call->Target()->Is<sem::TypeConstructor>() &&
call->Type()->Is<sem::Matrix>()) {
auto& args = call->Arguments();
if (args.size() > 0 && args[0]->Type()->is_scalar()) {
return true;
}
}
}
}
return false;
}
void VectorizeScalarMatrixConstructors::Run(CloneContext& ctx, void VectorizeScalarMatrixConstructors::Run(CloneContext& ctx,
const DataMap&, const DataMap&,
DataMap&) const { DataMap&) const {

View File

@ -30,6 +30,12 @@ class VectorizeScalarMatrixConstructors
/// Destructor /// Destructor
~VectorizeScalarMatrixConstructors() override; ~VectorizeScalarMatrixConstructors() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected: protected:
/// Runs the transform using the CloneContext built for transforming a /// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext. /// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -27,6 +27,12 @@ namespace {
using VectorizeScalarMatrixConstructorsTest = using VectorizeScalarMatrixConstructorsTest =
TransformTestWithParam<std::pair<uint32_t, uint32_t>>; TransformTestWithParam<std::pair<uint32_t, uint32_t>>;
TEST_F(VectorizeScalarMatrixConstructorsTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
}
TEST_P(VectorizeScalarMatrixConstructorsTest, Basic) { TEST_P(VectorizeScalarMatrixConstructorsTest, Basic) {
uint32_t cols = GetParam().first; uint32_t cols = GetParam().first;
uint32_t rows = GetParam().second; uint32_t rows = GetParam().second;
@ -63,6 +69,8 @@ fn main() {
auto src = utils::ReplaceAll(tmpl, "${values}", scalar_values); auto src = utils::ReplaceAll(tmpl, "${values}", scalar_values);
auto expect = utils::ReplaceAll(tmpl, "${values}", vector_values); auto expect = utils::ReplaceAll(tmpl, "${values}", vector_values);
EXPECT_TRUE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
auto got = Run<VectorizeScalarMatrixConstructors>(src); auto got = Run<VectorizeScalarMatrixConstructors>(src);
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
@ -92,6 +100,8 @@ fn main() {
auto src = utils::ReplaceAll(tmpl, "${columns}", columns); auto src = utils::ReplaceAll(tmpl, "${columns}", columns);
auto expect = src; auto expect = src;
EXPECT_FALSE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
auto got = Run<VectorizeScalarMatrixConstructors>(src); auto got = Run<VectorizeScalarMatrixConstructors>(src);
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));

View File

@ -38,6 +38,16 @@ WrapArraysInStructs::WrapArraysInStructs() = default;
WrapArraysInStructs::~WrapArraysInStructs() = default; WrapArraysInStructs::~WrapArraysInStructs() = default;
bool WrapArraysInStructs::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (program->Sem().Get<sem::Array>(node->As<ast::Type>())) {
return true;
}
}
return false;
}
void WrapArraysInStructs::Run(CloneContext& ctx, void WrapArraysInStructs::Run(CloneContext& ctx,
const DataMap&, const DataMap&,
DataMap&) const { DataMap&) const {

View File

@ -44,6 +44,12 @@ class WrapArraysInStructs : public Castable<WrapArraysInStructs, Transform> {
/// Destructor /// Destructor
~WrapArraysInStructs() override; ~WrapArraysInStructs() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected: protected:
/// Runs the transform using the CloneContext built for transforming a /// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext. /// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -25,9 +25,23 @@ namespace {
using WrapArraysInStructsTest = TransformTest; using WrapArraysInStructsTest = TransformTest;
TEST_F(WrapArraysInStructsTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<WrapArraysInStructs>(src));
}
TEST_F(WrapArraysInStructsTest, ShouldRunHasArray) {
auto* src = R"(
var<private> arr : array<i32, 4>;
)";
EXPECT_TRUE(ShouldRun<WrapArraysInStructs>(src));
}
TEST_F(WrapArraysInStructsTest, EmptyModule) { TEST_F(WrapArraysInStructsTest, EmptyModule) {
auto* src = ""; auto* src = R"()";
auto* expect = ""; auto* expect = src;
auto got = Run<WrapArraysInStructs>(src); auto got = Run<WrapArraysInStructs>(src);

View File

@ -433,6 +433,18 @@ ZeroInitWorkgroupMemory::ZeroInitWorkgroupMemory() = default;
ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default; ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default;
bool ZeroInitWorkgroupMemory::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* decl : program->AST().GlobalDeclarations()) {
if (auto* var = decl->As<ast::Variable>()) {
if (var->declared_storage_class == ast::StorageClass::kWorkgroup) {
return true;
}
}
}
return false;
}
void ZeroInitWorkgroupMemory::Run(CloneContext& ctx, void ZeroInitWorkgroupMemory::Run(CloneContext& ctx,
const DataMap&, const DataMap&,
DataMap&) const { DataMap&) const {

View File

@ -32,6 +32,12 @@ class ZeroInitWorkgroupMemory
/// Destructor /// Destructor
~ZeroInitWorkgroupMemory() override; ~ZeroInitWorkgroupMemory() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program,
const DataMap& data = {}) const override;
protected: protected:
/// Runs the transform using the CloneContext built for transforming a /// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext. /// program. Run() is responsible for calling Clone() on the CloneContext.

View File

@ -24,6 +24,28 @@ namespace {
using ZeroInitWorkgroupMemoryTest = TransformTest; using ZeroInitWorkgroupMemoryTest = TransformTest;
TEST_F(ZeroInitWorkgroupMemoryTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<ZeroInitWorkgroupMemory>(src));
}
TEST_F(ZeroInitWorkgroupMemoryTest, ShouldRunHasNoWorkgroupVars) {
auto* src = R"(
var<private> v : i32;
)";
EXPECT_FALSE(ShouldRun<ZeroInitWorkgroupMemory>(src));
}
TEST_F(ZeroInitWorkgroupMemoryTest, ShouldRunHasWorkgroupVars) {
auto* src = R"(
var<workgroup> a : i32;
)";
EXPECT_TRUE(ShouldRun<ZeroInitWorkgroupMemory>(src));
}
TEST_F(ZeroInitWorkgroupMemoryTest, EmptyModule) { TEST_F(ZeroInitWorkgroupMemoryTest, EmptyModule) {
auto* src = ""; auto* src = "";
auto* expect = src; auto* expect = src;

View File

@ -124,14 +124,6 @@ GeneratorImpl::GeneratorImpl(const Program* program) : TextGenerator(program) {}
GeneratorImpl::~GeneratorImpl() = default; GeneratorImpl::~GeneratorImpl() = default;
bool GeneratorImpl::Generate() { bool GeneratorImpl::Generate() {
if (!builder_.HasTransformApplied<transform::Glsl>()) {
diagnostics_.add_error(
diag::System::Writer,
"GLSL writer requires the transform::Glsl sanitizer to have been "
"applied to the input program");
return false;
}
const TypeInfo* last_kind = nullptr; const TypeInfo* last_kind = nullptr;
size_t last_padding_line = 0; size_t last_padding_line = 0;

View File

@ -21,16 +21,6 @@ namespace {
using GlslGeneratorImplTest = TestHelper; using GlslGeneratorImplTest = TestHelper;
TEST_F(GlslGeneratorImplTest, ErrorIfSanitizerNotRun) {
auto program = std::make_unique<Program>(std::move(*this));
GeneratorImpl gen(program.get());
EXPECT_FALSE(gen.Generate());
EXPECT_EQ(
gen.error(),
"error: GLSL writer requires the transform::Glsl sanitizer to have been "
"applied to the input program");
}
TEST_F(GlslGeneratorImplTest, Generate) { TEST_F(GlslGeneratorImplTest, Generate) {
Func("my_func", ast::VariableList{}, ty.void_(), ast::StatementList{}, Func("my_func", ast::VariableList{}, ty.void_(), ast::StatementList{},
ast::DecorationList{}); ast::DecorationList{});

View File

@ -44,9 +44,6 @@ class TestHelperBase : public BODY, public ProgramBuilder {
if (gen_) { if (gen_) {
return *gen_; return *gen_;
} }
// Fake that the GLSL sanitizer has been applied, so that we can unit test
// the writer without it erroring.
SetTransformApplied<transform::Glsl>();
[&]() { [&]() {
ASSERT_TRUE(IsValid()) << "Builder program is not valid\n" ASSERT_TRUE(IsValid()) << "Builder program is not valid\n"
<< diag::Formatter().format(Diagnostics()); << diag::Formatter().format(Diagnostics());

View File

@ -187,9 +187,10 @@ SanitizedResult Sanitize(
if (!result.program.IsValid()) { if (!result.program.IsValid()) {
return result; return result;
} }
result.used_array_length_from_uniform_indices = if (auto* res = out.data.Get<transform::ArrayLengthFromUniform::Result>()) {
std::move(out.data.Get<transform::ArrayLengthFromUniform::Result>() result.used_array_length_from_uniform_indices =
->used_size_indices); std::move(res->used_size_indices);
}
result.needs_storage_buffer_sizes = result.needs_storage_buffer_sizes =
!result.used_array_length_from_uniform_indices.empty(); !result.used_array_length_from_uniform_indices.empty();
return result; return result;