diff --git a/src/clone_context.cc b/src/clone_context.cc index 31b49cd171..71d8851b0b 100644 --- a/src/clone_context.cc +++ b/src/clone_context.cc @@ -57,7 +57,6 @@ Symbol CloneContext::Clone(Symbol s) { void CloneContext::Clone() { dst->AST().Copy(this, &src->AST()); - dst->SetTransformApplied(src->TransformsApplied()); } ast::FunctionList CloneContext::Clone(const ast::FunctionList& v) { diff --git a/src/program.cc b/src/program.cc index b838c7d6d5..e1fdafb0e9 100644 --- a/src/program.cc +++ b/src/program.cc @@ -42,7 +42,6 @@ Program::Program(Program&& program) sem_(std::move(program.sem_)), symbols_(std::move(program.symbols_)), diagnostics_(std::move(program.diagnostics_)), - transforms_applied_(std::move(program.transforms_applied_)), is_valid_(program.is_valid_) { program.AssertNotMoved(); program.moved_ = true; @@ -67,7 +66,6 @@ Program::Program(ProgramBuilder&& builder) { sem_ = std::move(builder.Sem()); symbols_ = std::move(builder.Symbols()); diagnostics_.add(std::move(builder.Diagnostics())); - transforms_applied_ = builder.TransformsApplied(); builder.MarkAsMoved(); if (!is_valid_ && !diagnostics_.contains_errors()) { @@ -92,7 +90,6 @@ Program& Program::operator=(Program&& program) { sem_ = std::move(program.sem_); symbols_ = std::move(program.symbols_); diagnostics_ = std::move(program.diagnostics_); - transforms_applied_ = std::move(program.transforms_applied_); is_valid_ = program.is_valid_; return *this; } diff --git a/src/program.h b/src/program.h index f5ceca7252..a44c69222f 100644 --- a/src/program.h +++ b/src/program.h @@ -126,25 +126,6 @@ class Program { /// information bool IsValid() const; - /// @return the TypeInfo pointers of all transforms that have been applied to - /// this program. - std::unordered_set TransformsApplied() const { - return transforms_applied_; - } - - /// @param transform the TypeInfo of the transform - /// @returns true if the transform with the given TypeInfo was applied to the - /// Program - bool HasTransformApplied(const TypeInfo* transform) const { - return transforms_applied_.count(transform); - } - - /// @returns true if the transform of type `T` was applied. - template - bool HasTransformApplied() const { - return HasTransformApplied(&TypeInfo::Of()); - } - /// Helper for returning the resolved semantic type of the expression `expr`. /// @param expr the AST expression /// @return the resolved semantic type for the expression, or nullptr if the @@ -184,7 +165,6 @@ class Program { sem::Info sem_; SymbolTable symbols_{id_}; diag::List diagnostics_; - std::unordered_set transforms_applied_; bool is_valid_ = false; // Not valid until it is built bool moved_ = false; }; diff --git a/src/program_builder.cc b/src/program_builder.cc index c7e6751420..08635aa338 100644 --- a/src/program_builder.cc +++ b/src/program_builder.cc @@ -38,8 +38,7 @@ ProgramBuilder::ProgramBuilder(ProgramBuilder&& rhs) ast_(rhs.ast_), sem_(std::move(rhs.sem_)), symbols_(std::move(rhs.symbols_)), - diagnostics_(std::move(rhs.diagnostics_)), - transforms_applied_(std::move(rhs.transforms_applied_)) { + diagnostics_(std::move(rhs.diagnostics_)) { rhs.MarkAsMoved(); } @@ -56,7 +55,7 @@ ProgramBuilder& ProgramBuilder::operator=(ProgramBuilder&& rhs) { sem_ = std::move(rhs.sem_); symbols_ = std::move(rhs.symbols_); diagnostics_ = std::move(rhs.diagnostics_); - transforms_applied_ = std::move(rhs.transforms_applied_); + return *this; } @@ -69,7 +68,6 @@ ProgramBuilder ProgramBuilder::Wrap(const Program* program) { builder.sem_ = sem::Info::Wrap(program->Sem()); builder.symbols_ = program->Symbols(); builder.diagnostics_ = program->Diagnostics(); - builder.transforms_applied_ = program->TransformsApplied(); return builder; } diff --git a/src/program_builder.h b/src/program_builder.h index 0fd946fb25..83710e235a 100644 --- a/src/program_builder.h +++ b/src/program_builder.h @@ -2469,49 +2469,6 @@ class ProgramBuilder { source_ = Source(loc); } - /// Marks that the given transform has been applied to this program. - /// @param transform the transform that has been applied - void SetTransformApplied(const CastableBase* transform) { - transforms_applied_.emplace(&transform->TypeInfo()); - } - - /// Marks that the given transform `T` has been applied to this program. - template - void SetTransformApplied() { - transforms_applied_.emplace(&TypeInfo::Of()); - } - - /// Marks that the transforms with the given TypeInfos have been applied to - /// this program. - /// @param transforms the set of transform TypeInfos that has been applied - void SetTransformApplied( - const std::unordered_set& transforms) { - for (auto* transform : transforms) { - transforms_applied_.emplace(transform); - } - } - - /// Unmarks that the given transform `T` has been applied to this program. - template - void UnsetTransformApplied() { - auto it = transforms_applied_.find(&TypeInfo::Of()); - if (it != transforms_applied_.end()) { - transforms_applied_.erase(it); - } - } - - /// @returns true if the transform of type `T` was applied. - template - bool HasTransformApplied() { - return transforms_applied_.count(&TypeInfo::Of()); - } - - /// @return the TypeInfo pointers of all transforms that have been applied to - /// this program. - std::unordered_set TransformsApplied() const { - return transforms_applied_; - } - /// Helper for returning the resolved semantic type of the expression `expr`. /// @note As the Resolver is run when the Program is built, this will only be /// useful for the Resolver itself and tests that use their own Resolver. @@ -2592,7 +2549,6 @@ class ProgramBuilder { sem::Info sem_; SymbolTable symbols_{id_}; diag::List diagnostics_; - std::unordered_set transforms_applied_; /// The source to use when creating AST nodes without providing a Source as /// the first argument. diff --git a/src/transform/add_empty_entry_point.cc b/src/transform/add_empty_entry_point.cc index cbc7e0cc07..824c636589 100644 --- a/src/transform/add_empty_entry_point.cc +++ b/src/transform/add_empty_entry_point.cc @@ -27,15 +27,19 @@ AddEmptyEntryPoint::AddEmptyEntryPoint() = default; AddEmptyEntryPoint::~AddEmptyEntryPoint() = default; +bool AddEmptyEntryPoint::ShouldRun(const Program* program, + const DataMap&) const { + for (auto* func : program->AST().Functions()) { + if (func->IsEntryPoint()) { + return false; + } + } + return true; +} + void AddEmptyEntryPoint::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - for (auto* func : ctx.src->AST().Functions()) { - if (func->IsEntryPoint()) { - ctx.Clone(); - return; - } - } ctx.dst->Func(ctx.dst->Symbols().New("unused_entry_point"), {}, ctx.dst->ty.void_(), {}, {ctx.dst->Stage(ast::PipelineStage::kCompute), diff --git a/src/transform/add_empty_entry_point.h b/src/transform/add_empty_entry_point.h index e955080562..31821acb28 100644 --- a/src/transform/add_empty_entry_point.h +++ b/src/transform/add_empty_entry_point.h @@ -28,6 +28,12 @@ class AddEmptyEntryPoint : public Castable { /// Destructor ~AddEmptyEntryPoint() override; + /// @param program the program to inspect + /// @param data optional extra transform-specific input data + /// @returns true if this transform should be run for the given program + bool ShouldRun(const Program* program, + const DataMap& data = {}) const override; + protected: /// Runs the transform using the CloneContext built for transforming a /// program. Run() is responsible for calling Clone() on the CloneContext. diff --git a/src/transform/add_empty_entry_point_test.cc b/src/transform/add_empty_entry_point_test.cc index 4cee0bc601..df25e8894d 100644 --- a/src/transform/add_empty_entry_point_test.cc +++ b/src/transform/add_empty_entry_point_test.cc @@ -24,6 +24,21 @@ namespace { using AddEmptyEntryPointTest = TransformTest; +TEST_F(AddEmptyEntryPointTest, ShouldRunEmptyModule) { + auto* src = R"()"; + + EXPECT_TRUE(ShouldRun(src)); +} + +TEST_F(AddEmptyEntryPointTest, ShouldRunExistingEntryPoint) { + auto* src = R"( +[[stage(compute), workgroup_size(1)]] +fn existing() {} +)"; + + EXPECT_FALSE(ShouldRun(src)); +} + TEST_F(AddEmptyEntryPointTest, EmptyModule) { auto* src = R"()"; diff --git a/src/transform/array_length_from_uniform.cc b/src/transform/array_length_from_uniform.cc index b34dad3672..80edb88b95 100644 --- a/src/transform/array_length_from_uniform.cc +++ b/src/transform/array_length_from_uniform.cc @@ -20,6 +20,7 @@ #include "src/program_builder.h" #include "src/sem/call.h" +#include "src/sem/function.h" #include "src/sem/variable.h" #include "src/transform/simplify_pointers.h" @@ -93,13 +94,23 @@ static void IterateArrayLengthOnStorageVar(CloneContext& ctx, F&& functor) { } } +bool ArrayLengthFromUniform::ShouldRun(const Program* program, + const DataMap&) const { + for (auto* fn : program->AST().Functions()) { + if (auto* sem_fn = program->Sem().Get(fn)) { + for (auto* intrinsic : sem_fn->DirectlyCalledIntrinsics()) { + if (intrinsic->Type() == sem::IntrinsicType::kArrayLength) { + return true; + } + } + } + } + return false; +} + void ArrayLengthFromUniform::Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const { - if (!Requires(ctx)) { - return; - } - auto* cfg = inputs.Get(); if (cfg == nullptr) { ctx.dst->Diagnostics().add_error( diff --git a/src/transform/array_length_from_uniform.h b/src/transform/array_length_from_uniform.h index bfeaffe033..6947350c63 100644 --- a/src/transform/array_length_from_uniform.h +++ b/src/transform/array_length_from_uniform.h @@ -49,6 +49,9 @@ namespace transform { /// This transform assumes that the `SimplifyPointers` /// transforms have been run before it so that arguments to the arrayLength /// builtin always have the form `&resource.array`. +/// +/// @note Depends on the following transforms to have been run first: +/// * SimplifyPointers class ArrayLengthFromUniform : public Castable { public: @@ -81,6 +84,8 @@ class ArrayLengthFromUniform }; /// Information produced about what the transform did. + /// If there were no calls to the arrayLength() intrinsic, then no Result will + /// be emitted. struct Result : public Castable { /// Constructor /// @param used_size_indices Indices into the UBO that are statically used. @@ -96,6 +101,12 @@ class ArrayLengthFromUniform const std::unordered_set used_size_indices; }; + /// @param program the program to inspect + /// @param data optional extra transform-specific input data + /// @returns true if this transform should be run for the given program + bool ShouldRun(const Program* program, + const DataMap& data = {}) const override; + protected: /// Runs the transform using the CloneContext built for transforming a /// program. Run() is responsible for calling Clone() on the CloneContext. diff --git a/src/transform/array_length_from_uniform_test.cc b/src/transform/array_length_from_uniform_test.cc index 944fc6a4f4..01940c6acc 100644 --- a/src/transform/array_length_from_uniform_test.cc +++ b/src/transform/array_length_from_uniform_test.cc @@ -26,8 +26,61 @@ namespace { using ArrayLengthFromUniformTest = TransformTest; +TEST_F(ArrayLengthFromUniformTest, ShouldRunEmptyModule) { + auto* src = R"()"; + + EXPECT_FALSE(ShouldRun(src)); +} + +TEST_F(ArrayLengthFromUniformTest, ShouldRunNoArrayLength) { + auto* src = R"( +struct SB { + x : i32; + arr : array; +}; + +[[group(0), binding(0)]] var sb : SB; + +[[stage(compute), workgroup_size(1)]] +fn main() { +} +)"; + + EXPECT_FALSE(ShouldRun(src)); +} + +TEST_F(ArrayLengthFromUniformTest, ShouldRunWithArrayLength) { + auto* src = R"( +struct SB { + x : i32; + arr : array; +}; + +[[group(0), binding(0)]] var sb : SB; + +[[stage(compute), workgroup_size(1)]] +fn main() { + var len : u32 = arrayLength(&sb.arr); +} +)"; + + EXPECT_TRUE(ShouldRun(src)); +} + TEST_F(ArrayLengthFromUniformTest, Error_MissingTransformData) { - auto* src = ""; + auto* src = R"( +struct SB { + x : i32; + arr : array; +}; + +[[group(0), binding(0)]] var sb : SB; + +[[stage(compute), workgroup_size(1)]] +fn main() { + var len : u32 = arrayLength(&sb.arr); +} +)"; auto* expect = "error: missing transform data for " @@ -38,18 +91,6 @@ TEST_F(ArrayLengthFromUniformTest, Error_MissingTransformData) { EXPECT_EQ(expect, str(got)); } -TEST_F(ArrayLengthFromUniformTest, Error_MissingSimplifyPointers) { - auto* src = ""; - - auto* expect = - "error: tint::transform::ArrayLengthFromUniform depends on " - "tint::transform::SimplifyPointers but the dependency was not run"; - - auto got = Run(src); - - EXPECT_EQ(expect, str(got)); -} - TEST_F(ArrayLengthFromUniformTest, Basic) { auto* src = R"( @group(0) @binding(0) var sb : array; @@ -426,8 +467,7 @@ fn main() { auto got = Run(src, data); EXPECT_EQ(src, str(got)); - EXPECT_EQ(std::unordered_set(), - got.data.Get()->used_size_indices); + EXPECT_EQ(got.data.Get(), nullptr); } TEST_F(ArrayLengthFromUniformTest, MissingBindingPointToIndexMapping) { diff --git a/src/transform/binding_remapper.cc b/src/transform/binding_remapper.cc index feff411ebd..178753b1db 100644 --- a/src/transform/binding_remapper.cc +++ b/src/transform/binding_remapper.cc @@ -42,6 +42,14 @@ BindingRemapper::Remappings::~Remappings() = default; BindingRemapper::BindingRemapper() = default; BindingRemapper::~BindingRemapper() = default; +bool BindingRemapper::ShouldRun(const Program*, const DataMap& inputs) const { + if (auto* remappings = inputs.Get()) { + return !remappings->binding_points.empty() || + !remappings->access_controls.empty(); + } + return false; +} + void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { diff --git a/src/transform/binding_remapper.h b/src/transform/binding_remapper.h index da837bddf4..0af33f6f57 100644 --- a/src/transform/binding_remapper.h +++ b/src/transform/binding_remapper.h @@ -68,6 +68,12 @@ class BindingRemapper : public Castable { BindingRemapper(); ~BindingRemapper() override; + /// @param program the program to inspect + /// @param data optional extra transform-specific input data + /// @returns true if this transform should be run for the given program + bool ShouldRun(const Program* program, + const DataMap& data = {}) const override; + protected: /// Runs the transform using the CloneContext built for transforming a /// program. Run() is responsible for calling Clone() on the CloneContext. diff --git a/src/transform/binding_remapper_test.cc b/src/transform/binding_remapper_test.cc index 5827cb2af9..3145f13bfb 100644 --- a/src/transform/binding_remapper_test.cc +++ b/src/transform/binding_remapper_test.cc @@ -24,6 +24,47 @@ namespace { using BindingRemapperTest = TransformTest; +TEST_F(BindingRemapperTest, ShouldRunNoRemappings) { + auto* src = R"()"; + + EXPECT_FALSE(ShouldRun(src)); +} + +TEST_F(BindingRemapperTest, ShouldRunEmptyRemappings) { + auto* src = R"()"; + + DataMap data; + data.Add(BindingRemapper::BindingPoints{}, + BindingRemapper::AccessControls{}); + + EXPECT_FALSE(ShouldRun(src, data)); +} + +TEST_F(BindingRemapperTest, ShouldRunBindingPointRemappings) { + auto* src = R"()"; + + DataMap data; + data.Add( + BindingRemapper::BindingPoints{ + {{2, 1}, {1, 2}}, + }, + BindingRemapper::AccessControls{}); + + EXPECT_TRUE(ShouldRun(src, data)); +} + +TEST_F(BindingRemapperTest, ShouldRunAccessControlRemappings) { + auto* src = R"()"; + + DataMap data; + data.Add(BindingRemapper::BindingPoints{}, + BindingRemapper::AccessControls{ + {{2, 1}, ast::Access::kWrite}, + }); + + EXPECT_TRUE(ShouldRun(src, data)); +} + TEST_F(BindingRemapperTest, NoRemappings) { auto* src = R"( struct S { @@ -359,17 +400,18 @@ TEST_F(BindingRemapperTest, NoData) { auto* src = R"( struct S { a : f32; -}; +} @group(2) @binding(1) var a : S; + @group(3) @binding(2) var b : S; @stage(compute) @workgroup_size(1) -fn f() {} +fn f() { +} )"; - auto* expect = - "error: missing transform data for tint::transform::BindingRemapper"; + auto* expect = src; auto got = Run(src); diff --git a/src/transform/calculate_array_length.cc b/src/transform/calculate_array_length.cc index 9d560ed9f6..c54c3e8aaf 100644 --- a/src/transform/calculate_array_length.cc +++ b/src/transform/calculate_array_length.cc @@ -22,6 +22,7 @@ #include "src/program_builder.h" #include "src/sem/block_statement.h" #include "src/sem/call.h" +#include "src/sem/function.h" #include "src/sem/statement.h" #include "src/sem/struct.h" #include "src/sem/variable.h" @@ -71,13 +72,24 @@ CalculateArrayLength::BufferSizeIntrinsic::Clone(CloneContext* ctx) const { CalculateArrayLength::CalculateArrayLength() = default; CalculateArrayLength::~CalculateArrayLength() = default; +bool CalculateArrayLength::ShouldRun(const Program* program, + const DataMap&) const { + for (auto* fn : program->AST().Functions()) { + if (auto* sem_fn = program->Sem().Get(fn)) { + for (auto* intrinsic : sem_fn->DirectlyCalledIntrinsics()) { + if (intrinsic->Type() == sem::IntrinsicType::kArrayLength) { + return true; + } + } + } + } + return false; +} + void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) const { auto& sem = ctx.src->Sem(); - if (!Requires(ctx)) { - return; - } // get_buffer_size_intrinsic() emits the function decorated with // BufferSizeIntrinsic that is transformed by the HLSL writer into a call to diff --git a/src/transform/calculate_array_length.h b/src/transform/calculate_array_length.h index 6394dd7272..1ae84740bc 100644 --- a/src/transform/calculate_array_length.h +++ b/src/transform/calculate_array_length.h @@ -29,6 +29,9 @@ namespace transform { /// CalculateArrayLength is a transform used to replace calls to arrayLength() /// with a value calculated from the size of the storage buffer. +/// +/// @note Depends on the following transforms to have been run first: +/// * SimplifyPointers class CalculateArrayLength : public Castable { public: /// BufferSizeIntrinsic is an InternalDecoration that's applied to intrinsic @@ -56,6 +59,12 @@ class CalculateArrayLength : public Castable { /// Destructor ~CalculateArrayLength() override; + /// @param program the program to inspect + /// @param data optional extra transform-specific input data + /// @returns true if this transform should be run for the given program + bool ShouldRun(const Program* program, + const DataMap& data = {}) const override; + protected: /// Runs the transform using the CloneContext built for transforming a /// program. Run() is responsible for calling Clone() on the CloneContext. diff --git a/src/transform/calculate_array_length_test.cc b/src/transform/calculate_array_length_test.cc index 88643aecc7..cebb27647e 100644 --- a/src/transform/calculate_array_length_test.cc +++ b/src/transform/calculate_array_length_test.cc @@ -24,16 +24,45 @@ namespace { using CalculateArrayLengthTest = TransformTest; -TEST_F(CalculateArrayLengthTest, Error_MissingCalculateArrayLength) { - auto* src = ""; +TEST_F(CalculateArrayLengthTest, ShouldRunEmptyModule) { + auto* src = R"()"; - auto* expect = - "error: tint::transform::CalculateArrayLength depends on " - "tint::transform::SimplifyPointers but the dependency was not run"; + EXPECT_FALSE(ShouldRun(src)); +} - auto got = Run(src); +TEST_F(CalculateArrayLengthTest, ShouldRunNoArrayLength) { + auto* src = R"( +struct SB { + x : i32; + arr : array; +}; - EXPECT_EQ(expect, str(got)); +[[group(0), binding(0)]] var sb : SB; + +[[stage(compute), workgroup_size(1)]] +fn main() { +} +)"; + + EXPECT_FALSE(ShouldRun(src)); +} + +TEST_F(CalculateArrayLengthTest, ShouldRunWithArrayLength) { + auto* src = R"( +struct SB { + x : i32; + arr : array; +}; + +[[group(0), binding(0)]] var sb : SB; + +[[stage(compute), workgroup_size(1)]] +fn main() { + var len : u32 = arrayLength(&sb.arr); +} +)"; + + EXPECT_TRUE(ShouldRun(src)); } TEST_F(CalculateArrayLengthTest, Basic) { diff --git a/src/transform/canonicalize_entry_point_io.cc b/src/transform/canonicalize_entry_point_io.cc index d0734a06be..88e6b1604a 100644 --- a/src/transform/canonicalize_entry_point_io.cc +++ b/src/transform/canonicalize_entry_point_io.cc @@ -551,10 +551,6 @@ struct CanonicalizeEntryPointIO::State { void CanonicalizeEntryPointIO::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { - if (!Requires(ctx)) { - return; - } - auto* cfg = inputs.Get(); if (cfg == nullptr) { ctx.dst->Diagnostics().add_error( diff --git a/src/transform/canonicalize_entry_point_io.h b/src/transform/canonicalize_entry_point_io.h index 6271444c6d..e3c1dce25a 100644 --- a/src/transform/canonicalize_entry_point_io.h +++ b/src/transform/canonicalize_entry_point_io.h @@ -80,6 +80,9 @@ namespace transform { /// return wrapper_result; /// } /// ``` +/// +/// @note Depends on the following transforms to have been run first: +/// * Unshadow class CanonicalizeEntryPointIO : public Castable { public: diff --git a/src/transform/canonicalize_entry_point_io_test.cc b/src/transform/canonicalize_entry_point_io_test.cc index 9a5f184638..3d89959ae2 100644 --- a/src/transform/canonicalize_entry_point_io_test.cc +++ b/src/transform/canonicalize_entry_point_io_test.cc @@ -23,18 +23,6 @@ namespace { using CanonicalizeEntryPointIOTest = TransformTest; -TEST_F(CanonicalizeEntryPointIOTest, Error_MissingUnshadow) { - auto* src = ""; - - auto* expect = - "error: tint::transform::CanonicalizeEntryPointIO depends on " - "tint::transform::Unshadow but the dependency was not run"; - - auto got = Run(src); - - EXPECT_EQ(expect, str(got)); -} - TEST_F(CanonicalizeEntryPointIOTest, Error_MissingTransformData) { auto* src = ""; diff --git a/src/transform/decompose_memory_access.cc b/src/transform/decompose_memory_access.cc index 81511dcf6f..99c5623c84 100644 --- a/src/transform/decompose_memory_access.cc +++ b/src/transform/decompose_memory_access.cc @@ -790,6 +790,19 @@ const DecomposeMemoryAccess::Intrinsic* DecomposeMemoryAccess::Intrinsic::Clone( DecomposeMemoryAccess::DecomposeMemoryAccess() = default; DecomposeMemoryAccess::~DecomposeMemoryAccess() = default; +bool DecomposeMemoryAccess::ShouldRun(const Program* program, + const DataMap&) const { + for (auto* decl : program->AST().GlobalDeclarations()) { + if (auto* var = program->Sem().Get(decl)) { + if (var->StorageClass() == ast::StorageClass::kStorage || + var->StorageClass() == ast::StorageClass::kUniform) { + return true; + } + } + } + return false; +} + void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) const { diff --git a/src/transform/decompose_memory_access.h b/src/transform/decompose_memory_access.h index 76fbbc4690..2139f4b6ed 100644 --- a/src/transform/decompose_memory_access.h +++ b/src/transform/decompose_memory_access.h @@ -105,6 +105,12 @@ class DecomposeMemoryAccess /// Destructor ~DecomposeMemoryAccess() override; + /// @param program the program to inspect + /// @param data optional extra transform-specific input data + /// @returns true if this transform should be run for the given program + bool ShouldRun(const Program* program, + const DataMap& data = {}) const override; + protected: /// Runs the transform using the CloneContext built for transforming a /// program. Run() is responsible for calling Clone() on the CloneContext. diff --git a/src/transform/decompose_memory_access_test.cc b/src/transform/decompose_memory_access_test.cc index 2fd6e10060..dffb3f2200 100644 --- a/src/transform/decompose_memory_access_test.cc +++ b/src/transform/decompose_memory_access_test.cc @@ -22,6 +22,34 @@ namespace { using DecomposeMemoryAccessTest = TransformTest; +TEST_F(DecomposeMemoryAccessTest, ShouldRunEmptyModule) { + auto* src = R"()"; + + EXPECT_FALSE(ShouldRun(src)); +} + +TEST_F(DecomposeMemoryAccessTest, ShouldRunStorageBuffer) { + auto* src = R"( +struct Buffer { + i : i32; +}; +[[group(0), binding(0)]] var sb : Buffer; +)"; + + EXPECT_TRUE(ShouldRun(src)); +} + +TEST_F(DecomposeMemoryAccessTest, ShouldRunUniformBuffer) { + auto* src = R"( +struct Buffer { + i : i32; +}; +[[group(0), binding(0)]] var ub : Buffer; +)"; + + EXPECT_TRUE(ShouldRun(src)); +} + TEST_F(DecomposeMemoryAccessTest, SB_BasicLoad) { auto* src = R"( struct SB { diff --git a/src/transform/decompose_strided_matrix.cc b/src/transform/decompose_strided_matrix.cc index 11eb5c6e5b..669efc8567 100644 --- a/src/transform/decompose_strided_matrix.cc +++ b/src/transform/decompose_strided_matrix.cc @@ -107,7 +107,8 @@ DecomposeStridedMatrix::DecomposeStridedMatrix() = default; DecomposeStridedMatrix::~DecomposeStridedMatrix() = default; -bool DecomposeStridedMatrix::ShouldRun(const Program* program) const { +bool DecomposeStridedMatrix::ShouldRun(const Program* program, + const DataMap&) const { bool should_run = false; GatherCustomStrideMatrixMembers( program, [&](const sem::StructMember*, sem::Matrix*, uint32_t) { @@ -120,10 +121,6 @@ bool DecomposeStridedMatrix::ShouldRun(const Program* program) const { void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - if (!Requires(ctx)) { - return; - } - // Scan the program for all storage and uniform structure matrix members with // a custom stride attribute. Replace these matrices with an equivalent array, // and populate the `decomposed` map with the members that have been replaced. diff --git a/src/transform/decompose_strided_matrix.h b/src/transform/decompose_strided_matrix.h index 587b6760ef..e7535772ce 100644 --- a/src/transform/decompose_strided_matrix.h +++ b/src/transform/decompose_strided_matrix.h @@ -25,6 +25,9 @@ namespace transform { /// of N column vectors. /// This transform is used by the SPIR-V reader to handle the SPIR-V /// MatrixStride decoration. +/// +/// @note Depends on the following transforms to have been run first: +/// * SimplifyPointers class DecomposeStridedMatrix : public Castable { public: @@ -35,8 +38,10 @@ class DecomposeStridedMatrix ~DecomposeStridedMatrix() override; /// @param program the program to inspect + /// @param data optional extra transform-specific input data /// @returns true if this transform should be run for the given program - bool ShouldRun(const Program* program) const override; + bool ShouldRun(const Program* program, + const DataMap& data = {}) const override; protected: /// Runs the transform using the CloneContext built for transforming a diff --git a/src/transform/decompose_strided_matrix_test.cc b/src/transform/decompose_strided_matrix_test.cc index 97aa39144a..fe3db50f65 100644 --- a/src/transform/decompose_strided_matrix_test.cc +++ b/src/transform/decompose_strided_matrix_test.cc @@ -31,55 +31,25 @@ namespace { using DecomposeStridedMatrixTest = TransformTest; using f32 = ProgramBuilder::f32; +TEST_F(DecomposeStridedMatrixTest, ShouldRunEmptyModule) { + auto* src = R"()"; + + EXPECT_FALSE(ShouldRun(src)); +} + +TEST_F(DecomposeStridedMatrixTest, ShouldRunNonStridedMatrox) { + auto* src = R"( +var m : mat3x2; +)"; + + EXPECT_FALSE(ShouldRun(src)); +} + TEST_F(DecomposeStridedMatrixTest, Empty) { auto* src = R"()"; auto* expect = src; - auto got = Run(src); - - EXPECT_EQ(expect, str(got)); -} - -TEST_F(DecomposeStridedMatrixTest, MissingDependencySimplify) { - // struct S { - // [[offset(16), stride(32)]] - // [[internal(ignore_stride_decoration)]] - // m : mat2x2; - // }; - // [[group(0), binding(0)]] var s : S; - // - // [[stage(compute), workgroup_size(1)]] - // fn f() { - // let x : mat2x2 = s.m; - // } - ProgramBuilder b; - auto* S = b.Structure( - "S", - { - b.Member( - "m", b.ty.mat2x2(), - { - b.create(16), - b.create(32), - b.Disable(ast::DisabledValidation::kIgnoreStrideDecoration), - }), - }); - b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform, - b.GroupAndBinding(0, 0)); - b.Func( - "f", {}, b.ty.void_(), - { - b.Decl(b.Const("x", b.ty.mat2x2(), b.MemberAccessor("s", "m"))), - }, - { - b.Stage(ast::PipelineStage::kCompute), - b.WorkgroupSize(1), - }); - - auto* expect = - R"(error: tint::transform::DecomposeStridedMatrix depends on tint::transform::SimplifyPointers but the dependency was not run)"; - - auto got = Run(Program(std::move(b))); + auto got = Run(src); EXPECT_EQ(expect, str(got)); } diff --git a/src/transform/first_index_offset.cc b/src/transform/first_index_offset.cc index 50a65fac1e..971901323b 100644 --- a/src/transform/first_index_offset.cc +++ b/src/transform/first_index_offset.cc @@ -57,6 +57,15 @@ FirstIndexOffset::Data::~Data() = default; FirstIndexOffset::FirstIndexOffset() = default; FirstIndexOffset::~FirstIndexOffset() = default; +bool FirstIndexOffset::ShouldRun(const Program* program, const DataMap&) const { + for (auto* fn : program->AST().Functions()) { + if (fn->PipelineStage() == ast::PipelineStage::kVertex) { + return true; + } + } + return false; +} + void FirstIndexOffset::Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const { diff --git a/src/transform/first_index_offset.h b/src/transform/first_index_offset.h index 8b7e0e0441..23e4be0e90 100644 --- a/src/transform/first_index_offset.h +++ b/src/transform/first_index_offset.h @@ -115,6 +115,12 @@ class FirstIndexOffset : public Castable { /// Destructor ~FirstIndexOffset() override; + /// @param program the program to inspect + /// @param data optional extra transform-specific input data + /// @returns true if this transform should be run for the given program + bool ShouldRun(const Program* program, + const DataMap& data = {}) const override; + protected: /// Runs the transform using the CloneContext built for transforming a /// program. Run() is responsible for calling Clone() on the CloneContext. diff --git a/src/transform/first_index_offset_test.cc b/src/transform/first_index_offset_test.cc index c0c95f6f80..05fefa84ae 100644 --- a/src/transform/first_index_offset_test.cc +++ b/src/transform/first_index_offset_test.cc @@ -26,6 +26,34 @@ namespace { using FirstIndexOffsetTest = TransformTest; +TEST_F(FirstIndexOffsetTest, ShouldRunEmptyModule) { + auto* src = R"()"; + + EXPECT_FALSE(ShouldRun(src)); +} + +TEST_F(FirstIndexOffsetTest, ShouldRunFragmentStage) { + auto* src = R"( +[[stage(fragment)]] +fn entry() { + return; +} +)"; + + EXPECT_FALSE(ShouldRun(src)); +} + +TEST_F(FirstIndexOffsetTest, ShouldRunVertexStage) { + auto* src = R"( +[[stage(vertex)]] +fn entry() -> [[builtin(position)]] vec4 { + return vec4(); +} +)"; + + EXPECT_TRUE(ShouldRun(src)); +} + TEST_F(FirstIndexOffsetTest, EmptyModule) { auto* src = ""; auto* expect = ""; @@ -38,6 +66,26 @@ TEST_F(FirstIndexOffsetTest, EmptyModule) { auto* data = got.data.Get(); + EXPECT_EQ(data, nullptr); +} + +TEST_F(FirstIndexOffsetTest, BasicVertexShader) { + auto* src = R"( +@stage(vertex) +fn entry() -> @builtin(position) vec4 { + return vec4(); +} +)"; + auto* expect = src; + + DataMap config; + config.Add(0, 0); + auto got = Run(src, std::move(config)); + + EXPECT_EQ(expect, str(got)); + + auto* data = got.data.Get(); + ASSERT_NE(data, nullptr); EXPECT_EQ(data->has_vertex_index, false); EXPECT_EQ(data->has_instance_index, false); diff --git a/src/transform/for_loop_to_loop.cc b/src/transform/for_loop_to_loop.cc index 6277aedf52..3afd8a2fb6 100644 --- a/src/transform/for_loop_to_loop.cc +++ b/src/transform/for_loop_to_loop.cc @@ -25,6 +25,15 @@ ForLoopToLoop::ForLoopToLoop() = default; ForLoopToLoop::~ForLoopToLoop() = default; +bool ForLoopToLoop::ShouldRun(const Program* program, const DataMap&) const { + for (auto* node : program->ASTNodes().Objects()) { + if (node->Is()) { + return true; + } + } + return false; +} + void ForLoopToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const { ctx.ReplaceAll( [&](const ast::ForLoopStatement* for_loop) -> const ast::Statement* { diff --git a/src/transform/for_loop_to_loop.h b/src/transform/for_loop_to_loop.h index 80a64e068c..26cdc29335 100644 --- a/src/transform/for_loop_to_loop.h +++ b/src/transform/for_loop_to_loop.h @@ -30,6 +30,12 @@ class ForLoopToLoop : public Castable { /// Destructor ~ForLoopToLoop() override; + /// @param program the program to inspect + /// @param data optional extra transform-specific input data + /// @returns true if this transform should be run for the given program + bool ShouldRun(const Program* program, + const DataMap& data = {}) const override; + protected: /// Runs the transform using the CloneContext built for transforming a /// program. Run() is responsible for calling Clone() on the CloneContext. diff --git a/src/transform/for_loop_to_loop_test.cc b/src/transform/for_loop_to_loop_test.cc index d8b2f1fcba..5f335933ed 100644 --- a/src/transform/for_loop_to_loop_test.cc +++ b/src/transform/for_loop_to_loop_test.cc @@ -22,6 +22,24 @@ namespace { using ForLoopToLoopTest = TransformTest; +TEST_F(ForLoopToLoopTest, ShouldRunEmptyModule) { + auto* src = R"()"; + + EXPECT_FALSE(ShouldRun(src)); +} + +TEST_F(ForLoopToLoopTest, ShouldRunHasForLoop) { + auto* src = R"( +fn f() { + for (;;) { + break; + } +} +)"; + + EXPECT_TRUE(ShouldRun(src)); +} + TEST_F(ForLoopToLoopTest, EmptyModule) { auto* src = ""; auto* expect = src; diff --git a/src/transform/glsl.cc b/src/transform/glsl.cc index af97f0b1e5..595f2a0fd9 100644 --- a/src/transform/glsl.cc +++ b/src/transform/glsl.cc @@ -113,7 +113,6 @@ Output Glsl::Run(const Program* in, const DataMap& inputs) const { ProgramBuilder builder; CloneContext ctx(&builder, &out.program); ctx.Clone(); - builder.SetTransformApplied(this); return Output{Program(std::move(builder))}; } diff --git a/src/transform/localize_struct_array_assignment.cc b/src/transform/localize_struct_array_assignment.cc index af3c99ed6f..43137715cb 100644 --- a/src/transform/localize_struct_array_assignment.cc +++ b/src/transform/localize_struct_array_assignment.cc @@ -216,15 +216,9 @@ LocalizeStructArrayAssignment::~LocalizeStructArrayAssignment() = default; void LocalizeStructArrayAssignment::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - if (!Requires(ctx)) { - return; - } - State state(ctx); state.Run(); - - // This transform may introduce pointers - ctx.dst->UnsetTransformApplied(); } + } // namespace transform } // namespace tint diff --git a/src/transform/localize_struct_array_assignment.h b/src/transform/localize_struct_array_assignment.h index 7fa18a2302..beab7bada4 100644 --- a/src/transform/localize_struct_array_assignment.h +++ b/src/transform/localize_struct_array_assignment.h @@ -25,6 +25,9 @@ namespace transform { /// temporary local variable, assigns to the local variable, and copies the /// array back. This is to work around FXC's compilation failure for these cases /// (see crbug.com/tint/1206). +/// +/// @note Depends on the following transforms to have been run first: +/// * SimplifyPointers class LocalizeStructArrayAssignment : public Castable { public: diff --git a/src/transform/localize_struct_array_assignment_test.cc b/src/transform/localize_struct_array_assignment_test.cc index 878a89c1c3..0147d8ebda 100644 --- a/src/transform/localize_struct_array_assignment_test.cc +++ b/src/transform/localize_struct_array_assignment_test.cc @@ -24,16 +24,6 @@ namespace { using LocalizeStructArrayAssignmentTest = TransformTest; -TEST_F(LocalizeStructArrayAssignmentTest, MissingSimplifyPointers) { - auto* src = R"()"; - auto* expect = - "error: tint::transform::LocalizeStructArrayAssignment depends on " - "tint::transform::SimplifyPointers but the dependency was not run"; - - auto got = Run(src); - EXPECT_EQ(expect, str(got)); -} - TEST_F(LocalizeStructArrayAssignmentTest, EmptyModule) { auto* src = R"()"; auto* expect = src; diff --git a/src/transform/loop_to_for_loop.cc b/src/transform/loop_to_for_loop.cc index 9378a7179b..735a115644 100644 --- a/src/transform/loop_to_for_loop.cc +++ b/src/transform/loop_to_for_loop.cc @@ -56,6 +56,15 @@ LoopToForLoop::LoopToForLoop() = default; LoopToForLoop::~LoopToForLoop() = default; +bool LoopToForLoop::ShouldRun(const Program* program, const DataMap&) const { + for (auto* node : program->ASTNodes().Objects()) { + if (node->Is()) { + return true; + } + } + return false; +} + void LoopToForLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const { ctx.ReplaceAll([&](const ast::LoopStatement* loop) -> const ast::Statement* { // For loop condition is taken from the first statement in the loop. diff --git a/src/transform/loop_to_for_loop.h b/src/transform/loop_to_for_loop.h index f3f729d714..b3d323b4f3 100644 --- a/src/transform/loop_to_for_loop.h +++ b/src/transform/loop_to_for_loop.h @@ -30,6 +30,12 @@ class LoopToForLoop : public Castable { /// Destructor ~LoopToForLoop() override; + /// @param program the program to inspect + /// @param data optional extra transform-specific input data + /// @returns true if this transform should be run for the given program + bool ShouldRun(const Program* program, + const DataMap& data = {}) const override; + protected: /// Runs the transform using the CloneContext built for transforming a /// program. Run() is responsible for calling Clone() on the CloneContext. diff --git a/src/transform/loop_to_for_loop_test.cc b/src/transform/loop_to_for_loop_test.cc index 402c69ad27..812150976e 100644 --- a/src/transform/loop_to_for_loop_test.cc +++ b/src/transform/loop_to_for_loop_test.cc @@ -22,6 +22,24 @@ namespace { using LoopToForLoopTest = TransformTest; +TEST_F(LoopToForLoopTest, ShouldRunEmptyModule) { + auto* src = R"()"; + + EXPECT_FALSE(ShouldRun(src)); +} + +TEST_F(LoopToForLoopTest, ShouldRunHasForLoop) { + auto* src = R"( +fn f() { + loop { + break; + } +} +)"; + + EXPECT_TRUE(ShouldRun(src)); +} + TEST_F(LoopToForLoopTest, EmptyModule) { auto* src = ""; auto* expect = ""; diff --git a/src/transform/manager.cc b/src/transform/manager.cc index 8841dd74dc..8f0a78fed3 100644 --- a/src/transform/manager.cc +++ b/src/transform/manager.cc @@ -53,7 +53,7 @@ Output Manager::Run(const Program* program, const DataMap& data) const { Output out; for (const auto& transform : transforms_) { - if (!transform->ShouldRun(in)) { + if (!transform->ShouldRun(in, data)) { TINT_IF_PRINT_PROGRAM(std::cout << "Skipping " << transform->TypeInfo().name); continue; diff --git a/src/transform/module_scope_var_to_entry_point_param.cc b/src/transform/module_scope_var_to_entry_point_param.cc index 9005450da6..84bd5535ea 100644 --- a/src/transform/module_scope_var_to_entry_point_param.cc +++ b/src/transform/module_scope_var_to_entry_point_param.cc @@ -377,6 +377,16 @@ ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default; ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default; +bool ModuleScopeVarToEntryPointParam::ShouldRun(const Program* program, + const DataMap&) const { + for (auto* decl : program->AST().GlobalDeclarations()) { + if (decl->Is()) { + return true; + } + } + return false; +} + void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx, const DataMap&, DataMap&) const { diff --git a/src/transform/module_scope_var_to_entry_point_param.h b/src/transform/module_scope_var_to_entry_point_param.h index 0ba88a5855..85570a067a 100644 --- a/src/transform/module_scope_var_to_entry_point_param.h +++ b/src/transform/module_scope_var_to_entry_point_param.h @@ -70,6 +70,12 @@ class ModuleScopeVarToEntryPointParam /// Destructor ~ModuleScopeVarToEntryPointParam() override; + /// @param program the program to inspect + /// @param data optional extra transform-specific input data + /// @returns true if this transform should be run for the given program + bool ShouldRun(const Program* program, + const DataMap& data = {}) const override; + protected: /// Runs the transform using the CloneContext built for transforming a /// program. Run() is responsible for calling Clone() on the CloneContext. diff --git a/src/transform/module_scope_var_to_entry_point_param_test.cc b/src/transform/module_scope_var_to_entry_point_param_test.cc index d8b0f63744..52967a796d 100644 --- a/src/transform/module_scope_var_to_entry_point_param_test.cc +++ b/src/transform/module_scope_var_to_entry_point_param_test.cc @@ -24,6 +24,20 @@ namespace { using ModuleScopeVarToEntryPointParamTest = TransformTest; +TEST_F(ModuleScopeVarToEntryPointParamTest, ShouldRunEmptyModule) { + auto* src = R"()"; + + EXPECT_FALSE(ShouldRun(src)); +} + +TEST_F(ModuleScopeVarToEntryPointParamTest, ShouldRunHasGlobal) { + auto* src = R"( +var v : i32; +)"; + + EXPECT_TRUE(ShouldRun(src)); +} + TEST_F(ModuleScopeVarToEntryPointParamTest, Basic) { auto* src = R"( var p : f32; diff --git a/src/transform/multiplanar_external_texture.cc b/src/transform/multiplanar_external_texture.cc index f18c0cc091..e88507a686 100644 --- a/src/transform/multiplanar_external_texture.cc +++ b/src/transform/multiplanar_external_texture.cc @@ -81,7 +81,7 @@ struct MultiplanarExternalTexture::State { // represent the secondary plane and one uniform buffer for the // ExternalTextureParams struct). ctx.ReplaceAll([&](const ast::Variable* var) -> const ast::Variable* { - if (!::tint::Is(var->type)) { + if (!sem.Get(var->type)) { return nullptr; } @@ -201,7 +201,7 @@ struct MultiplanarExternalTexture::State { // functions. ctx.ReplaceAll([&](const ast::Function* fn) -> const ast::Function* { for (const ast::Variable* param : fn->params) { - if (::tint::Is(param->type)) { + if (sem.Get(param->type)) { // If we find a texture_external, we must ensure the // ExternalTextureParams struct exists. if (!params_struct_sym.IsValid()) { @@ -407,6 +407,18 @@ MultiplanarExternalTexture::NewBindingPoints::~NewBindingPoints() = default; MultiplanarExternalTexture::MultiplanarExternalTexture() = default; MultiplanarExternalTexture::~MultiplanarExternalTexture() = default; +bool MultiplanarExternalTexture::ShouldRun(const Program* program, + const DataMap&) const { + for (auto* node : program->ASTNodes().Objects()) { + if (auto* ty = node->As()) { + if (program->Sem().Get(ty)) { + return true; + } + } + } + return false; +} + // Within this transform, an instance of a texture_external binding is unpacked // into two texture_2d bindings representing two possible planes of a // single texture and a uniform buffer binding representing a struct of diff --git a/src/transform/multiplanar_external_texture.h b/src/transform/multiplanar_external_texture.h index 0a999da5c4..ecbad6510c 100644 --- a/src/transform/multiplanar_external_texture.h +++ b/src/transform/multiplanar_external_texture.h @@ -76,6 +76,12 @@ class MultiplanarExternalTexture /// Destructor ~MultiplanarExternalTexture() override; + /// @param program the program to inspect + /// @param data optional extra transform-specific input data + /// @returns true if this transform should be run for the given program + bool ShouldRun(const Program* program, + const DataMap& data = {}) const override; + protected: struct State; diff --git a/src/transform/multiplanar_external_texture_test.cc b/src/transform/multiplanar_external_texture_test.cc index be8137a815..0d8df30e7c 100644 --- a/src/transform/multiplanar_external_texture_test.cc +++ b/src/transform/multiplanar_external_texture_test.cc @@ -21,6 +21,35 @@ namespace { using MultiplanarExternalTextureTest = TransformTest; +TEST_F(MultiplanarExternalTextureTest, ShouldRunEmptyModule) { + auto* src = R"()"; + + EXPECT_FALSE(ShouldRun(src)); +} + +TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureAlias) { + auto* src = R"( +type ET = texture_external; +)"; + + EXPECT_TRUE(ShouldRun(src)); +} +TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureGlobal) { + auto* src = R"( +[[group(0), binding(0)]] var ext_tex : texture_external; +)"; + + EXPECT_TRUE(ShouldRun(src)); +} + +TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureParam) { + auto* src = R"( +fn f(ext_tex : texture_external) {} +)"; + + EXPECT_TRUE(ShouldRun(src)); +} + // Running the transform without passing in data for the new bindings should // result in an error. TEST_F(MultiplanarExternalTextureTest, ErrorNoPassedData) { @@ -651,6 +680,106 @@ fn main() { EXPECT_EQ(expect, str(got)); } +// Tests that the transform works with a function using an external texture, +// even if there's no external texture declared at module scope. +TEST_F(MultiplanarExternalTextureTest, + ExternalTexturePassedAsParamWithoutGlobalDecl) { + auto* src = R"( +fn f(ext_tex : texture_external) -> vec2 { + return textureDimensions(ext_tex); +} +)"; + + auto* expect = R"( +struct ExternalTextureParams { + numPlanes : u32; + vr : f32; + ug : f32; + vg : f32; + ub : f32; +} + +fn f(ext_tex : texture_2d, ext_tex_plane_1 : texture_2d, ext_tex_params : ExternalTextureParams) -> vec2 { + return textureDimensions(ext_tex); +} +)"; + + DataMap data; + data.Add( + MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}}); + auto got = Run(src, data); + EXPECT_EQ(expect, str(got)); +} + +// Tests that the the transform handles aliases to external textures +TEST_F(MultiplanarExternalTextureTest, ExternalTextureAlias) { + auto* src = R"( +type ET = texture_external; + +fn f(t : ET, s : sampler) { + textureSampleLevel(t, s, vec2(1.0, 2.0)); +} + +[[group(0), binding(0)]] var ext_tex : ET; +[[group(0), binding(1)]] var smp : sampler; + +[[stage(fragment)]] +fn main() { + f(ext_tex, smp); +} +)"; + + auto* expect = R"( +type ET = texture_external; + +struct ExternalTextureParams { + numPlanes : u32; + vr : f32; + ug : f32; + vg : f32; + ub : f32; +} + +fn textureSampleExternal(plane0 : texture_2d, plane1 : texture_2d, smp : sampler, coord : vec2, params : ExternalTextureParams) -> vec4 { + if ((params.numPlanes == 1u)) { + return textureSampleLevel(plane0, smp, coord, 0.0); + } + let y = (textureSampleLevel(plane0, smp, coord, 0.0).r - 0.0625); + let uv = (textureSampleLevel(plane1, smp, coord, 0.0).rg - 0.5); + let u = uv.x; + let v = uv.y; + let r = ((1.164000034 * y) + (params.vr * v)); + let g = (((1.164000034 * y) - (params.ug * u)) - (params.vg * v)); + let b = ((1.164000034 * y) + (params.ub * u)); + return vec4(r, g, b, 1.0); +} + +fn f(t : texture_2d, ext_tex_plane_1 : texture_2d, ext_tex_params : ExternalTextureParams, s : sampler) { + textureSampleExternal(t, ext_tex_plane_1, s, vec2(1.0, 2.0), ext_tex_params); +} + +@group(0) @binding(2) var ext_tex_plane_1_1 : texture_2d; + +@group(0) @binding(3) var ext_tex_params_1 : ExternalTextureParams; + +@group(0) @binding(0) var ext_tex : texture_2d; + +@group(0) @binding(1) var smp : sampler; + +@stage(fragment) +fn main() { + f(ext_tex, ext_tex_plane_1_1, ext_tex_params_1, smp); +} +)"; + DataMap data; + data.Add( + MultiplanarExternalTexture::BindingsMap{ + {{0, 0}, {{0, 2}, {0, 3}}}, + }); + auto got = Run(src, data); + EXPECT_EQ(expect, str(got)); +} + } // namespace } // namespace transform } // namespace tint diff --git a/src/transform/num_workgroups_from_uniform.cc b/src/transform/num_workgroups_from_uniform.cc index b190935cc3..f2d411de6f 100644 --- a/src/transform/num_workgroups_from_uniform.cc +++ b/src/transform/num_workgroups_from_uniform.cc @@ -52,13 +52,21 @@ struct Accessor { NumWorkgroupsFromUniform::NumWorkgroupsFromUniform() = default; NumWorkgroupsFromUniform::~NumWorkgroupsFromUniform() = default; +bool NumWorkgroupsFromUniform::ShouldRun(const Program* program, + const DataMap&) const { + for (auto* node : program->ASTNodes().Objects()) { + if (auto* deco = node->As()) { + if (deco->builtin == ast::Builtin::kNumWorkgroups) { + return true; + } + } + } + return false; +} + void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { - if (!Requires(ctx)) { - return; - } - auto* cfg = inputs.Get(); if (cfg == nullptr) { ctx.dst->Diagnostics().add_error( diff --git a/src/transform/num_workgroups_from_uniform.h b/src/transform/num_workgroups_from_uniform.h index 794768c839..1870ea21ff 100644 --- a/src/transform/num_workgroups_from_uniform.h +++ b/src/transform/num_workgroups_from_uniform.h @@ -39,6 +39,9 @@ namespace transform { /// ``` /// The binding group and number used for this uniform buffer is provided via /// the `Config` transform input. +/// +/// @note Depends on the following transforms to have been run first: +/// * CanonicalizeEntryPointIO class NumWorkgroupsFromUniform : public Castable { public: @@ -63,6 +66,12 @@ class NumWorkgroupsFromUniform sem::BindingPoint ubo_binding; }; + /// @param program the program to inspect + /// @param data optional extra transform-specific input data + /// @returns true if this transform should be run for the given program + bool ShouldRun(const Program* program, + const DataMap& data = {}) const override; + protected: /// Runs the transform using the CloneContext built for transforming a /// program. Run() is responsible for calling Clone() on the CloneContext. diff --git a/src/transform/num_workgroups_from_uniform_test.cc b/src/transform/num_workgroups_from_uniform_test.cc index bea5222549..dedee54abe 100644 --- a/src/transform/num_workgroups_from_uniform_test.cc +++ b/src/transform/num_workgroups_from_uniform_test.cc @@ -26,8 +26,28 @@ namespace { using NumWorkgroupsFromUniformTest = TransformTest; +TEST_F(NumWorkgroupsFromUniformTest, ShouldRunEmptyModule) { + auto* src = R"()"; + + EXPECT_FALSE(ShouldRun(src)); +} + +TEST_F(NumWorkgroupsFromUniformTest, ShouldRunHasNumWorkgroups) { + auto* src = R"( +[[stage(compute), workgroup_size(1)]] +fn main([[builtin(num_workgroups)]] num_wgs : vec3) { +} +)"; + + EXPECT_TRUE(ShouldRun(src)); +} + TEST_F(NumWorkgroupsFromUniformTest, Error_MissingTransformData) { - auto* src = ""; + auto* src = R"( +[[stage(compute), workgroup_size(1)]] +fn main([[builtin(num_workgroups)]] num_wgs : vec3) { +} +)"; auto* expect = "error: missing transform data for " @@ -42,19 +62,6 @@ TEST_F(NumWorkgroupsFromUniformTest, Error_MissingTransformData) { EXPECT_EQ(expect, str(got)); } -TEST_F(NumWorkgroupsFromUniformTest, Error_MissingCanonicalizeEntryPointIO) { - auto* src = ""; - - auto* expect = - "error: tint::transform::NumWorkgroupsFromUniform depends on " - "tint::transform::CanonicalizeEntryPointIO but the dependency was not " - "run"; - - auto got = Run(src); - - EXPECT_EQ(expect, str(got)); -} - TEST_F(NumWorkgroupsFromUniformTest, Basic) { auto* src = R"( @stage(compute) @workgroup_size(1) diff --git a/src/transform/pad_array_elements.cc b/src/transform/pad_array_elements.cc index f469a2067a..f04976d712 100644 --- a/src/transform/pad_array_elements.cc +++ b/src/transform/pad_array_elements.cc @@ -97,6 +97,19 @@ PadArrayElements::PadArrayElements() = default; PadArrayElements::~PadArrayElements() = default; +bool PadArrayElements::ShouldRun(const Program* program, const DataMap&) const { + for (auto* node : program->ASTNodes().Objects()) { + if (auto* var = node->As()) { + if (auto* arr = program->Sem().Get(var)) { + if (!arr->IsStrideImplicit()) { + return true; + } + } + } + } + return false; +} + void PadArrayElements::Run(CloneContext& ctx, const DataMap&, DataMap&) const { auto& sem = ctx.src->Sem(); diff --git a/src/transform/pad_array_elements.h b/src/transform/pad_array_elements.h index 78a8c01165..fe7797856b 100644 --- a/src/transform/pad_array_elements.h +++ b/src/transform/pad_array_elements.h @@ -38,6 +38,12 @@ class PadArrayElements : public Castable { /// Destructor ~PadArrayElements() override; + /// @param program the program to inspect + /// @param data optional extra transform-specific input data + /// @returns true if this transform should be run for the given program + bool ShouldRun(const Program* program, + const DataMap& data = {}) const override; + protected: /// Runs the transform using the CloneContext built for transforming a /// program. Run() is responsible for calling Clone() on the CloneContext. diff --git a/src/transform/pad_array_elements_test.cc b/src/transform/pad_array_elements_test.cc index 441eadcda9..5139edc3d5 100644 --- a/src/transform/pad_array_elements_test.cc +++ b/src/transform/pad_array_elements_test.cc @@ -24,6 +24,28 @@ namespace { using PadArrayElementsTest = TransformTest; +TEST_F(PadArrayElementsTest, ShouldRunEmptyModule) { + auto* src = R"()"; + + EXPECT_FALSE(ShouldRun(src)); +} + +TEST_F(PadArrayElementsTest, ShouldRunHasImplicitArrayStride) { + auto* src = R"( +var arr : array; +)"; + + EXPECT_FALSE(ShouldRun(src)); +} + +TEST_F(PadArrayElementsTest, ShouldRunHasExplicitArrayStride) { + auto* src = R"( +var arr : [[stride(8)]] array; +)"; + + EXPECT_TRUE(ShouldRun(src)); +} + TEST_F(PadArrayElementsTest, EmptyModule) { auto* src = ""; auto* expect = ""; diff --git a/src/transform/remove_phonies.cc b/src/transform/remove_phonies.cc index 8dfb9b4e20..06966b2278 100644 --- a/src/transform/remove_phonies.cc +++ b/src/transform/remove_phonies.cc @@ -68,6 +68,15 @@ RemovePhonies::RemovePhonies() = default; RemovePhonies::~RemovePhonies() = default; +bool RemovePhonies::ShouldRun(const Program* program, const DataMap&) const { + for (auto* node : program->ASTNodes().Objects()) { + if (node->Is()) { + return true; + } + } + return false; +} + void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const { auto& sem = ctx.src->Sem(); diff --git a/src/transform/remove_phonies.h b/src/transform/remove_phonies.h index 45b117f5ec..415ce890c7 100644 --- a/src/transform/remove_phonies.h +++ b/src/transform/remove_phonies.h @@ -34,6 +34,12 @@ class RemovePhonies : public Castable { /// Destructor ~RemovePhonies() override; + /// @param program the program to inspect + /// @param data optional extra transform-specific input data + /// @returns true if this transform should be run for the given program + bool ShouldRun(const Program* program, + const DataMap& data = {}) const override; + protected: /// Runs the transform using the CloneContext built for transforming a /// program. Run() is responsible for calling Clone() on the CloneContext. diff --git a/src/transform/remove_phonies_test.cc b/src/transform/remove_phonies_test.cc index 6e8a1f1fca..d073d358af 100644 --- a/src/transform/remove_phonies_test.cc +++ b/src/transform/remove_phonies_test.cc @@ -26,6 +26,22 @@ namespace { using RemovePhoniesTest = TransformTest; +TEST_F(RemovePhoniesTest, ShouldRunEmptyModule) { + auto* src = R"()"; + + EXPECT_FALSE(ShouldRun(src)); +} + +TEST_F(RemovePhoniesTest, ShouldRunHasPhony) { + auto* src = R"( +fn f() { + _ = 1; +} +)"; + + EXPECT_TRUE(ShouldRun(src)); +} + TEST_F(RemovePhoniesTest, EmptyModule) { auto* src = ""; auto* expect = ""; diff --git a/src/transform/remove_unreachable_statements.cc b/src/transform/remove_unreachable_statements.cc index adb1f4123b..88d4158f8d 100644 --- a/src/transform/remove_unreachable_statements.cc +++ b/src/transform/remove_unreachable_statements.cc @@ -37,6 +37,18 @@ RemoveUnreachableStatements::RemoveUnreachableStatements() = default; RemoveUnreachableStatements::~RemoveUnreachableStatements() = default; +bool RemoveUnreachableStatements::ShouldRun(const Program* program, + const DataMap&) const { + for (auto* node : program->ASTNodes().Objects()) { + if (auto* stmt = program->Sem().Get(node)) { + if (!stmt->IsReachable()) { + return true; + } + } + } + return false; +} + void RemoveUnreachableStatements::Run(CloneContext& ctx, const DataMap&, DataMap&) const { diff --git a/src/transform/remove_unreachable_statements.h b/src/transform/remove_unreachable_statements.h index 7f0fc0d835..060a2a6ba4 100644 --- a/src/transform/remove_unreachable_statements.h +++ b/src/transform/remove_unreachable_statements.h @@ -34,6 +34,12 @@ class RemoveUnreachableStatements /// Destructor ~RemoveUnreachableStatements() override; + /// @param program the program to inspect + /// @param data optional extra transform-specific input data + /// @returns true if this transform should be run for the given program + bool ShouldRun(const Program* program, + const DataMap& data = {}) const override; + protected: /// Runs the transform using the CloneContext built for transforming a /// program. Run() is responsible for calling Clone() on the CloneContext. diff --git a/src/transform/remove_unreachable_statements_test.cc b/src/transform/remove_unreachable_statements_test.cc index 874bbf239f..33fe5a6ca4 100644 --- a/src/transform/remove_unreachable_statements_test.cc +++ b/src/transform/remove_unreachable_statements_test.cc @@ -22,6 +22,37 @@ namespace { using RemoveUnreachableStatementsTest = TransformTest; +TEST_F(RemoveUnreachableStatementsTest, ShouldRunEmptyModule) { + auto* src = R"()"; + + EXPECT_FALSE(ShouldRun(src)); +} + +TEST_F(RemoveUnreachableStatementsTest, ShouldRunHasNoUnreachable) { + auto* src = R"( +fn f() { + if (true) { + var x = 1; + } +} +)"; + + EXPECT_FALSE(ShouldRun(src)); +} + +TEST_F(RemoveUnreachableStatementsTest, ShouldRunHasUnreachable) { + auto* src = R"( +fn f() { + return; + if (true) { + var x = 1; + } +} +)"; + + EXPECT_TRUE(ShouldRun(src)); +} + TEST_F(RemoveUnreachableStatementsTest, EmptyModule) { auto* src = ""; auto* expect = ""; diff --git a/src/transform/simplify_pointers.cc b/src/transform/simplify_pointers.cc index c050a2eb5c..cfc06b7d9b 100644 --- a/src/transform/simplify_pointers.cc +++ b/src/transform/simplify_pointers.cc @@ -232,9 +232,6 @@ SimplifyPointers::SimplifyPointers() = default; SimplifyPointers::~SimplifyPointers() = default; void SimplifyPointers::Run(CloneContext& ctx, const DataMap&, DataMap&) const { - if (!Requires(ctx)) { - return; - } State(ctx).Run(); } diff --git a/src/transform/simplify_pointers.h b/src/transform/simplify_pointers.h index 33af4addd7..2541a36e2f 100644 --- a/src/transform/simplify_pointers.h +++ b/src/transform/simplify_pointers.h @@ -29,6 +29,9 @@ namespace transform { /// Note: SimplifyPointers does not operate on module-scope `let`s, as these /// cannot be pointers: https://gpuweb.github.io/gpuweb/wgsl/#module-constants /// `A module-scope let-declared constant must be of constructible type.` +/// +/// @note Depends on the following transforms to have been run first: +/// * Unshadow class SimplifyPointers : public Castable { public: /// Constructor diff --git a/src/transform/simplify_pointers_test.cc b/src/transform/simplify_pointers_test.cc index a9cce891ad..584a606307 100644 --- a/src/transform/simplify_pointers_test.cc +++ b/src/transform/simplify_pointers_test.cc @@ -23,18 +23,6 @@ namespace { using SimplifyPointersTest = TransformTest; -TEST_F(SimplifyPointersTest, Error_MissingSimplifyPointers) { - auto* src = ""; - - auto* expect = - "error: tint::transform::SimplifyPointers depends on " - "tint::transform::Unshadow but the dependency was not run"; - - auto got = Run(src); - - EXPECT_EQ(expect, str(got)); -} - TEST_F(SimplifyPointersTest, EmptyModule) { auto* src = ""; auto* expect = ""; diff --git a/src/transform/single_entry_point.cc b/src/transform/single_entry_point.cc index a1caa7d356..fd7bac9ded 100644 --- a/src/transform/single_entry_point.cc +++ b/src/transform/single_entry_point.cc @@ -107,11 +107,6 @@ void SingleEntryPoint::Run(CloneContext& ctx, // Clone the entry point. ctx.dst->AST().AddFunction(ctx.Clone(entry_point)); - - // Retain the list of applied transforms. - // We need to do this manually since we are not going to use the top-level - // ctx.Clone() function. - ctx.dst->SetTransformApplied(ctx.src->TransformsApplied()); } SingleEntryPoint::Config::Config(std::string entry_point) diff --git a/src/transform/test_helper.h b/src/transform/test_helper.h index 0c4618a6bc..a09ccba66c 100644 --- a/src/transform/test_helper.h +++ b/src/transform/test_helper.h @@ -81,6 +81,17 @@ class TransformTestBase : public BASE { return manager.Run(&program, data); } + /// @param in the input WGSL source + /// @param data the optional DataMap to pass to Transform::Run() + /// @return true if the transform should be run for the given input. + template + bool ShouldRun(std::string in, const DataMap& data = {}) { + auto file = std::make_unique("test", in); + auto program = reader::wgsl::Parse(file.get()); + EXPECT_TRUE(program.IsValid()) << program.Diagnostics().str(); + return TRANSFORM().ShouldRun(&program, data); + } + /// @param output the output of the transform /// @returns the output program as a WGSL string, or an error string if the /// program is not valid. diff --git a/src/transform/transform.cc b/src/transform/transform.cc index 0834d97f73..7ef31003ac 100644 --- a/src/transform/transform.cc +++ b/src/transform/transform.cc @@ -52,7 +52,6 @@ Output Transform::Run(const Program* program, CloneContext ctx(&builder, program); Output output; Run(ctx, data, output.data); - builder.SetTransformApplied(this); output.program = Program(std::move(builder)); return output; } @@ -62,22 +61,7 @@ void Transform::Run(CloneContext& ctx, const DataMap&, DataMap&) const { << "Transform::Run() unimplemented for " << TypeInfo().name; } -bool Transform::ShouldRun(const Program*) const { - return true; -} - -bool Transform::Requires( - CloneContext& ctx, - std::initializer_list deps) const { - for (auto* dep : deps) { - if (!ctx.src->HasTransformApplied(dep)) { - ctx.dst->Diagnostics().add_error( - diag::System::Transform, std::string(TypeInfo().name) + - " depends on " + std::string(dep->name) + - " but the dependency was not run"); - return false; - } - } +bool Transform::ShouldRun(const Program*, const DataMap&) const { return true; } diff --git a/src/transform/transform.h b/src/transform/transform.h index 5f8a107e37..8bc109e0e2 100644 --- a/src/transform/transform.h +++ b/src/transform/transform.h @@ -160,8 +160,10 @@ class Transform : public Castable { virtual Output Run(const Program* program, const DataMap& data = {}) const; /// @param program the program to inspect + /// @param data optional extra transform-specific input data /// @returns true if this transform should be run for the given program - virtual bool ShouldRun(const Program* program) const; + virtual bool ShouldRun(const Program* program, + const DataMap& data = {}) const; protected: /// Runs the transform using the CloneContext built for transforming a @@ -174,23 +176,6 @@ class Transform : public Castable { const DataMap& inputs, DataMap& outputs) const; - /// Requires appends an error diagnostic to `ctx.dst` if the template type - /// transforms were not already run on `ctx.src`. - /// @param ctx the CloneContext - /// @returns true if all dependency transforms have been run - template - bool Requires(CloneContext& ctx) const { - return Requires(ctx, {&::tint::TypeInfo::Of()...}); - } - - /// Requires appends an error diagnostic to `ctx.dst` if the list of - /// Transforms were not already run on `ctx.src`. - /// @param ctx the CloneContext - /// @param deps the list of Transform TypeInfos - /// @returns true if all dependency transforms have been run - bool Requires(CloneContext& ctx, - std::initializer_list deps) const; - /// Removes the statement `stmt` from the transformed program. /// RemoveStatement handles edge cases, like statements in the initializer and /// continuing of for-loops. diff --git a/src/transform/vectorize_scalar_matrix_constructors.cc b/src/transform/vectorize_scalar_matrix_constructors.cc index d72efa2013..75fcfb10a4 100644 --- a/src/transform/vectorize_scalar_matrix_constructors.cc +++ b/src/transform/vectorize_scalar_matrix_constructors.cc @@ -32,6 +32,22 @@ VectorizeScalarMatrixConstructors::VectorizeScalarMatrixConstructors() = VectorizeScalarMatrixConstructors::~VectorizeScalarMatrixConstructors() = default; +bool VectorizeScalarMatrixConstructors::ShouldRun(const Program* program, + const DataMap&) const { + for (auto* node : program->ASTNodes().Objects()) { + if (auto* call = program->Sem().Get(node)) { + if (call->Target()->Is() && + call->Type()->Is()) { + auto& args = call->Arguments(); + if (args.size() > 0 && args[0]->Type()->is_scalar()) { + return true; + } + } + } + } + return false; +} + void VectorizeScalarMatrixConstructors::Run(CloneContext& ctx, const DataMap&, DataMap&) const { diff --git a/src/transform/vectorize_scalar_matrix_constructors.h b/src/transform/vectorize_scalar_matrix_constructors.h index d6e1df3110..220d366907 100644 --- a/src/transform/vectorize_scalar_matrix_constructors.h +++ b/src/transform/vectorize_scalar_matrix_constructors.h @@ -30,6 +30,12 @@ class VectorizeScalarMatrixConstructors /// Destructor ~VectorizeScalarMatrixConstructors() override; + /// @param program the program to inspect + /// @param data optional extra transform-specific input data + /// @returns true if this transform should be run for the given program + bool ShouldRun(const Program* program, + const DataMap& data = {}) const override; + protected: /// Runs the transform using the CloneContext built for transforming a /// program. Run() is responsible for calling Clone() on the CloneContext. diff --git a/src/transform/vectorize_scalar_matrix_constructors_test.cc b/src/transform/vectorize_scalar_matrix_constructors_test.cc index 7e72e4cfb6..51b252f96f 100644 --- a/src/transform/vectorize_scalar_matrix_constructors_test.cc +++ b/src/transform/vectorize_scalar_matrix_constructors_test.cc @@ -27,6 +27,12 @@ namespace { using VectorizeScalarMatrixConstructorsTest = TransformTestWithParam>; +TEST_F(VectorizeScalarMatrixConstructorsTest, ShouldRunEmptyModule) { + auto* src = R"()"; + + EXPECT_FALSE(ShouldRun(src)); +} + TEST_P(VectorizeScalarMatrixConstructorsTest, Basic) { uint32_t cols = GetParam().first; uint32_t rows = GetParam().second; @@ -63,6 +69,8 @@ fn main() { auto src = utils::ReplaceAll(tmpl, "${values}", scalar_values); auto expect = utils::ReplaceAll(tmpl, "${values}", vector_values); + EXPECT_TRUE(ShouldRun(src)); + auto got = Run(src); EXPECT_EQ(expect, str(got)); @@ -92,6 +100,8 @@ fn main() { auto src = utils::ReplaceAll(tmpl, "${columns}", columns); auto expect = src; + EXPECT_FALSE(ShouldRun(src)); + auto got = Run(src); EXPECT_EQ(expect, str(got)); diff --git a/src/transform/wrap_arrays_in_structs.cc b/src/transform/wrap_arrays_in_structs.cc index 64f851e331..d62c70294d 100644 --- a/src/transform/wrap_arrays_in_structs.cc +++ b/src/transform/wrap_arrays_in_structs.cc @@ -38,6 +38,16 @@ WrapArraysInStructs::WrapArraysInStructs() = default; WrapArraysInStructs::~WrapArraysInStructs() = default; +bool WrapArraysInStructs::ShouldRun(const Program* program, + const DataMap&) const { + for (auto* node : program->ASTNodes().Objects()) { + if (program->Sem().Get(node->As())) { + return true; + } + } + return false; +} + void WrapArraysInStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const { diff --git a/src/transform/wrap_arrays_in_structs.h b/src/transform/wrap_arrays_in_structs.h index 4af74dc6e8..dfdc304a42 100644 --- a/src/transform/wrap_arrays_in_structs.h +++ b/src/transform/wrap_arrays_in_structs.h @@ -44,6 +44,12 @@ class WrapArraysInStructs : public Castable { /// Destructor ~WrapArraysInStructs() override; + /// @param program the program to inspect + /// @param data optional extra transform-specific input data + /// @returns true if this transform should be run for the given program + bool ShouldRun(const Program* program, + const DataMap& data = {}) const override; + protected: /// Runs the transform using the CloneContext built for transforming a /// program. Run() is responsible for calling Clone() on the CloneContext. diff --git a/src/transform/wrap_arrays_in_structs_test.cc b/src/transform/wrap_arrays_in_structs_test.cc index 21aa7c84c7..4f3b5d4e66 100644 --- a/src/transform/wrap_arrays_in_structs_test.cc +++ b/src/transform/wrap_arrays_in_structs_test.cc @@ -25,9 +25,23 @@ namespace { using WrapArraysInStructsTest = TransformTest; +TEST_F(WrapArraysInStructsTest, ShouldRunEmptyModule) { + auto* src = R"()"; + + EXPECT_FALSE(ShouldRun(src)); +} + +TEST_F(WrapArraysInStructsTest, ShouldRunHasArray) { + auto* src = R"( +var arr : array; +)"; + + EXPECT_TRUE(ShouldRun(src)); +} + TEST_F(WrapArraysInStructsTest, EmptyModule) { - auto* src = ""; - auto* expect = ""; + auto* src = R"()"; + auto* expect = src; auto got = Run(src); diff --git a/src/transform/zero_init_workgroup_memory.cc b/src/transform/zero_init_workgroup_memory.cc index 398b893b44..f4109f0638 100644 --- a/src/transform/zero_init_workgroup_memory.cc +++ b/src/transform/zero_init_workgroup_memory.cc @@ -433,6 +433,18 @@ ZeroInitWorkgroupMemory::ZeroInitWorkgroupMemory() = default; ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default; +bool ZeroInitWorkgroupMemory::ShouldRun(const Program* program, + const DataMap&) const { + for (auto* decl : program->AST().GlobalDeclarations()) { + if (auto* var = decl->As()) { + if (var->declared_storage_class == ast::StorageClass::kWorkgroup) { + return true; + } + } + } + return false; +} + void ZeroInitWorkgroupMemory::Run(CloneContext& ctx, const DataMap&, DataMap&) const { diff --git a/src/transform/zero_init_workgroup_memory.h b/src/transform/zero_init_workgroup_memory.h index 9eb3a3ec2e..df6c9aeaf5 100644 --- a/src/transform/zero_init_workgroup_memory.h +++ b/src/transform/zero_init_workgroup_memory.h @@ -32,6 +32,12 @@ class ZeroInitWorkgroupMemory /// Destructor ~ZeroInitWorkgroupMemory() override; + /// @param program the program to inspect + /// @param data optional extra transform-specific input data + /// @returns true if this transform should be run for the given program + bool ShouldRun(const Program* program, + const DataMap& data = {}) const override; + protected: /// Runs the transform using the CloneContext built for transforming a /// program. Run() is responsible for calling Clone() on the CloneContext. diff --git a/src/transform/zero_init_workgroup_memory_test.cc b/src/transform/zero_init_workgroup_memory_test.cc index d23e0af56f..e422b384b7 100644 --- a/src/transform/zero_init_workgroup_memory_test.cc +++ b/src/transform/zero_init_workgroup_memory_test.cc @@ -24,6 +24,28 @@ namespace { using ZeroInitWorkgroupMemoryTest = TransformTest; +TEST_F(ZeroInitWorkgroupMemoryTest, ShouldRunEmptyModule) { + auto* src = R"()"; + + EXPECT_FALSE(ShouldRun(src)); +} + +TEST_F(ZeroInitWorkgroupMemoryTest, ShouldRunHasNoWorkgroupVars) { + auto* src = R"( +var v : i32; +)"; + + EXPECT_FALSE(ShouldRun(src)); +} + +TEST_F(ZeroInitWorkgroupMemoryTest, ShouldRunHasWorkgroupVars) { + auto* src = R"( +var a : i32; +)"; + + EXPECT_TRUE(ShouldRun(src)); +} + TEST_F(ZeroInitWorkgroupMemoryTest, EmptyModule) { auto* src = ""; auto* expect = src; diff --git a/src/writer/glsl/generator_impl.cc b/src/writer/glsl/generator_impl.cc index e7b697e2d0..1444a69eb6 100644 --- a/src/writer/glsl/generator_impl.cc +++ b/src/writer/glsl/generator_impl.cc @@ -124,14 +124,6 @@ GeneratorImpl::GeneratorImpl(const Program* program) : TextGenerator(program) {} GeneratorImpl::~GeneratorImpl() = default; bool GeneratorImpl::Generate() { - if (!builder_.HasTransformApplied()) { - diagnostics_.add_error( - diag::System::Writer, - "GLSL writer requires the transform::Glsl sanitizer to have been " - "applied to the input program"); - return false; - } - const TypeInfo* last_kind = nullptr; size_t last_padding_line = 0; diff --git a/src/writer/glsl/generator_impl_test.cc b/src/writer/glsl/generator_impl_test.cc index c89041f019..1f9db55849 100644 --- a/src/writer/glsl/generator_impl_test.cc +++ b/src/writer/glsl/generator_impl_test.cc @@ -21,16 +21,6 @@ namespace { using GlslGeneratorImplTest = TestHelper; -TEST_F(GlslGeneratorImplTest, ErrorIfSanitizerNotRun) { - auto program = std::make_unique(std::move(*this)); - GeneratorImpl gen(program.get()); - EXPECT_FALSE(gen.Generate()); - EXPECT_EQ( - gen.error(), - "error: GLSL writer requires the transform::Glsl sanitizer to have been " - "applied to the input program"); -} - TEST_F(GlslGeneratorImplTest, Generate) { Func("my_func", ast::VariableList{}, ty.void_(), ast::StatementList{}, ast::DecorationList{}); diff --git a/src/writer/glsl/test_helper.h b/src/writer/glsl/test_helper.h index 298cbd8d92..0364f6ca81 100644 --- a/src/writer/glsl/test_helper.h +++ b/src/writer/glsl/test_helper.h @@ -44,9 +44,6 @@ class TestHelperBase : public BODY, public ProgramBuilder { if (gen_) { return *gen_; } - // Fake that the GLSL sanitizer has been applied, so that we can unit test - // the writer without it erroring. - SetTransformApplied(); [&]() { ASSERT_TRUE(IsValid()) << "Builder program is not valid\n" << diag::Formatter().format(Diagnostics()); diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc index e25aa85866..3b09b62def 100644 --- a/src/writer/msl/generator_impl.cc +++ b/src/writer/msl/generator_impl.cc @@ -187,9 +187,10 @@ SanitizedResult Sanitize( if (!result.program.IsValid()) { return result; } - result.used_array_length_from_uniform_indices = - std::move(out.data.Get() - ->used_size_indices); + if (auto* res = out.data.Get()) { + result.used_array_length_from_uniform_indices = + std::move(res->used_size_indices); + } result.needs_storage_buffer_sizes = !result.used_array_length_from_uniform_indices.empty(); return result;