optimizations: Implement transform::ShouldRun()
This change adds an override for Transform::ShouldRun() for many of the transforms that can trivially detect whether running would be a no-op or not. Most programs do not require all the transforms to be run, and by skipping those that are not needed, significant performance wins can be had. This change also removes Transform::Requires() and Program::HasTransformApplied(). This makes little sense now that transforms can be skipped, and the usefulness of this information has been severely reduced since the introduction of transforms that need to be run more than once. Instread, just document on the transform class what the expectations are. Issue: tint:1383 Change-Id: I1a6f27cc4ba61ca1475a4ba912c465db619f76c7 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/77121 Reviewed-by: Antonio Maiorano <amaiorano@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
12d54d746e
commit
800b8e3175
|
@ -57,7 +57,6 @@ Symbol CloneContext::Clone(Symbol s) {
|
|||
|
||||
void CloneContext::Clone() {
|
||||
dst->AST().Copy(this, &src->AST());
|
||||
dst->SetTransformApplied(src->TransformsApplied());
|
||||
}
|
||||
|
||||
ast::FunctionList CloneContext::Clone(const ast::FunctionList& v) {
|
||||
|
|
|
@ -42,7 +42,6 @@ Program::Program(Program&& program)
|
|||
sem_(std::move(program.sem_)),
|
||||
symbols_(std::move(program.symbols_)),
|
||||
diagnostics_(std::move(program.diagnostics_)),
|
||||
transforms_applied_(std::move(program.transforms_applied_)),
|
||||
is_valid_(program.is_valid_) {
|
||||
program.AssertNotMoved();
|
||||
program.moved_ = true;
|
||||
|
@ -67,7 +66,6 @@ Program::Program(ProgramBuilder&& builder) {
|
|||
sem_ = std::move(builder.Sem());
|
||||
symbols_ = std::move(builder.Symbols());
|
||||
diagnostics_.add(std::move(builder.Diagnostics()));
|
||||
transforms_applied_ = builder.TransformsApplied();
|
||||
builder.MarkAsMoved();
|
||||
|
||||
if (!is_valid_ && !diagnostics_.contains_errors()) {
|
||||
|
@ -92,7 +90,6 @@ Program& Program::operator=(Program&& program) {
|
|||
sem_ = std::move(program.sem_);
|
||||
symbols_ = std::move(program.symbols_);
|
||||
diagnostics_ = std::move(program.diagnostics_);
|
||||
transforms_applied_ = std::move(program.transforms_applied_);
|
||||
is_valid_ = program.is_valid_;
|
||||
return *this;
|
||||
}
|
||||
|
|
|
@ -126,25 +126,6 @@ class Program {
|
|||
/// information
|
||||
bool IsValid() const;
|
||||
|
||||
/// @return the TypeInfo pointers of all transforms that have been applied to
|
||||
/// this program.
|
||||
std::unordered_set<const TypeInfo*> TransformsApplied() const {
|
||||
return transforms_applied_;
|
||||
}
|
||||
|
||||
/// @param transform the TypeInfo of the transform
|
||||
/// @returns true if the transform with the given TypeInfo was applied to the
|
||||
/// Program
|
||||
bool HasTransformApplied(const TypeInfo* transform) const {
|
||||
return transforms_applied_.count(transform);
|
||||
}
|
||||
|
||||
/// @returns true if the transform of type `T` was applied.
|
||||
template <typename T>
|
||||
bool HasTransformApplied() const {
|
||||
return HasTransformApplied(&TypeInfo::Of<T>());
|
||||
}
|
||||
|
||||
/// Helper for returning the resolved semantic type of the expression `expr`.
|
||||
/// @param expr the AST expression
|
||||
/// @return the resolved semantic type for the expression, or nullptr if the
|
||||
|
@ -184,7 +165,6 @@ class Program {
|
|||
sem::Info sem_;
|
||||
SymbolTable symbols_{id_};
|
||||
diag::List diagnostics_;
|
||||
std::unordered_set<const TypeInfo*> transforms_applied_;
|
||||
bool is_valid_ = false; // Not valid until it is built
|
||||
bool moved_ = false;
|
||||
};
|
||||
|
|
|
@ -38,8 +38,7 @@ ProgramBuilder::ProgramBuilder(ProgramBuilder&& rhs)
|
|||
ast_(rhs.ast_),
|
||||
sem_(std::move(rhs.sem_)),
|
||||
symbols_(std::move(rhs.symbols_)),
|
||||
diagnostics_(std::move(rhs.diagnostics_)),
|
||||
transforms_applied_(std::move(rhs.transforms_applied_)) {
|
||||
diagnostics_(std::move(rhs.diagnostics_)) {
|
||||
rhs.MarkAsMoved();
|
||||
}
|
||||
|
||||
|
@ -56,7 +55,7 @@ ProgramBuilder& ProgramBuilder::operator=(ProgramBuilder&& rhs) {
|
|||
sem_ = std::move(rhs.sem_);
|
||||
symbols_ = std::move(rhs.symbols_);
|
||||
diagnostics_ = std::move(rhs.diagnostics_);
|
||||
transforms_applied_ = std::move(rhs.transforms_applied_);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
@ -69,7 +68,6 @@ ProgramBuilder ProgramBuilder::Wrap(const Program* program) {
|
|||
builder.sem_ = sem::Info::Wrap(program->Sem());
|
||||
builder.symbols_ = program->Symbols();
|
||||
builder.diagnostics_ = program->Diagnostics();
|
||||
builder.transforms_applied_ = program->TransformsApplied();
|
||||
return builder;
|
||||
}
|
||||
|
||||
|
|
|
@ -2469,49 +2469,6 @@ class ProgramBuilder {
|
|||
source_ = Source(loc);
|
||||
}
|
||||
|
||||
/// Marks that the given transform has been applied to this program.
|
||||
/// @param transform the transform that has been applied
|
||||
void SetTransformApplied(const CastableBase* transform) {
|
||||
transforms_applied_.emplace(&transform->TypeInfo());
|
||||
}
|
||||
|
||||
/// Marks that the given transform `T` has been applied to this program.
|
||||
template <typename T>
|
||||
void SetTransformApplied() {
|
||||
transforms_applied_.emplace(&TypeInfo::Of<T>());
|
||||
}
|
||||
|
||||
/// Marks that the transforms with the given TypeInfos have been applied to
|
||||
/// this program.
|
||||
/// @param transforms the set of transform TypeInfos that has been applied
|
||||
void SetTransformApplied(
|
||||
const std::unordered_set<const TypeInfo*>& transforms) {
|
||||
for (auto* transform : transforms) {
|
||||
transforms_applied_.emplace(transform);
|
||||
}
|
||||
}
|
||||
|
||||
/// Unmarks that the given transform `T` has been applied to this program.
|
||||
template <typename T>
|
||||
void UnsetTransformApplied() {
|
||||
auto it = transforms_applied_.find(&TypeInfo::Of<T>());
|
||||
if (it != transforms_applied_.end()) {
|
||||
transforms_applied_.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
/// @returns true if the transform of type `T` was applied.
|
||||
template <typename T>
|
||||
bool HasTransformApplied() {
|
||||
return transforms_applied_.count(&TypeInfo::Of<T>());
|
||||
}
|
||||
|
||||
/// @return the TypeInfo pointers of all transforms that have been applied to
|
||||
/// this program.
|
||||
std::unordered_set<const TypeInfo*> TransformsApplied() const {
|
||||
return transforms_applied_;
|
||||
}
|
||||
|
||||
/// Helper for returning the resolved semantic type of the expression `expr`.
|
||||
/// @note As the Resolver is run when the Program is built, this will only be
|
||||
/// useful for the Resolver itself and tests that use their own Resolver.
|
||||
|
@ -2592,7 +2549,6 @@ class ProgramBuilder {
|
|||
sem::Info sem_;
|
||||
SymbolTable symbols_{id_};
|
||||
diag::List diagnostics_;
|
||||
std::unordered_set<const TypeInfo*> transforms_applied_;
|
||||
|
||||
/// The source to use when creating AST nodes without providing a Source as
|
||||
/// the first argument.
|
||||
|
|
|
@ -27,15 +27,19 @@ AddEmptyEntryPoint::AddEmptyEntryPoint() = default;
|
|||
|
||||
AddEmptyEntryPoint::~AddEmptyEntryPoint() = default;
|
||||
|
||||
bool AddEmptyEntryPoint::ShouldRun(const Program* program,
|
||||
const DataMap&) const {
|
||||
for (auto* func : program->AST().Functions()) {
|
||||
if (func->IsEntryPoint()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void AddEmptyEntryPoint::Run(CloneContext& ctx,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
for (auto* func : ctx.src->AST().Functions()) {
|
||||
if (func->IsEntryPoint()) {
|
||||
ctx.Clone();
|
||||
return;
|
||||
}
|
||||
}
|
||||
ctx.dst->Func(ctx.dst->Symbols().New("unused_entry_point"), {},
|
||||
ctx.dst->ty.void_(), {},
|
||||
{ctx.dst->Stage(ast::PipelineStage::kCompute),
|
||||
|
|
|
@ -28,6 +28,12 @@ class AddEmptyEntryPoint : public Castable<AddEmptyEntryPoint, Transform> {
|
|||
/// Destructor
|
||||
~AddEmptyEntryPoint() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
|
|
|
@ -24,6 +24,21 @@ namespace {
|
|||
|
||||
using AddEmptyEntryPointTest = TransformTest;
|
||||
|
||||
TEST_F(AddEmptyEntryPointTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<AddEmptyEntryPoint>(src));
|
||||
}
|
||||
|
||||
TEST_F(AddEmptyEntryPointTest, ShouldRunExistingEntryPoint) {
|
||||
auto* src = R"(
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn existing() {}
|
||||
)";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<AddEmptyEntryPoint>(src));
|
||||
}
|
||||
|
||||
TEST_F(AddEmptyEntryPointTest, EmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include "src/program_builder.h"
|
||||
#include "src/sem/call.h"
|
||||
#include "src/sem/function.h"
|
||||
#include "src/sem/variable.h"
|
||||
#include "src/transform/simplify_pointers.h"
|
||||
|
||||
|
@ -93,13 +94,23 @@ static void IterateArrayLengthOnStorageVar(CloneContext& ctx, F&& functor) {
|
|||
}
|
||||
}
|
||||
|
||||
bool ArrayLengthFromUniform::ShouldRun(const Program* program,
|
||||
const DataMap&) const {
|
||||
for (auto* fn : program->AST().Functions()) {
|
||||
if (auto* sem_fn = program->Sem().Get(fn)) {
|
||||
for (auto* intrinsic : sem_fn->DirectlyCalledIntrinsics()) {
|
||||
if (intrinsic->Type() == sem::IntrinsicType::kArrayLength) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ArrayLengthFromUniform::Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const {
|
||||
if (!Requires<SimplifyPointers>(ctx)) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto* cfg = inputs.Get<Config>();
|
||||
if (cfg == nullptr) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
|
|
|
@ -49,6 +49,9 @@ namespace transform {
|
|||
/// This transform assumes that the `SimplifyPointers`
|
||||
/// transforms have been run before it so that arguments to the arrayLength
|
||||
/// builtin always have the form `&resource.array`.
|
||||
///
|
||||
/// @note Depends on the following transforms to have been run first:
|
||||
/// * SimplifyPointers
|
||||
class ArrayLengthFromUniform
|
||||
: public Castable<ArrayLengthFromUniform, Transform> {
|
||||
public:
|
||||
|
@ -81,6 +84,8 @@ class ArrayLengthFromUniform
|
|||
};
|
||||
|
||||
/// Information produced about what the transform did.
|
||||
/// If there were no calls to the arrayLength() intrinsic, then no Result will
|
||||
/// be emitted.
|
||||
struct Result : public Castable<Result, transform::Data> {
|
||||
/// Constructor
|
||||
/// @param used_size_indices Indices into the UBO that are statically used.
|
||||
|
@ -96,6 +101,12 @@ class ArrayLengthFromUniform
|
|||
const std::unordered_set<uint32_t> used_size_indices;
|
||||
};
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
|
|
|
@ -26,8 +26,61 @@ namespace {
|
|||
|
||||
using ArrayLengthFromUniformTest = TransformTest;
|
||||
|
||||
TEST_F(ArrayLengthFromUniformTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src));
|
||||
}
|
||||
|
||||
TEST_F(ArrayLengthFromUniformTest, ShouldRunNoArrayLength) {
|
||||
auto* src = R"(
|
||||
struct SB {
|
||||
x : i32;
|
||||
arr : array<i32>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read> sb : SB;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src));
|
||||
}
|
||||
|
||||
TEST_F(ArrayLengthFromUniformTest, ShouldRunWithArrayLength) {
|
||||
auto* src = R"(
|
||||
struct SB {
|
||||
x : i32;
|
||||
arr : array<i32>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read> sb : SB;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
var len : u32 = arrayLength(&sb.arr);
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<ArrayLengthFromUniform>(src));
|
||||
}
|
||||
|
||||
TEST_F(ArrayLengthFromUniformTest, Error_MissingTransformData) {
|
||||
auto* src = "";
|
||||
auto* src = R"(
|
||||
struct SB {
|
||||
x : i32;
|
||||
arr : array<i32>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read> sb : SB;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
var len : u32 = arrayLength(&sb.arr);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect =
|
||||
"error: missing transform data for "
|
||||
|
@ -38,18 +91,6 @@ TEST_F(ArrayLengthFromUniformTest, Error_MissingTransformData) {
|
|||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(ArrayLengthFromUniformTest, Error_MissingSimplifyPointers) {
|
||||
auto* src = "";
|
||||
|
||||
auto* expect =
|
||||
"error: tint::transform::ArrayLengthFromUniform depends on "
|
||||
"tint::transform::SimplifyPointers but the dependency was not run";
|
||||
|
||||
auto got = Run<ArrayLengthFromUniform>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(ArrayLengthFromUniformTest, Basic) {
|
||||
auto* src = R"(
|
||||
@group(0) @binding(0) var<storage, read> sb : array<i32>;
|
||||
|
@ -426,8 +467,7 @@ fn main() {
|
|||
auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
|
||||
|
||||
EXPECT_EQ(src, str(got));
|
||||
EXPECT_EQ(std::unordered_set<uint32_t>(),
|
||||
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
|
||||
EXPECT_EQ(got.data.Get<ArrayLengthFromUniform::Result>(), nullptr);
|
||||
}
|
||||
|
||||
TEST_F(ArrayLengthFromUniformTest, MissingBindingPointToIndexMapping) {
|
||||
|
|
|
@ -42,6 +42,14 @@ BindingRemapper::Remappings::~Remappings() = default;
|
|||
BindingRemapper::BindingRemapper() = default;
|
||||
BindingRemapper::~BindingRemapper() = default;
|
||||
|
||||
bool BindingRemapper::ShouldRun(const Program*, const DataMap& inputs) const {
|
||||
if (auto* remappings = inputs.Get<Remappings>()) {
|
||||
return !remappings->binding_points.empty() ||
|
||||
!remappings->access_controls.empty();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void BindingRemapper::Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap&) const {
|
||||
|
|
|
@ -68,6 +68,12 @@ class BindingRemapper : public Castable<BindingRemapper, Transform> {
|
|||
BindingRemapper();
|
||||
~BindingRemapper() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
|
|
|
@ -24,6 +24,47 @@ namespace {
|
|||
|
||||
using BindingRemapperTest = TransformTest;
|
||||
|
||||
TEST_F(BindingRemapperTest, ShouldRunNoRemappings) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<BindingRemapper>(src));
|
||||
}
|
||||
|
||||
TEST_F(BindingRemapperTest, ShouldRunEmptyRemappings) {
|
||||
auto* src = R"()";
|
||||
|
||||
DataMap data;
|
||||
data.Add<BindingRemapper::Remappings>(BindingRemapper::BindingPoints{},
|
||||
BindingRemapper::AccessControls{});
|
||||
|
||||
EXPECT_FALSE(ShouldRun<BindingRemapper>(src, data));
|
||||
}
|
||||
|
||||
TEST_F(BindingRemapperTest, ShouldRunBindingPointRemappings) {
|
||||
auto* src = R"()";
|
||||
|
||||
DataMap data;
|
||||
data.Add<BindingRemapper::Remappings>(
|
||||
BindingRemapper::BindingPoints{
|
||||
{{2, 1}, {1, 2}},
|
||||
},
|
||||
BindingRemapper::AccessControls{});
|
||||
|
||||
EXPECT_TRUE(ShouldRun<BindingRemapper>(src, data));
|
||||
}
|
||||
|
||||
TEST_F(BindingRemapperTest, ShouldRunAccessControlRemappings) {
|
||||
auto* src = R"()";
|
||||
|
||||
DataMap data;
|
||||
data.Add<BindingRemapper::Remappings>(BindingRemapper::BindingPoints{},
|
||||
BindingRemapper::AccessControls{
|
||||
{{2, 1}, ast::Access::kWrite},
|
||||
});
|
||||
|
||||
EXPECT_TRUE(ShouldRun<BindingRemapper>(src, data));
|
||||
}
|
||||
|
||||
TEST_F(BindingRemapperTest, NoRemappings) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
|
@ -359,17 +400,18 @@ TEST_F(BindingRemapperTest, NoData) {
|
|||
auto* src = R"(
|
||||
struct S {
|
||||
a : f32;
|
||||
};
|
||||
}
|
||||
|
||||
@group(2) @binding(1) var<storage, read> a : S;
|
||||
|
||||
@group(3) @binding(2) var<storage, read> b : S;
|
||||
|
||||
@stage(compute) @workgroup_size(1)
|
||||
fn f() {}
|
||||
fn f() {
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect =
|
||||
"error: missing transform data for tint::transform::BindingRemapper";
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<BindingRemapper>(src);
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "src/program_builder.h"
|
||||
#include "src/sem/block_statement.h"
|
||||
#include "src/sem/call.h"
|
||||
#include "src/sem/function.h"
|
||||
#include "src/sem/statement.h"
|
||||
#include "src/sem/struct.h"
|
||||
#include "src/sem/variable.h"
|
||||
|
@ -71,13 +72,24 @@ CalculateArrayLength::BufferSizeIntrinsic::Clone(CloneContext* ctx) const {
|
|||
CalculateArrayLength::CalculateArrayLength() = default;
|
||||
CalculateArrayLength::~CalculateArrayLength() = default;
|
||||
|
||||
bool CalculateArrayLength::ShouldRun(const Program* program,
|
||||
const DataMap&) const {
|
||||
for (auto* fn : program->AST().Functions()) {
|
||||
if (auto* sem_fn = program->Sem().Get(fn)) {
|
||||
for (auto* intrinsic : sem_fn->DirectlyCalledIntrinsics()) {
|
||||
if (intrinsic->Type() == sem::IntrinsicType::kArrayLength) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void CalculateArrayLength::Run(CloneContext& ctx,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
auto& sem = ctx.src->Sem();
|
||||
if (!Requires<SimplifyPointers>(ctx)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// get_buffer_size_intrinsic() emits the function decorated with
|
||||
// BufferSizeIntrinsic that is transformed by the HLSL writer into a call to
|
||||
|
|
|
@ -29,6 +29,9 @@ namespace transform {
|
|||
|
||||
/// CalculateArrayLength is a transform used to replace calls to arrayLength()
|
||||
/// with a value calculated from the size of the storage buffer.
|
||||
///
|
||||
/// @note Depends on the following transforms to have been run first:
|
||||
/// * SimplifyPointers
|
||||
class CalculateArrayLength : public Castable<CalculateArrayLength, Transform> {
|
||||
public:
|
||||
/// BufferSizeIntrinsic is an InternalDecoration that's applied to intrinsic
|
||||
|
@ -56,6 +59,12 @@ class CalculateArrayLength : public Castable<CalculateArrayLength, Transform> {
|
|||
/// Destructor
|
||||
~CalculateArrayLength() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
|
|
|
@ -24,16 +24,45 @@ namespace {
|
|||
|
||||
using CalculateArrayLengthTest = TransformTest;
|
||||
|
||||
TEST_F(CalculateArrayLengthTest, Error_MissingCalculateArrayLength) {
|
||||
auto* src = "";
|
||||
TEST_F(CalculateArrayLengthTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
auto* expect =
|
||||
"error: tint::transform::CalculateArrayLength depends on "
|
||||
"tint::transform::SimplifyPointers but the dependency was not run";
|
||||
EXPECT_FALSE(ShouldRun<CalculateArrayLength>(src));
|
||||
}
|
||||
|
||||
auto got = Run<CalculateArrayLength>(src);
|
||||
TEST_F(CalculateArrayLengthTest, ShouldRunNoArrayLength) {
|
||||
auto* src = R"(
|
||||
struct SB {
|
||||
x : i32;
|
||||
arr : array<i32>;
|
||||
};
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
[[group(0), binding(0)]] var<storage, read> sb : SB;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<CalculateArrayLength>(src));
|
||||
}
|
||||
|
||||
TEST_F(CalculateArrayLengthTest, ShouldRunWithArrayLength) {
|
||||
auto* src = R"(
|
||||
struct SB {
|
||||
x : i32;
|
||||
arr : array<i32>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read> sb : SB;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
var len : u32 = arrayLength(&sb.arr);
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<CalculateArrayLength>(src));
|
||||
}
|
||||
|
||||
TEST_F(CalculateArrayLengthTest, Basic) {
|
||||
|
|
|
@ -551,10 +551,6 @@ struct CanonicalizeEntryPointIO::State {
|
|||
void CanonicalizeEntryPointIO::Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap&) const {
|
||||
if (!Requires<Unshadow>(ctx)) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto* cfg = inputs.Get<Config>();
|
||||
if (cfg == nullptr) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
|
|
|
@ -80,6 +80,9 @@ namespace transform {
|
|||
/// return wrapper_result;
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// @note Depends on the following transforms to have been run first:
|
||||
/// * Unshadow
|
||||
class CanonicalizeEntryPointIO
|
||||
: public Castable<CanonicalizeEntryPointIO, Transform> {
|
||||
public:
|
||||
|
|
|
@ -23,18 +23,6 @@ namespace {
|
|||
|
||||
using CanonicalizeEntryPointIOTest = TransformTest;
|
||||
|
||||
TEST_F(CanonicalizeEntryPointIOTest, Error_MissingUnshadow) {
|
||||
auto* src = "";
|
||||
|
||||
auto* expect =
|
||||
"error: tint::transform::CanonicalizeEntryPointIO depends on "
|
||||
"tint::transform::Unshadow but the dependency was not run";
|
||||
|
||||
auto got = Run<CanonicalizeEntryPointIO>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(CanonicalizeEntryPointIOTest, Error_MissingTransformData) {
|
||||
auto* src = "";
|
||||
|
||||
|
|
|
@ -790,6 +790,19 @@ const DecomposeMemoryAccess::Intrinsic* DecomposeMemoryAccess::Intrinsic::Clone(
|
|||
DecomposeMemoryAccess::DecomposeMemoryAccess() = default;
|
||||
DecomposeMemoryAccess::~DecomposeMemoryAccess() = default;
|
||||
|
||||
bool DecomposeMemoryAccess::ShouldRun(const Program* program,
|
||||
const DataMap&) const {
|
||||
for (auto* decl : program->AST().GlobalDeclarations()) {
|
||||
if (auto* var = program->Sem().Get<sem::Variable>(decl)) {
|
||||
if (var->StorageClass() == ast::StorageClass::kStorage ||
|
||||
var->StorageClass() == ast::StorageClass::kUniform) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void DecomposeMemoryAccess::Run(CloneContext& ctx,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
|
|
|
@ -105,6 +105,12 @@ class DecomposeMemoryAccess
|
|||
/// Destructor
|
||||
~DecomposeMemoryAccess() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
|
|
|
@ -22,6 +22,34 @@ namespace {
|
|||
|
||||
using DecomposeMemoryAccessTest = TransformTest;
|
||||
|
||||
TEST_F(DecomposeMemoryAccessTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<DecomposeMemoryAccess>(src));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeMemoryAccessTest, ShouldRunStorageBuffer) {
|
||||
auto* src = R"(
|
||||
struct Buffer {
|
||||
i : i32;
|
||||
};
|
||||
[[group(0), binding(0)]] var<storage, read_write> sb : Buffer;
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<DecomposeMemoryAccess>(src));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeMemoryAccessTest, ShouldRunUniformBuffer) {
|
||||
auto* src = R"(
|
||||
struct Buffer {
|
||||
i : i32;
|
||||
};
|
||||
[[group(0), binding(0)]] var<uniform> ub : Buffer;
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<DecomposeMemoryAccess>(src));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeMemoryAccessTest, SB_BasicLoad) {
|
||||
auto* src = R"(
|
||||
struct SB {
|
||||
|
|
|
@ -107,7 +107,8 @@ DecomposeStridedMatrix::DecomposeStridedMatrix() = default;
|
|||
|
||||
DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
|
||||
|
||||
bool DecomposeStridedMatrix::ShouldRun(const Program* program) const {
|
||||
bool DecomposeStridedMatrix::ShouldRun(const Program* program,
|
||||
const DataMap&) const {
|
||||
bool should_run = false;
|
||||
GatherCustomStrideMatrixMembers(
|
||||
program, [&](const sem::StructMember*, sem::Matrix*, uint32_t) {
|
||||
|
@ -120,10 +121,6 @@ bool DecomposeStridedMatrix::ShouldRun(const Program* program) const {
|
|||
void DecomposeStridedMatrix::Run(CloneContext& ctx,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
if (!Requires<SimplifyPointers>(ctx)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Scan the program for all storage and uniform structure matrix members with
|
||||
// a custom stride attribute. Replace these matrices with an equivalent array,
|
||||
// and populate the `decomposed` map with the members that have been replaced.
|
||||
|
|
|
@ -25,6 +25,9 @@ namespace transform {
|
|||
/// of N column vectors.
|
||||
/// This transform is used by the SPIR-V reader to handle the SPIR-V
|
||||
/// MatrixStride decoration.
|
||||
///
|
||||
/// @note Depends on the following transforms to have been run first:
|
||||
/// * SimplifyPointers
|
||||
class DecomposeStridedMatrix
|
||||
: public Castable<DecomposeStridedMatrix, Transform> {
|
||||
public:
|
||||
|
@ -35,8 +38,10 @@ class DecomposeStridedMatrix
|
|||
~DecomposeStridedMatrix() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program) const override;
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
|
|
|
@ -31,55 +31,25 @@ namespace {
|
|||
using DecomposeStridedMatrixTest = TransformTest;
|
||||
using f32 = ProgramBuilder::f32;
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<DecomposeStridedMatrix>(src));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, ShouldRunNonStridedMatrox) {
|
||||
auto* src = R"(
|
||||
var<private> m : mat3x2<f32>;
|
||||
)";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<DecomposeStridedMatrix>(src));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, Empty) {
|
||||
auto* src = R"()";
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(DecomposeStridedMatrixTest, MissingDependencySimplify) {
|
||||
// struct S {
|
||||
// [[offset(16), stride(32)]]
|
||||
// [[internal(ignore_stride_decoration)]]
|
||||
// m : mat2x2<f32>;
|
||||
// };
|
||||
// [[group(0), binding(0)]] var<uniform> s : S;
|
||||
//
|
||||
// [[stage(compute), workgroup_size(1)]]
|
||||
// fn f() {
|
||||
// let x : mat2x2<f32> = s.m;
|
||||
// }
|
||||
ProgramBuilder b;
|
||||
auto* S = b.Structure(
|
||||
"S",
|
||||
{
|
||||
b.Member(
|
||||
"m", b.ty.mat2x2<f32>(),
|
||||
{
|
||||
b.create<ast::StructMemberOffsetDecoration>(16),
|
||||
b.create<ast::StrideDecoration>(32),
|
||||
b.Disable(ast::DisabledValidation::kIgnoreStrideDecoration),
|
||||
}),
|
||||
});
|
||||
b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
|
||||
b.GroupAndBinding(0, 0));
|
||||
b.Func(
|
||||
"f", {}, b.ty.void_(),
|
||||
{
|
||||
b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
|
||||
},
|
||||
{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1),
|
||||
});
|
||||
|
||||
auto* expect =
|
||||
R"(error: tint::transform::DecomposeStridedMatrix depends on tint::transform::SimplifyPointers but the dependency was not run)";
|
||||
|
||||
auto got = Run<DecomposeStridedMatrix>(Program(std::move(b)));
|
||||
auto got = Run<DecomposeStridedMatrix>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
|
|
@ -57,6 +57,15 @@ FirstIndexOffset::Data::~Data() = default;
|
|||
FirstIndexOffset::FirstIndexOffset() = default;
|
||||
FirstIndexOffset::~FirstIndexOffset() = default;
|
||||
|
||||
bool FirstIndexOffset::ShouldRun(const Program* program, const DataMap&) const {
|
||||
for (auto* fn : program->AST().Functions()) {
|
||||
if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void FirstIndexOffset::Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const {
|
||||
|
|
|
@ -115,6 +115,12 @@ class FirstIndexOffset : public Castable<FirstIndexOffset, Transform> {
|
|||
/// Destructor
|
||||
~FirstIndexOffset() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
|
|
|
@ -26,6 +26,34 @@ namespace {
|
|||
|
||||
using FirstIndexOffsetTest = TransformTest;
|
||||
|
||||
TEST_F(FirstIndexOffsetTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<FirstIndexOffset>(src));
|
||||
}
|
||||
|
||||
TEST_F(FirstIndexOffsetTest, ShouldRunFragmentStage) {
|
||||
auto* src = R"(
|
||||
[[stage(fragment)]]
|
||||
fn entry() {
|
||||
return;
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<FirstIndexOffset>(src));
|
||||
}
|
||||
|
||||
TEST_F(FirstIndexOffsetTest, ShouldRunVertexStage) {
|
||||
auto* src = R"(
|
||||
[[stage(vertex)]]
|
||||
fn entry() -> [[builtin(position)]] vec4<f32> {
|
||||
return vec4<f32>();
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<FirstIndexOffset>(src));
|
||||
}
|
||||
|
||||
TEST_F(FirstIndexOffsetTest, EmptyModule) {
|
||||
auto* src = "";
|
||||
auto* expect = "";
|
||||
|
@ -38,6 +66,26 @@ TEST_F(FirstIndexOffsetTest, EmptyModule) {
|
|||
|
||||
auto* data = got.data.Get<FirstIndexOffset::Data>();
|
||||
|
||||
EXPECT_EQ(data, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(FirstIndexOffsetTest, BasicVertexShader) {
|
||||
auto* src = R"(
|
||||
@stage(vertex)
|
||||
fn entry() -> @builtin(position) vec4<f32> {
|
||||
return vec4<f32>();
|
||||
}
|
||||
)";
|
||||
auto* expect = src;
|
||||
|
||||
DataMap config;
|
||||
config.Add<FirstIndexOffset::BindingPoint>(0, 0);
|
||||
auto got = Run<FirstIndexOffset>(src, std::move(config));
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
|
||||
auto* data = got.data.Get<FirstIndexOffset::Data>();
|
||||
|
||||
ASSERT_NE(data, nullptr);
|
||||
EXPECT_EQ(data->has_vertex_index, false);
|
||||
EXPECT_EQ(data->has_instance_index, false);
|
||||
|
|
|
@ -25,6 +25,15 @@ ForLoopToLoop::ForLoopToLoop() = default;
|
|||
|
||||
ForLoopToLoop::~ForLoopToLoop() = default;
|
||||
|
||||
bool ForLoopToLoop::ShouldRun(const Program* program, const DataMap&) const {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (node->Is<ast::ForLoopStatement>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ForLoopToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
ctx.ReplaceAll(
|
||||
[&](const ast::ForLoopStatement* for_loop) -> const ast::Statement* {
|
||||
|
|
|
@ -30,6 +30,12 @@ class ForLoopToLoop : public Castable<ForLoopToLoop, Transform> {
|
|||
/// Destructor
|
||||
~ForLoopToLoop() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
|
|
|
@ -22,6 +22,24 @@ namespace {
|
|||
|
||||
using ForLoopToLoopTest = TransformTest;
|
||||
|
||||
TEST_F(ForLoopToLoopTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<ForLoopToLoop>(src));
|
||||
}
|
||||
|
||||
TEST_F(ForLoopToLoopTest, ShouldRunHasForLoop) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
for (;;) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<ForLoopToLoop>(src));
|
||||
}
|
||||
|
||||
TEST_F(ForLoopToLoopTest, EmptyModule) {
|
||||
auto* src = "";
|
||||
auto* expect = src;
|
||||
|
|
|
@ -113,7 +113,6 @@ Output Glsl::Run(const Program* in, const DataMap& inputs) const {
|
|||
ProgramBuilder builder;
|
||||
CloneContext ctx(&builder, &out.program);
|
||||
ctx.Clone();
|
||||
builder.SetTransformApplied(this);
|
||||
return Output{Program(std::move(builder))};
|
||||
}
|
||||
|
||||
|
|
|
@ -216,15 +216,9 @@ LocalizeStructArrayAssignment::~LocalizeStructArrayAssignment() = default;
|
|||
void LocalizeStructArrayAssignment::Run(CloneContext& ctx,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
if (!Requires<SimplifyPointers>(ctx)) {
|
||||
return;
|
||||
}
|
||||
|
||||
State state(ctx);
|
||||
state.Run();
|
||||
|
||||
// This transform may introduce pointers
|
||||
ctx.dst->UnsetTransformApplied<transform::SimplifyPointers>();
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
|
|
@ -25,6 +25,9 @@ namespace transform {
|
|||
/// temporary local variable, assigns to the local variable, and copies the
|
||||
/// array back. This is to work around FXC's compilation failure for these cases
|
||||
/// (see crbug.com/tint/1206).
|
||||
///
|
||||
/// @note Depends on the following transforms to have been run first:
|
||||
/// * SimplifyPointers
|
||||
class LocalizeStructArrayAssignment
|
||||
: public Castable<LocalizeStructArrayAssignment, Transform> {
|
||||
public:
|
||||
|
|
|
@ -24,16 +24,6 @@ namespace {
|
|||
|
||||
using LocalizeStructArrayAssignmentTest = TransformTest;
|
||||
|
||||
TEST_F(LocalizeStructArrayAssignmentTest, MissingSimplifyPointers) {
|
||||
auto* src = R"()";
|
||||
auto* expect =
|
||||
"error: tint::transform::LocalizeStructArrayAssignment depends on "
|
||||
"tint::transform::SimplifyPointers but the dependency was not run";
|
||||
|
||||
auto got = Run<LocalizeStructArrayAssignment>(src);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(LocalizeStructArrayAssignmentTest, EmptyModule) {
|
||||
auto* src = R"()";
|
||||
auto* expect = src;
|
||||
|
|
|
@ -56,6 +56,15 @@ LoopToForLoop::LoopToForLoop() = default;
|
|||
|
||||
LoopToForLoop::~LoopToForLoop() = default;
|
||||
|
||||
bool LoopToForLoop::ShouldRun(const Program* program, const DataMap&) const {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (node->Is<ast::LoopStatement>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void LoopToForLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
ctx.ReplaceAll([&](const ast::LoopStatement* loop) -> const ast::Statement* {
|
||||
// For loop condition is taken from the first statement in the loop.
|
||||
|
|
|
@ -30,6 +30,12 @@ class LoopToForLoop : public Castable<LoopToForLoop, Transform> {
|
|||
/// Destructor
|
||||
~LoopToForLoop() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
|
|
|
@ -22,6 +22,24 @@ namespace {
|
|||
|
||||
using LoopToForLoopTest = TransformTest;
|
||||
|
||||
TEST_F(LoopToForLoopTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<LoopToForLoop>(src));
|
||||
}
|
||||
|
||||
TEST_F(LoopToForLoopTest, ShouldRunHasForLoop) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
loop {
|
||||
break;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<LoopToForLoop>(src));
|
||||
}
|
||||
|
||||
TEST_F(LoopToForLoopTest, EmptyModule) {
|
||||
auto* src = "";
|
||||
auto* expect = "";
|
||||
|
|
|
@ -53,7 +53,7 @@ Output Manager::Run(const Program* program, const DataMap& data) const {
|
|||
|
||||
Output out;
|
||||
for (const auto& transform : transforms_) {
|
||||
if (!transform->ShouldRun(in)) {
|
||||
if (!transform->ShouldRun(in, data)) {
|
||||
TINT_IF_PRINT_PROGRAM(std::cout << "Skipping "
|
||||
<< transform->TypeInfo().name);
|
||||
continue;
|
||||
|
|
|
@ -377,6 +377,16 @@ ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
|
|||
|
||||
ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default;
|
||||
|
||||
bool ModuleScopeVarToEntryPointParam::ShouldRun(const Program* program,
|
||||
const DataMap&) const {
|
||||
for (auto* decl : program->AST().GlobalDeclarations()) {
|
||||
if (decl->Is<ast::Variable>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
|
|
|
@ -70,6 +70,12 @@ class ModuleScopeVarToEntryPointParam
|
|||
/// Destructor
|
||||
~ModuleScopeVarToEntryPointParam() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
|
|
|
@ -24,6 +24,20 @@ namespace {
|
|||
|
||||
using ModuleScopeVarToEntryPointParamTest = TransformTest;
|
||||
|
||||
TEST_F(ModuleScopeVarToEntryPointParamTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<ModuleScopeVarToEntryPointParam>(src));
|
||||
}
|
||||
|
||||
TEST_F(ModuleScopeVarToEntryPointParamTest, ShouldRunHasGlobal) {
|
||||
auto* src = R"(
|
||||
var<private> v : i32;
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<ModuleScopeVarToEntryPointParam>(src));
|
||||
}
|
||||
|
||||
TEST_F(ModuleScopeVarToEntryPointParamTest, Basic) {
|
||||
auto* src = R"(
|
||||
var<private> p : f32;
|
||||
|
|
|
@ -81,7 +81,7 @@ struct MultiplanarExternalTexture::State {
|
|||
// represent the secondary plane and one uniform buffer for the
|
||||
// ExternalTextureParams struct).
|
||||
ctx.ReplaceAll([&](const ast::Variable* var) -> const ast::Variable* {
|
||||
if (!::tint::Is<ast::ExternalTexture>(var->type)) {
|
||||
if (!sem.Get<sem::ExternalTexture>(var->type)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -201,7 +201,7 @@ struct MultiplanarExternalTexture::State {
|
|||
// functions.
|
||||
ctx.ReplaceAll([&](const ast::Function* fn) -> const ast::Function* {
|
||||
for (const ast::Variable* param : fn->params) {
|
||||
if (::tint::Is<ast::ExternalTexture>(param->type)) {
|
||||
if (sem.Get<sem::ExternalTexture>(param->type)) {
|
||||
// If we find a texture_external, we must ensure the
|
||||
// ExternalTextureParams struct exists.
|
||||
if (!params_struct_sym.IsValid()) {
|
||||
|
@ -407,6 +407,18 @@ MultiplanarExternalTexture::NewBindingPoints::~NewBindingPoints() = default;
|
|||
MultiplanarExternalTexture::MultiplanarExternalTexture() = default;
|
||||
MultiplanarExternalTexture::~MultiplanarExternalTexture() = default;
|
||||
|
||||
bool MultiplanarExternalTexture::ShouldRun(const Program* program,
|
||||
const DataMap&) const {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (auto* ty = node->As<ast::Type>()) {
|
||||
if (program->Sem().Get<sem::ExternalTexture>(ty)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Within this transform, an instance of a texture_external binding is unpacked
|
||||
// into two texture_2d<f32> bindings representing two possible planes of a
|
||||
// single texture and a uniform buffer binding representing a struct of
|
||||
|
|
|
@ -76,6 +76,12 @@ class MultiplanarExternalTexture
|
|||
/// Destructor
|
||||
~MultiplanarExternalTexture() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
struct State;
|
||||
|
||||
|
|
|
@ -21,6 +21,35 @@ namespace {
|
|||
|
||||
using MultiplanarExternalTextureTest = TransformTest;
|
||||
|
||||
TEST_F(MultiplanarExternalTextureTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<MultiplanarExternalTexture>(src));
|
||||
}
|
||||
|
||||
TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureAlias) {
|
||||
auto* src = R"(
|
||||
type ET = texture_external;
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src));
|
||||
}
|
||||
TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureGlobal) {
|
||||
auto* src = R"(
|
||||
[[group(0), binding(0)]] var ext_tex : texture_external;
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src));
|
||||
}
|
||||
|
||||
TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureParam) {
|
||||
auto* src = R"(
|
||||
fn f(ext_tex : texture_external) {}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src));
|
||||
}
|
||||
|
||||
// Running the transform without passing in data for the new bindings should
|
||||
// result in an error.
|
||||
TEST_F(MultiplanarExternalTextureTest, ErrorNoPassedData) {
|
||||
|
@ -651,6 +680,106 @@ fn main() {
|
|||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
// Tests that the transform works with a function using an external texture,
|
||||
// even if there's no external texture declared at module scope.
|
||||
TEST_F(MultiplanarExternalTextureTest,
|
||||
ExternalTexturePassedAsParamWithoutGlobalDecl) {
|
||||
auto* src = R"(
|
||||
fn f(ext_tex : texture_external) -> vec2<i32> {
|
||||
return textureDimensions(ext_tex);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct ExternalTextureParams {
|
||||
numPlanes : u32;
|
||||
vr : f32;
|
||||
ug : f32;
|
||||
vg : f32;
|
||||
ub : f32;
|
||||
}
|
||||
|
||||
fn f(ext_tex : texture_2d<f32>, ext_tex_plane_1 : texture_2d<f32>, ext_tex_params : ExternalTextureParams) -> vec2<i32> {
|
||||
return textureDimensions(ext_tex);
|
||||
}
|
||||
)";
|
||||
|
||||
DataMap data;
|
||||
data.Add<MultiplanarExternalTexture::NewBindingPoints>(
|
||||
MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
|
||||
auto got = Run<MultiplanarExternalTexture>(src, data);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
// Tests that the the transform handles aliases to external textures
|
||||
TEST_F(MultiplanarExternalTextureTest, ExternalTextureAlias) {
|
||||
auto* src = R"(
|
||||
type ET = texture_external;
|
||||
|
||||
fn f(t : ET, s : sampler) {
|
||||
textureSampleLevel(t, s, vec2<f32>(1.0, 2.0));
|
||||
}
|
||||
|
||||
[[group(0), binding(0)]] var ext_tex : ET;
|
||||
[[group(0), binding(1)]] var smp : sampler;
|
||||
|
||||
[[stage(fragment)]]
|
||||
fn main() {
|
||||
f(ext_tex, smp);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
type ET = texture_external;
|
||||
|
||||
struct ExternalTextureParams {
|
||||
numPlanes : u32;
|
||||
vr : f32;
|
||||
ug : f32;
|
||||
vg : f32;
|
||||
ub : f32;
|
||||
}
|
||||
|
||||
fn textureSampleExternal(plane0 : texture_2d<f32>, plane1 : texture_2d<f32>, smp : sampler, coord : vec2<f32>, params : ExternalTextureParams) -> vec4<f32> {
|
||||
if ((params.numPlanes == 1u)) {
|
||||
return textureSampleLevel(plane0, smp, coord, 0.0);
|
||||
}
|
||||
let y = (textureSampleLevel(plane0, smp, coord, 0.0).r - 0.0625);
|
||||
let uv = (textureSampleLevel(plane1, smp, coord, 0.0).rg - 0.5);
|
||||
let u = uv.x;
|
||||
let v = uv.y;
|
||||
let r = ((1.164000034 * y) + (params.vr * v));
|
||||
let g = (((1.164000034 * y) - (params.ug * u)) - (params.vg * v));
|
||||
let b = ((1.164000034 * y) + (params.ub * u));
|
||||
return vec4<f32>(r, g, b, 1.0);
|
||||
}
|
||||
|
||||
fn f(t : texture_2d<f32>, ext_tex_plane_1 : texture_2d<f32>, ext_tex_params : ExternalTextureParams, s : sampler) {
|
||||
textureSampleExternal(t, ext_tex_plane_1, s, vec2<f32>(1.0, 2.0), ext_tex_params);
|
||||
}
|
||||
|
||||
@group(0) @binding(2) var ext_tex_plane_1_1 : texture_2d<f32>;
|
||||
|
||||
@group(0) @binding(3) var<uniform> ext_tex_params_1 : ExternalTextureParams;
|
||||
|
||||
@group(0) @binding(0) var ext_tex : texture_2d<f32>;
|
||||
|
||||
@group(0) @binding(1) var smp : sampler;
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
f(ext_tex, ext_tex_plane_1_1, ext_tex_params_1, smp);
|
||||
}
|
||||
)";
|
||||
DataMap data;
|
||||
data.Add<MultiplanarExternalTexture::NewBindingPoints>(
|
||||
MultiplanarExternalTexture::BindingsMap{
|
||||
{{0, 0}, {{0, 2}, {0, 3}}},
|
||||
});
|
||||
auto got = Run<MultiplanarExternalTexture>(src, data);
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace transform
|
||||
} // namespace tint
|
||||
|
|
|
@ -52,13 +52,21 @@ struct Accessor {
|
|||
NumWorkgroupsFromUniform::NumWorkgroupsFromUniform() = default;
|
||||
NumWorkgroupsFromUniform::~NumWorkgroupsFromUniform() = default;
|
||||
|
||||
bool NumWorkgroupsFromUniform::ShouldRun(const Program* program,
|
||||
const DataMap&) const {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (auto* deco = node->As<ast::BuiltinDecoration>()) {
|
||||
if (deco->builtin == ast::Builtin::kNumWorkgroups) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void NumWorkgroupsFromUniform::Run(CloneContext& ctx,
|
||||
const DataMap& inputs,
|
||||
DataMap&) const {
|
||||
if (!Requires<CanonicalizeEntryPointIO>(ctx)) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto* cfg = inputs.Get<Config>();
|
||||
if (cfg == nullptr) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
|
|
|
@ -39,6 +39,9 @@ namespace transform {
|
|||
/// ```
|
||||
/// The binding group and number used for this uniform buffer is provided via
|
||||
/// the `Config` transform input.
|
||||
///
|
||||
/// @note Depends on the following transforms to have been run first:
|
||||
/// * CanonicalizeEntryPointIO
|
||||
class NumWorkgroupsFromUniform
|
||||
: public Castable<NumWorkgroupsFromUniform, Transform> {
|
||||
public:
|
||||
|
@ -63,6 +66,12 @@ class NumWorkgroupsFromUniform
|
|||
sem::BindingPoint ubo_binding;
|
||||
};
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
|
|
|
@ -26,8 +26,28 @@ namespace {
|
|||
|
||||
using NumWorkgroupsFromUniformTest = TransformTest;
|
||||
|
||||
TEST_F(NumWorkgroupsFromUniformTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<NumWorkgroupsFromUniform>(src));
|
||||
}
|
||||
|
||||
TEST_F(NumWorkgroupsFromUniformTest, ShouldRunHasNumWorkgroups) {
|
||||
auto* src = R"(
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main([[builtin(num_workgroups)]] num_wgs : vec3<u32>) {
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<NumWorkgroupsFromUniform>(src));
|
||||
}
|
||||
|
||||
TEST_F(NumWorkgroupsFromUniformTest, Error_MissingTransformData) {
|
||||
auto* src = "";
|
||||
auto* src = R"(
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main([[builtin(num_workgroups)]] num_wgs : vec3<u32>) {
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect =
|
||||
"error: missing transform data for "
|
||||
|
@ -42,19 +62,6 @@ TEST_F(NumWorkgroupsFromUniformTest, Error_MissingTransformData) {
|
|||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(NumWorkgroupsFromUniformTest, Error_MissingCanonicalizeEntryPointIO) {
|
||||
auto* src = "";
|
||||
|
||||
auto* expect =
|
||||
"error: tint::transform::NumWorkgroupsFromUniform depends on "
|
||||
"tint::transform::CanonicalizeEntryPointIO but the dependency was not "
|
||||
"run";
|
||||
|
||||
auto got = Run<NumWorkgroupsFromUniform>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(NumWorkgroupsFromUniformTest, Basic) {
|
||||
auto* src = R"(
|
||||
@stage(compute) @workgroup_size(1)
|
||||
|
|
|
@ -97,6 +97,19 @@ PadArrayElements::PadArrayElements() = default;
|
|||
|
||||
PadArrayElements::~PadArrayElements() = default;
|
||||
|
||||
bool PadArrayElements::ShouldRun(const Program* program, const DataMap&) const {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (auto* var = node->As<ast::Type>()) {
|
||||
if (auto* arr = program->Sem().Get<sem::Array>(var)) {
|
||||
if (!arr->IsStrideImplicit()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void PadArrayElements::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
auto& sem = ctx.src->Sem();
|
||||
|
||||
|
|
|
@ -38,6 +38,12 @@ class PadArrayElements : public Castable<PadArrayElements, Transform> {
|
|||
/// Destructor
|
||||
~PadArrayElements() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
|
|
|
@ -24,6 +24,28 @@ namespace {
|
|||
|
||||
using PadArrayElementsTest = TransformTest;
|
||||
|
||||
TEST_F(PadArrayElementsTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<PadArrayElements>(src));
|
||||
}
|
||||
|
||||
TEST_F(PadArrayElementsTest, ShouldRunHasImplicitArrayStride) {
|
||||
auto* src = R"(
|
||||
var<private> arr : array<i32, 4>;
|
||||
)";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<PadArrayElements>(src));
|
||||
}
|
||||
|
||||
TEST_F(PadArrayElementsTest, ShouldRunHasExplicitArrayStride) {
|
||||
auto* src = R"(
|
||||
var<private> arr : [[stride(8)]] array<i32, 4>;
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<PadArrayElements>(src));
|
||||
}
|
||||
|
||||
TEST_F(PadArrayElementsTest, EmptyModule) {
|
||||
auto* src = "";
|
||||
auto* expect = "";
|
||||
|
|
|
@ -68,6 +68,15 @@ RemovePhonies::RemovePhonies() = default;
|
|||
|
||||
RemovePhonies::~RemovePhonies() = default;
|
||||
|
||||
bool RemovePhonies::ShouldRun(const Program* program, const DataMap&) const {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (node->Is<ast::PhonyExpression>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
auto& sem = ctx.src->Sem();
|
||||
|
||||
|
|
|
@ -34,6 +34,12 @@ class RemovePhonies : public Castable<RemovePhonies, Transform> {
|
|||
/// Destructor
|
||||
~RemovePhonies() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
|
|
|
@ -26,6 +26,22 @@ namespace {
|
|||
|
||||
using RemovePhoniesTest = TransformTest;
|
||||
|
||||
TEST_F(RemovePhoniesTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<RemovePhonies>(src));
|
||||
}
|
||||
|
||||
TEST_F(RemovePhoniesTest, ShouldRunHasPhony) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
_ = 1;
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<RemovePhonies>(src));
|
||||
}
|
||||
|
||||
TEST_F(RemovePhoniesTest, EmptyModule) {
|
||||
auto* src = "";
|
||||
auto* expect = "";
|
||||
|
|
|
@ -37,6 +37,18 @@ RemoveUnreachableStatements::RemoveUnreachableStatements() = default;
|
|||
|
||||
RemoveUnreachableStatements::~RemoveUnreachableStatements() = default;
|
||||
|
||||
bool RemoveUnreachableStatements::ShouldRun(const Program* program,
|
||||
const DataMap&) const {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (auto* stmt = program->Sem().Get<sem::Statement>(node)) {
|
||||
if (!stmt->IsReachable()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void RemoveUnreachableStatements::Run(CloneContext& ctx,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
|
|
|
@ -34,6 +34,12 @@ class RemoveUnreachableStatements
|
|||
/// Destructor
|
||||
~RemoveUnreachableStatements() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
|
|
|
@ -22,6 +22,37 @@ namespace {
|
|||
|
||||
using RemoveUnreachableStatementsTest = TransformTest;
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<RemoveUnreachableStatements>(src));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, ShouldRunHasNoUnreachable) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
if (true) {
|
||||
var x = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<RemoveUnreachableStatements>(src));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, ShouldRunHasUnreachable) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
return;
|
||||
if (true) {
|
||||
var x = 1;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<RemoveUnreachableStatements>(src));
|
||||
}
|
||||
|
||||
TEST_F(RemoveUnreachableStatementsTest, EmptyModule) {
|
||||
auto* src = "";
|
||||
auto* expect = "";
|
||||
|
|
|
@ -232,9 +232,6 @@ SimplifyPointers::SimplifyPointers() = default;
|
|||
SimplifyPointers::~SimplifyPointers() = default;
|
||||
|
||||
void SimplifyPointers::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
if (!Requires<Unshadow>(ctx)) {
|
||||
return;
|
||||
}
|
||||
State(ctx).Run();
|
||||
}
|
||||
|
||||
|
|
|
@ -29,6 +29,9 @@ namespace transform {
|
|||
/// Note: SimplifyPointers does not operate on module-scope `let`s, as these
|
||||
/// cannot be pointers: https://gpuweb.github.io/gpuweb/wgsl/#module-constants
|
||||
/// `A module-scope let-declared constant must be of constructible type.`
|
||||
///
|
||||
/// @note Depends on the following transforms to have been run first:
|
||||
/// * Unshadow
|
||||
class SimplifyPointers : public Castable<SimplifyPointers, Transform> {
|
||||
public:
|
||||
/// Constructor
|
||||
|
|
|
@ -23,18 +23,6 @@ namespace {
|
|||
|
||||
using SimplifyPointersTest = TransformTest;
|
||||
|
||||
TEST_F(SimplifyPointersTest, Error_MissingSimplifyPointers) {
|
||||
auto* src = "";
|
||||
|
||||
auto* expect =
|
||||
"error: tint::transform::SimplifyPointers depends on "
|
||||
"tint::transform::Unshadow but the dependency was not run";
|
||||
|
||||
auto got = Run<SimplifyPointers>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(SimplifyPointersTest, EmptyModule) {
|
||||
auto* src = "";
|
||||
auto* expect = "";
|
||||
|
|
|
@ -107,11 +107,6 @@ void SingleEntryPoint::Run(CloneContext& ctx,
|
|||
|
||||
// Clone the entry point.
|
||||
ctx.dst->AST().AddFunction(ctx.Clone(entry_point));
|
||||
|
||||
// Retain the list of applied transforms.
|
||||
// We need to do this manually since we are not going to use the top-level
|
||||
// ctx.Clone() function.
|
||||
ctx.dst->SetTransformApplied(ctx.src->TransformsApplied());
|
||||
}
|
||||
|
||||
SingleEntryPoint::Config::Config(std::string entry_point)
|
||||
|
|
|
@ -81,6 +81,17 @@ class TransformTestBase : public BASE {
|
|||
return manager.Run(&program, data);
|
||||
}
|
||||
|
||||
/// @param in the input WGSL source
|
||||
/// @param data the optional DataMap to pass to Transform::Run()
|
||||
/// @return true if the transform should be run for the given input.
|
||||
template <typename TRANSFORM>
|
||||
bool ShouldRun(std::string in, const DataMap& data = {}) {
|
||||
auto file = std::make_unique<Source::File>("test", in);
|
||||
auto program = reader::wgsl::Parse(file.get());
|
||||
EXPECT_TRUE(program.IsValid()) << program.Diagnostics().str();
|
||||
return TRANSFORM().ShouldRun(&program, data);
|
||||
}
|
||||
|
||||
/// @param output the output of the transform
|
||||
/// @returns the output program as a WGSL string, or an error string if the
|
||||
/// program is not valid.
|
||||
|
|
|
@ -52,7 +52,6 @@ Output Transform::Run(const Program* program,
|
|||
CloneContext ctx(&builder, program);
|
||||
Output output;
|
||||
Run(ctx, data, output.data);
|
||||
builder.SetTransformApplied(this);
|
||||
output.program = Program(std::move(builder));
|
||||
return output;
|
||||
}
|
||||
|
@ -62,22 +61,7 @@ void Transform::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
|||
<< "Transform::Run() unimplemented for " << TypeInfo().name;
|
||||
}
|
||||
|
||||
bool Transform::ShouldRun(const Program*) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Transform::Requires(
|
||||
CloneContext& ctx,
|
||||
std::initializer_list<const ::tint::TypeInfo*> deps) const {
|
||||
for (auto* dep : deps) {
|
||||
if (!ctx.src->HasTransformApplied(dep)) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
diag::System::Transform, std::string(TypeInfo().name) +
|
||||
" depends on " + std::string(dep->name) +
|
||||
" but the dependency was not run");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
bool Transform::ShouldRun(const Program*, const DataMap&) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -160,8 +160,10 @@ class Transform : public Castable<Transform> {
|
|||
virtual Output Run(const Program* program, const DataMap& data = {}) const;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
virtual bool ShouldRun(const Program* program) const;
|
||||
virtual bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
|
@ -174,23 +176,6 @@ class Transform : public Castable<Transform> {
|
|||
const DataMap& inputs,
|
||||
DataMap& outputs) const;
|
||||
|
||||
/// Requires appends an error diagnostic to `ctx.dst` if the template type
|
||||
/// transforms were not already run on `ctx.src`.
|
||||
/// @param ctx the CloneContext
|
||||
/// @returns true if all dependency transforms have been run
|
||||
template <typename... TRANSFORMS>
|
||||
bool Requires(CloneContext& ctx) const {
|
||||
return Requires(ctx, {&::tint::TypeInfo::Of<TRANSFORMS>()...});
|
||||
}
|
||||
|
||||
/// Requires appends an error diagnostic to `ctx.dst` if the list of
|
||||
/// Transforms were not already run on `ctx.src`.
|
||||
/// @param ctx the CloneContext
|
||||
/// @param deps the list of Transform TypeInfos
|
||||
/// @returns true if all dependency transforms have been run
|
||||
bool Requires(CloneContext& ctx,
|
||||
std::initializer_list<const ::tint::TypeInfo*> deps) const;
|
||||
|
||||
/// Removes the statement `stmt` from the transformed program.
|
||||
/// RemoveStatement handles edge cases, like statements in the initializer and
|
||||
/// continuing of for-loops.
|
||||
|
|
|
@ -32,6 +32,22 @@ VectorizeScalarMatrixConstructors::VectorizeScalarMatrixConstructors() =
|
|||
VectorizeScalarMatrixConstructors::~VectorizeScalarMatrixConstructors() =
|
||||
default;
|
||||
|
||||
bool VectorizeScalarMatrixConstructors::ShouldRun(const Program* program,
|
||||
const DataMap&) const {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (auto* call = program->Sem().Get<sem::Call>(node)) {
|
||||
if (call->Target()->Is<sem::TypeConstructor>() &&
|
||||
call->Type()->Is<sem::Matrix>()) {
|
||||
auto& args = call->Arguments();
|
||||
if (args.size() > 0 && args[0]->Type()->is_scalar()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void VectorizeScalarMatrixConstructors::Run(CloneContext& ctx,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
|
|
|
@ -30,6 +30,12 @@ class VectorizeScalarMatrixConstructors
|
|||
/// Destructor
|
||||
~VectorizeScalarMatrixConstructors() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
|
|
|
@ -27,6 +27,12 @@ namespace {
|
|||
using VectorizeScalarMatrixConstructorsTest =
|
||||
TransformTestWithParam<std::pair<uint32_t, uint32_t>>;
|
||||
|
||||
TEST_F(VectorizeScalarMatrixConstructorsTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
|
||||
}
|
||||
|
||||
TEST_P(VectorizeScalarMatrixConstructorsTest, Basic) {
|
||||
uint32_t cols = GetParam().first;
|
||||
uint32_t rows = GetParam().second;
|
||||
|
@ -63,6 +69,8 @@ fn main() {
|
|||
auto src = utils::ReplaceAll(tmpl, "${values}", scalar_values);
|
||||
auto expect = utils::ReplaceAll(tmpl, "${values}", vector_values);
|
||||
|
||||
EXPECT_TRUE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
|
||||
|
||||
auto got = Run<VectorizeScalarMatrixConstructors>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
|
@ -92,6 +100,8 @@ fn main() {
|
|||
auto src = utils::ReplaceAll(tmpl, "${columns}", columns);
|
||||
auto expect = src;
|
||||
|
||||
EXPECT_FALSE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
|
||||
|
||||
auto got = Run<VectorizeScalarMatrixConstructors>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
|
|
|
@ -38,6 +38,16 @@ WrapArraysInStructs::WrapArraysInStructs() = default;
|
|||
|
||||
WrapArraysInStructs::~WrapArraysInStructs() = default;
|
||||
|
||||
bool WrapArraysInStructs::ShouldRun(const Program* program,
|
||||
const DataMap&) const {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (program->Sem().Get<sem::Array>(node->As<ast::Type>())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void WrapArraysInStructs::Run(CloneContext& ctx,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
|
|
|
@ -44,6 +44,12 @@ class WrapArraysInStructs : public Castable<WrapArraysInStructs, Transform> {
|
|||
/// Destructor
|
||||
~WrapArraysInStructs() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
|
|
|
@ -25,9 +25,23 @@ namespace {
|
|||
|
||||
using WrapArraysInStructsTest = TransformTest;
|
||||
|
||||
TEST_F(WrapArraysInStructsTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<WrapArraysInStructs>(src));
|
||||
}
|
||||
|
||||
TEST_F(WrapArraysInStructsTest, ShouldRunHasArray) {
|
||||
auto* src = R"(
|
||||
var<private> arr : array<i32, 4>;
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<WrapArraysInStructs>(src));
|
||||
}
|
||||
|
||||
TEST_F(WrapArraysInStructsTest, EmptyModule) {
|
||||
auto* src = "";
|
||||
auto* expect = "";
|
||||
auto* src = R"()";
|
||||
auto* expect = src;
|
||||
|
||||
auto got = Run<WrapArraysInStructs>(src);
|
||||
|
||||
|
|
|
@ -433,6 +433,18 @@ ZeroInitWorkgroupMemory::ZeroInitWorkgroupMemory() = default;
|
|||
|
||||
ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default;
|
||||
|
||||
bool ZeroInitWorkgroupMemory::ShouldRun(const Program* program,
|
||||
const DataMap&) const {
|
||||
for (auto* decl : program->AST().GlobalDeclarations()) {
|
||||
if (auto* var = decl->As<ast::Variable>()) {
|
||||
if (var->declared_storage_class == ast::StorageClass::kWorkgroup) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ZeroInitWorkgroupMemory::Run(CloneContext& ctx,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
|
|
|
@ -32,6 +32,12 @@ class ZeroInitWorkgroupMemory
|
|||
/// Destructor
|
||||
~ZeroInitWorkgroupMemory() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program,
|
||||
const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
|
|
|
@ -24,6 +24,28 @@ namespace {
|
|||
|
||||
using ZeroInitWorkgroupMemoryTest = TransformTest;
|
||||
|
||||
TEST_F(ZeroInitWorkgroupMemoryTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<ZeroInitWorkgroupMemory>(src));
|
||||
}
|
||||
|
||||
TEST_F(ZeroInitWorkgroupMemoryTest, ShouldRunHasNoWorkgroupVars) {
|
||||
auto* src = R"(
|
||||
var<private> v : i32;
|
||||
)";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<ZeroInitWorkgroupMemory>(src));
|
||||
}
|
||||
|
||||
TEST_F(ZeroInitWorkgroupMemoryTest, ShouldRunHasWorkgroupVars) {
|
||||
auto* src = R"(
|
||||
var<workgroup> a : i32;
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<ZeroInitWorkgroupMemory>(src));
|
||||
}
|
||||
|
||||
TEST_F(ZeroInitWorkgroupMemoryTest, EmptyModule) {
|
||||
auto* src = "";
|
||||
auto* expect = src;
|
||||
|
|
|
@ -124,14 +124,6 @@ GeneratorImpl::GeneratorImpl(const Program* program) : TextGenerator(program) {}
|
|||
GeneratorImpl::~GeneratorImpl() = default;
|
||||
|
||||
bool GeneratorImpl::Generate() {
|
||||
if (!builder_.HasTransformApplied<transform::Glsl>()) {
|
||||
diagnostics_.add_error(
|
||||
diag::System::Writer,
|
||||
"GLSL writer requires the transform::Glsl sanitizer to have been "
|
||||
"applied to the input program");
|
||||
return false;
|
||||
}
|
||||
|
||||
const TypeInfo* last_kind = nullptr;
|
||||
size_t last_padding_line = 0;
|
||||
|
||||
|
|
|
@ -21,16 +21,6 @@ namespace {
|
|||
|
||||
using GlslGeneratorImplTest = TestHelper;
|
||||
|
||||
TEST_F(GlslGeneratorImplTest, ErrorIfSanitizerNotRun) {
|
||||
auto program = std::make_unique<Program>(std::move(*this));
|
||||
GeneratorImpl gen(program.get());
|
||||
EXPECT_FALSE(gen.Generate());
|
||||
EXPECT_EQ(
|
||||
gen.error(),
|
||||
"error: GLSL writer requires the transform::Glsl sanitizer to have been "
|
||||
"applied to the input program");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest, Generate) {
|
||||
Func("my_func", ast::VariableList{}, ty.void_(), ast::StatementList{},
|
||||
ast::DecorationList{});
|
||||
|
|
|
@ -44,9 +44,6 @@ class TestHelperBase : public BODY, public ProgramBuilder {
|
|||
if (gen_) {
|
||||
return *gen_;
|
||||
}
|
||||
// Fake that the GLSL sanitizer has been applied, so that we can unit test
|
||||
// the writer without it erroring.
|
||||
SetTransformApplied<transform::Glsl>();
|
||||
[&]() {
|
||||
ASSERT_TRUE(IsValid()) << "Builder program is not valid\n"
|
||||
<< diag::Formatter().format(Diagnostics());
|
||||
|
|
|
@ -187,9 +187,10 @@ SanitizedResult Sanitize(
|
|||
if (!result.program.IsValid()) {
|
||||
return result;
|
||||
}
|
||||
result.used_array_length_from_uniform_indices =
|
||||
std::move(out.data.Get<transform::ArrayLengthFromUniform::Result>()
|
||||
->used_size_indices);
|
||||
if (auto* res = out.data.Get<transform::ArrayLengthFromUniform::Result>()) {
|
||||
result.used_array_length_from_uniform_indices =
|
||||
std::move(res->used_size_indices);
|
||||
}
|
||||
result.needs_storage_buffer_sizes =
|
||||
!result.used_array_length_from_uniform_indices.empty();
|
||||
return result;
|
||||
|
|
Loading…
Reference in New Issue